File size: 4,183 Bytes
0812af4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
"""
ONNX 모델을 사용한 멀티헤드 이미지 분류 추론 예제
"""

import onnxruntime as ort
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import json

# 전처리 파이프라인
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])
])

def load_model_info(model_info_path):
    """모델 정보 로드"""
    with open(model_info_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def preprocess_image(image_path):
    """이미지 전처리"""
    image = Image.open(image_path).convert('RGB')
    tensor = transform(image)
    return tensor.unsqueeze(0).numpy()  # 배치 차원 추가

def softmax(x):
    """Softmax 함수"""
    exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)

def predict_image(onnx_model_path, model_info_path, image_path):
    """이미지 분류 예측"""

    # 모델 정보 로드
    model_info = load_model_info(model_info_path)

    # ONNX 세션 생성
    session = ort.InferenceSession(onnx_model_path)

    # 이미지 전처리
    image_array = preprocess_image(image_path)

    # 추론 실행
    inputs = {'image': image_array}
    outputs = session.run(None, inputs)

    # 결과 해석
    results = {}
    head_names = list(model_info['output_specification']['heads'].keys())
    output_names = head_names + ['features']  # features 추가

    for i, output_name in enumerate(output_names):
        if output_name == 'features':
            # 특징 벡터 처리
            features = outputs[i][0]  # 첫 번째 배치
            results[output_name] = {
                'embedding': features.tolist(),
                'dimension': len(features),
                'description': 'DINOv2 backbone features'
            }
        else:
            # 분류 헤드 처리
            logits = outputs[i]
            probabilities = softmax(logits)[0]  # 첫 번째 배치

            # 클래스 이름 매핑
            class_names = model_info['class_mappings'].get(output_name, {})

            # 최고 확률 클래스
            pred_idx = np.argmax(probabilities)
            pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}")
            pred_prob = probabilities[pred_idx]

            # 상위 3개 클래스
            top3_indices = np.argsort(probabilities)[-3:][::-1]
            top3_results = []
            for idx in top3_indices:
                class_name = class_names.get(str(idx), f"Class_{idx}")
                prob = probabilities[idx]
                top3_results.append({'class': class_name, 'probability': float(prob)})

            results[output_name] = {
                'predicted_class': pred_class,
                'confidence': float(pred_prob),
                'top3': top3_results
            }

    return results

# 사용 예시
if __name__ == "__main__":
    onnx_path = "image_classifier.onnx"
    model_info_path = "model_info.json"
    image_path = "test_image.jpg"

    try:
        results = predict_image(onnx_path, model_info_path, image_path)

        print(f"이미지 분류 결과: {image_path}")
        print("=" * 50)

        for output_name, result in results.items():
            if output_name == 'features':
                print(f"\n{output_name.upper()}:")
                print(f"  차원: {result['dimension']}")
                print(f"  설명: {result['description']}")
                print(f"  특징 벡터 (처음 10개): {result['embedding'][:10]}")
            else:
                print(f"\n{output_name.upper()}:")
                print(f"  예측 클래스: {result['predicted_class']}")
                print(f"  신뢰도: {result['confidence']:.4f}")
                print(f"  Top 3:")
                for i, top_result in enumerate(result['top3'], 1):
                    print(f"    {i}. {top_result['class']}: {top_result['probability']:.4f}")

    except Exception as e:
        print(f"추론 실패: {e}")