File size: 1,447 Bytes
d65c86f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#!/usr/bin/env python3
"""
ONNX ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ์œ ์‚ฌ๋„ ์ถ”๋ก  ์˜ˆ์ œ
"""

import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

# ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def preprocess_image(image_path):
    """์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ"""
    image = Image.open(image_path).convert('RGB')
    tensor = transform(image)
    return tensor.unsqueeze(0).numpy()  # ๋ฐฐ์น˜ ์ฐจ์› ์ถ”๊ฐ€

def predict_similarity(onnx_model_path, image1_path, image2_path):
    """์ด๋ฏธ์ง€ ์Œ ์œ ์‚ฌ๋„ ์˜ˆ์ธก"""

    # ONNX ์„ธ์…˜ ์ƒ์„ฑ
    session = ort.InferenceSession(onnx_model_path)

    # ์ด๋ฏธ์ง€ ์ „์ฒ˜๋ฆฌ
    img1 = preprocess_image(image1_path)
    img2 = preprocess_image(image2_path)

    # ์ถ”๋ก  ์‹คํ–‰
    inputs = {'image1': img1, 'image2': img2}
    logits = session.run(None, inputs)[0]

    # ์‹œ๊ทธ๋ชจ์ด๋“œ๋กœ ํ™•๋ฅ  ๋ณ€ํ™˜
    similarity = 1 / (1 + np.exp(-logits[0][0]))

    return similarity

# ์‚ฌ์šฉ ์˜ˆ์‹œ
if __name__ == "__main__":
    onnx_path = "room_image_comparator.onnx"
    img1_path = "room1.jpg"
    img2_path = "room2.jpg"

    similarity = predict_similarity(onnx_path, img1_path, img2_path)
    print(f"์œ ์‚ฌ๋„: {similarity:.4f}")