import torch from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights import torchvision.transforms as transforms from PIL import Image import torch.nn.functional as F import json import os import sys # 파일 누락 시 종료를 위해 추가 # ============================================================================== # 0. ImageNet 클래스 이름 로드 # ============================================================================== CLASS_MAP_FILENAME = 'labels_map.txt' class_name_map = None # 전역 변수로 초기화 try: if not os.path.exists(CLASS_MAP_FILENAME): print(f"[오류] 클래스 이름 파일('{CLASS_MAP_FILENAME}')을 찾을 수 없습니다.") print("파일을 현재 디렉토리에 저장했는지 확인해 주세요.") sys.exit(1) # 파일이 없으면 프로그램 종료 # 1. 파일 로드 (JSON 형식) with open(CLASS_MAP_FILENAME, 'r') as f: class_map_json = json.load(f) # 2. 제공해주신 로직 적용: 인덱스 0부터 999까지 이름만 추출하여 리스트 생성 # JSON 파일의 키가 문자열이므로 str(i)로 접근하고, 값 리스트의 두 번째 요소(이름)를 가져옵니다. labels_list = [class_map_json[str(i)] for i in range(1000)] # 3. 인덱스와 이름을 매핑하는 딕셔너리로 변환 (나중에 클래스 ID로 이름 조회 용이) class_name_map = {i: name for i, name in enumerate(labels_list)} print(f"ImageNet 클래스 이름 ({len(class_name_map)}개) 로드 완료.") except Exception as e: print(f"[오류] 클래스 파일 로드 또는 처리 중 오류 발생: {e}") sys.exit(1) # 오류 발생 시 프로그램 종료 # ============================================================================== # 1. 설정 및 모델 로드 # ============================================================================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"사용 장치: {device}") # ImageNet으로 사전 훈련된 EfficientNetB0 모델 로드 print("사전 훈련된 EfficientNetB0 모델 로드 중...") model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1) model.eval() # 평가 모드 설정 model = model.to(device) print("모델 로드 및 평가 모드 설정 완료.") # ============================================================================== # 2. 필수 전처리 파이프라인 정의 # ============================================================================== preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # ============================================================================== # 3. 이미지 분류 및 출력 함수 # ============================================================================== def classify_image(image_path_string): """ 주어진 이미지 경로의 파일을 EfficientNetB0 모델로 분류하고 결과를 출력합니다. (클래스 이름 포함) """ try: # 1. 이미지 로드 및 RGB 변환 img = Image.open(image_path_string).convert('RGB') print(f"\n[INFO] 이미지 로드 성공: {image_path_string}") input_tensor = preprocess(img) input_batch = input_tensor.unsqueeze(0).to(device) with torch.no_grad(): output = model(input_batch) probabilities = F.softmax(output[0], dim=0) top_prob, top_catid = torch.topk(probabilities, 5) print("\n--- 분류 결과 (Top-5) ---") for i in range(top_prob.size(0)): idx = top_catid[i].item() # 클래스 이름 매핑 적용: 로드된 딕셔너리 사용 class_name = class_name_map.get(idx, f"알 수 없는 클래스 (ID: {idx})") print(f"순위 {i+1}:") print(f" - 클래스 이름: **{class_name}**") print(f" - 클래스 인덱스 (ID): {idx}") print(f" - 확률: {top_prob[i].item():.4f}") except FileNotFoundError: print(f"\n[오류] 이미지 파일을 찾을 수 없습니다: {image_path_string}") print("경로를 다시 확인해주세요.") except Exception as e: print(f"\n[오류] 분류 중 문제가 발생했습니다: {e}") # --- 실행 --- # 분류할 이미지 파일 경로를 문자열로 지정 (사용자 환경에 맞게 수정 필요!) CLASSIFY_TARGET_PATH = 'D:/pictures/muffin1.png' # 함수 실행 classify_image(CLASSIFY_TARGET_PATH)