#!/usr/bin/env python3 """ ONNX 모델을 사용한 멀티헤드 이미지 분류 추론 예제 전체 모델(model.onnx) 또는 분리 모델(encoder.onnx + head.onnx) 사용 가능 """ import onnxruntime as ort import numpy as np from PIL import Image import torchvision.transforms as transforms import json from pathlib import Path # 전처리 파이프라인 transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_model_info(model_info_path): """모델 정보 로드""" with open(model_info_path, 'r', encoding='utf-8') as f: return json.load(f) def preprocess_image(image_path): """이미지 전처리""" image = Image.open(image_path).convert('RGB') tensor = transform(image) return tensor.unsqueeze(0).numpy() # 배치 차원 추가 def softmax(x): """Softmax 함수""" exp_x = np.exp(x - np.max(x, axis=1, keepdims=True)) return exp_x / np.sum(exp_x, axis=1, keepdims=True) def predict_image_full_model(model_path, model_info_path, image_path): """전체 모델을 사용한 이미지 분류 예측""" # 모델 정보 로드 model_info = load_model_info(model_info_path) # ONNX 세션 생성 session = ort.InferenceSession(model_path) # 이미지 전처리 image_array = preprocess_image(image_path) # 추론 실행 inputs = {'image': image_array} outputs = session.run(None, inputs) # 결과 해석 results = {} head_names = list(model_info['output_specification']['heads'].keys()) for i, output_name in enumerate(head_names): logits = outputs[i] probabilities = softmax(logits)[0] # 클래스 이름 매핑 class_names = model_info['class_mappings'].get(output_name, {}) # 최고 확률 클래스 pred_idx = np.argmax(probabilities) pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}") pred_prob = probabilities[pred_idx] # 상위 3개 클래스 top3_indices = np.argsort(probabilities)[-3:][::-1] top3_results = [] for idx in top3_indices: class_name = class_names.get(str(idx), f"Class_{idx}") prob = probabilities[idx] top3_results.append({'class': class_name, 'probability': float(prob)}) results[output_name] = { 'predicted_class': pred_class, 'confidence': float(pred_prob), 'top3': top3_results } return results def predict_image_split_model(encoder_path, head_path, model_info_path, image_path): """분리 모델을 사용한 이미지 분류 예측""" # 모델 정보 로드 model_info = load_model_info(model_info_path) # ONNX 세션 생성 encoder_session = ort.InferenceSession(encoder_path) head_session = ort.InferenceSession(head_path) # 이미지 전처리 image_array = preprocess_image(image_path) # 인코더로 특징 벡터 추출 encoder_inputs = {'image': image_array} features = encoder_session.run(None, encoder_inputs)[0] # 헤드로 분류 head_inputs = {'features': features} outputs = head_session.run(None, head_inputs) # 결과 해석 results = {} head_names = list(model_info['output_specification']['heads'].keys()) for i, output_name in enumerate(head_names): logits = outputs[i] probabilities = softmax(logits)[0] # 클래스 이름 매핑 class_names = model_info['class_mappings'].get(output_name, {}) # 최고 확률 클래스 pred_idx = np.argmax(probabilities) pred_class = class_names.get(str(pred_idx), f"Class_{pred_idx}") pred_prob = probabilities[pred_idx] # 상위 3개 클래스 top3_indices = np.argsort(probabilities)[-3:][::-1] top3_results = [] for idx in top3_indices: class_name = class_names.get(str(idx), f"Class_{idx}") prob = probabilities[idx] top3_results.append({'class': class_name, 'probability': float(prob)}) results[output_name] = { 'predicted_class': pred_class, 'confidence': float(pred_prob), 'top3': top3_results } return results # 사용 예시 if __name__ == "__main__": model_info_path = "model_info.json" image_path = "test_image.jpg" # 분리 모델이 있는지 확인 if Path("encoder.onnx").exists() and Path("head.onnx").exists(): print("분리 모델 사용") results = predict_image_split_model("encoder.onnx", "head.onnx", model_info_path, image_path) elif Path("model.onnx").exists(): print("전체 모델 사용") results = predict_image_full_model("model.onnx", model_info_path, image_path) else: print("ONNX 모델을 찾을 수 없습니다.") exit(1) print(f"\n이미지 분류 결과: {image_path}") print("=" * 50) for output_name, result in results.items(): print(f"\n{output_name.upper()}:") print(f" 예측 클래스: {result['predicted_class']}") print(f" 신뢰도: {result['confidence']:.4f}") print(f" Top 3:") for i, top_result in enumerate(result['top3'], 1): print(f" {i}. {top_result['class']}: {top_result['probability']:.4f}")