import torch from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights import torchvision.transforms as transforms from PIL import Image import torch.nn.functional as F import json import os import sys # ============================================================================== # 0. ImageNet 클래스 이름 로드 # ============================================================================== CLASS_MAP_FILENAME = 'labels_map.txt' class_name_map = None # API 서버 시작 시 ImageNet 클래스 맵을 메모리에 로드 try: if not os.path.exists(CLASS_MAP_FILENAME): # NOTE: API 환경에서는 sys.exit 대신 예외를 발생시켜야 합니다. raise FileNotFoundError(f"[오류] 클래스 이름 파일('{CLASS_MAP_FILENAME}')을 찾을 수 없습니다. API 서버를 시작할 수 없습니다.") with open(CLASS_MAP_FILENAME, 'r') as f: class_map_json = json.load(f) # 🚨🚨🚨 이 부분이 수정되었습니다. 🚨🚨🚨 # 값(Value)이 문자열인 경우: v 자체가 클래스 이름입니다. # 값(Value)이 리스트인 경우: 리스트의 마지막 요소(일반적으로 인덱스 1)를 클래스 이름으로 가정합니다. labels_list = [] for k, v in class_map_json.items(): if k.isdigit() and 0 <= int(k) < 1000: if isinstance(v, list) and len(v) > 1: labels_list.append(v[1]) # 리스트일 경우 두 번째 요소 (이전 코드 유지) elif isinstance(v, str): labels_list.append(v) # 문자열일 경우 전체 문자열 사용 (수정된 핵심) else: # 알 수 없는 형식은 무시하거나, 기본값 설정 labels_list.append(f"Unknown Class Index {k}") # 인덱스와 이름 매핑 딕셔너리로 변환 # labels_list의 순서가 모델의 출력 인덱스 (0~999)와 일치해야 합니다. class_name_map = {i: name for i, name in enumerate(labels_list)} # 클래스 맵이 1000개가 맞는지 확인 (ImageNet 기준) if len(class_name_map) != 1000: print(f"[경고] 로드된 클래스 수: {len(class_name_map)}개. ImageNet (1000개)과 다릅니다. 확인해 주세요.") print(f"ImageNet 클래스 맵 로드 성공. (총 {len(class_name_map)}개)") except FileNotFoundError as e: # API 서버 시작을 막기 위해 발생된 오류를 다시 발생 raise e except Exception as e: # JSON 파싱 오류 등 기타 로딩 오류 print(f"[오류] 클래스 맵 로드 중 예기치 않은 오류 발생: {e}") class_name_map = None # 로드 실패 시 None 유지 # API 서버 시작을 막기 위해 RuntimeError 발생 raise RuntimeError(f"클래스 맵 로드 오류: {e}") # ============================================================================== # 1. 모델 및 전처리 파이프라인 로드 (전역적으로 한 번만 실행) # ============================================================================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 사전 훈련된 EfficientNetB0 모델 로드 # 모델 로드 중 오류가 발생할 수 있으므로 try-except 블록으로 감쌉니다. try: # weights 객체는 전처리(transforms) 정보도 포함합니다. weights = EfficientNet_B0_Weights.DEFAULT model = efficientnet_b0(weights=weights).to(device).eval() # eval 모드로 설정 preprocess = weights.transforms() # 전처리 파이프라인 로드 print("EfficientNetB0 모델 및 전처리 파이프라인 로드 성공.") except Exception as e: print(f"[오류] EfficientNetB0 모델 로드 중 오류 발생: {e}") # 모델 로드 실패 시 None 설정 후, 분류 시 오류를 발생시키도록 함 model = None preprocess = None raise RuntimeError(f"모델 로드 오류: {e}") # ============================================================================== # 2. 분류 함수 (API에서 사용) # ============================================================================== def classify_image_pil(img: Image.Image) -> list: """ 주어진 PIL Image 객체를 EfficientNetB0 모델로 분류하고 Top-5 결과를 리스트로 반환합니다. """ if class_name_map is None or not model: raise RuntimeError("모델 또는 클래스 맵이 아직 로드되지 않았습니다.") try: # 1. 이미지 RGB 변환 및 전처리 img = img.convert('RGB') input_tensor = preprocess(img) input_batch = input_tensor.unsqueeze(0).to(device) # 2. 추론 수행 with torch.no_grad(): output = model(input_batch) # 3. 확률 및 Top-K 추출 probabilities = F.softmax(output[0], dim=0) # Top-5 확률 및 인덱스 (카테고리 ID) 추출 top_prob, top_catid = torch.topk(probabilities, 5) results = [] for i in range(top_prob.size(0)): idx = top_catid[i].item() # 클래스 이름 매핑 적용 class_name = class_name_map.get(idx, f"알 수 없는 클래스 (ID: {idx})") results.append({ "rank": i + 1, "class_name": class_name, "class_index": idx, "probability": top_prob[i].item() }) return results except Exception as e: # 분류 중 발생하는 모든 오류는 호출자(app.py)에게 RuntimeError로 전달 raise RuntimeError(f"이미지 분류 중 PyTorch/CUDA 오류 발생: {e}")