File size: 1,447 Bytes
d65c86f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
#!/usr/bin/env python3
"""
ONNX ๋ชจ๋ธ์ ์ฌ์ฉํ ์ด๋ฏธ์ง ์ ์ฌ๋ ์ถ๋ก ์์
"""
import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
# ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def preprocess_image(image_path):
"""์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ"""
image = Image.open(image_path).convert('RGB')
tensor = transform(image)
return tensor.unsqueeze(0).numpy() # ๋ฐฐ์น ์ฐจ์ ์ถ๊ฐ
def predict_similarity(onnx_model_path, image1_path, image2_path):
"""์ด๋ฏธ์ง ์ ์ ์ฌ๋ ์์ธก"""
# ONNX ์ธ์
์์ฑ
session = ort.InferenceSession(onnx_model_path)
# ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ
img1 = preprocess_image(image1_path)
img2 = preprocess_image(image2_path)
# ์ถ๋ก ์คํ
inputs = {'image1': img1, 'image2': img2}
logits = session.run(None, inputs)[0]
# ์๊ทธ๋ชจ์ด๋๋ก ํ๋ฅ ๋ณํ
similarity = 1 / (1 + np.exp(-logits[0][0]))
return similarity
# ์ฌ์ฉ ์์
if __name__ == "__main__":
onnx_path = "room_image_comparator.onnx"
img1_path = "room1.jpg"
img2_path = "room2.jpg"
similarity = predict_similarity(onnx_path, img1_path, img2_path)
print(f"์ ์ฌ๋: {similarity:.4f}")
|