Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import json | |
| import os | |
| import sys | |
| # ============================================================================== | |
| # 0. ImageNet ํด๋์ค ์ด๋ฆ ๋ก๋ | |
| # ============================================================================== | |
| CLASS_MAP_FILENAME = 'labels_map.txt' | |
| class_name_map = None | |
| # API ์๋ฒ ์์ ์ ImageNet ํด๋์ค ๋งต์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋ | |
| try: | |
| if not os.path.exists(CLASS_MAP_FILENAME): | |
| # NOTE: API ํ๊ฒฝ์์๋ sys.exit ๋์ ์์ธ๋ฅผ ๋ฐ์์์ผ์ผ ํฉ๋๋ค. | |
| raise FileNotFoundError(f"[์ค๋ฅ] ํด๋์ค ์ด๋ฆ ํ์ผ('{CLASS_MAP_FILENAME}')์ ์ฐพ์ ์ ์์ต๋๋ค. API ์๋ฒ๋ฅผ ์์ํ ์ ์์ต๋๋ค.") | |
| with open(CLASS_MAP_FILENAME, 'r') as f: | |
| class_map_json = json.load(f) | |
| # ๐จ๐จ๐จ ์ด ๋ถ๋ถ์ด ์์ ๋์์ต๋๋ค. ๐จ๐จ๐จ | |
| # ๊ฐ(Value)์ด ๋ฌธ์์ด์ธ ๊ฒฝ์ฐ: v ์์ฒด๊ฐ ํด๋์ค ์ด๋ฆ์ ๋๋ค. | |
| # ๊ฐ(Value)์ด ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ: ๋ฆฌ์คํธ์ ๋ง์ง๋ง ์์(์ผ๋ฐ์ ์ผ๋ก ์ธ๋ฑ์ค 1)๋ฅผ ํด๋์ค ์ด๋ฆ์ผ๋ก ๊ฐ์ ํฉ๋๋ค. | |
| labels_list = [] | |
| for k, v in class_map_json.items(): | |
| if k.isdigit() and 0 <= int(k) < 1000: | |
| if isinstance(v, list) and len(v) > 1: | |
| labels_list.append(v[1]) # ๋ฆฌ์คํธ์ผ ๊ฒฝ์ฐ ๋ ๋ฒ์งธ ์์ (์ด์ ์ฝ๋ ์ ์ง) | |
| elif isinstance(v, str): | |
| labels_list.append(v) # ๋ฌธ์์ด์ผ ๊ฒฝ์ฐ ์ ์ฒด ๋ฌธ์์ด ์ฌ์ฉ (์์ ๋ ํต์ฌ) | |
| else: | |
| # ์ ์ ์๋ ํ์์ ๋ฌด์ํ๊ฑฐ๋, ๊ธฐ๋ณธ๊ฐ ์ค์ | |
| labels_list.append(f"Unknown Class Index {k}") | |
| # ์ธ๋ฑ์ค์ ์ด๋ฆ ๋งคํ ๋์ ๋๋ฆฌ๋ก ๋ณํ | |
| # labels_list์ ์์๊ฐ ๋ชจ๋ธ์ ์ถ๋ ฅ ์ธ๋ฑ์ค (0~999)์ ์ผ์นํด์ผ ํฉ๋๋ค. | |
| class_name_map = {i: name for i, name in enumerate(labels_list)} | |
| # ํด๋์ค ๋งต์ด 1000๊ฐ๊ฐ ๋ง๋์ง ํ์ธ (ImageNet ๊ธฐ์ค) | |
| if len(class_name_map) != 1000: | |
| print(f"[๊ฒฝ๊ณ ] ๋ก๋๋ ํด๋์ค ์: {len(class_name_map)}๊ฐ. ImageNet (1000๊ฐ)๊ณผ ๋ค๋ฆ ๋๋ค. ํ์ธํด ์ฃผ์ธ์.") | |
| print(f"ImageNet ํด๋์ค ๋งต ๋ก๋ ์ฑ๊ณต. (์ด {len(class_name_map)}๊ฐ)") | |
| except FileNotFoundError as e: | |
| # API ์๋ฒ ์์์ ๋ง๊ธฐ ์ํด ๋ฐ์๋ ์ค๋ฅ๋ฅผ ๋ค์ ๋ฐ์ | |
| raise e | |
| except Exception as e: | |
| # JSON ํ์ฑ ์ค๋ฅ ๋ฑ ๊ธฐํ ๋ก๋ฉ ์ค๋ฅ | |
| print(f"[์ค๋ฅ] ํด๋์ค ๋งต ๋ก๋ ์ค ์๊ธฐ์น ์์ ์ค๋ฅ ๋ฐ์: {e}") | |
| class_name_map = None # ๋ก๋ ์คํจ ์ None ์ ์ง | |
| # API ์๋ฒ ์์์ ๋ง๊ธฐ ์ํด RuntimeError ๋ฐ์ | |
| raise RuntimeError(f"ํด๋์ค ๋งต ๋ก๋ ์ค๋ฅ: {e}") | |
| # ============================================================================== | |
| # 1. ๋ชจ๋ธ ๋ฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ๋ก๋ (์ ์ญ์ ์ผ๋ก ํ ๋ฒ๋ง ์คํ) | |
| # ============================================================================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ์ฌ์ ํ๋ จ๋ EfficientNetB0 ๋ชจ๋ธ ๋ก๋ | |
| # ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ ์ ์์ผ๋ฏ๋ก try-except ๋ธ๋ก์ผ๋ก ๊ฐ์๋๋ค. | |
| try: | |
| # weights ๊ฐ์ฒด๋ ์ ์ฒ๋ฆฌ(transforms) ์ ๋ณด๋ ํฌํจํฉ๋๋ค. | |
| weights = EfficientNet_B0_Weights.DEFAULT | |
| model = efficientnet_b0(weights=weights).to(device).eval() # eval ๋ชจ๋๋ก ์ค์ | |
| preprocess = weights.transforms() # ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ๋ก๋ | |
| print("EfficientNetB0 ๋ชจ๋ธ ๋ฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ๋ก๋ ์ฑ๊ณต.") | |
| except Exception as e: | |
| print(f"[์ค๋ฅ] EfficientNetB0 ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
| # ๋ชจ๋ธ ๋ก๋ ์คํจ ์ None ์ค์ ํ, ๋ถ๋ฅ ์ ์ค๋ฅ๋ฅผ ๋ฐ์์ํค๋๋ก ํจ | |
| model = None | |
| preprocess = None | |
| raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {e}") | |
| # ============================================================================== | |
| # 2. ๋ถ๋ฅ ํจ์ (API์์ ์ฌ์ฉ) | |
| # ============================================================================== | |
| def classify_image_pil(img: Image.Image) -> list: | |
| """ | |
| ์ฃผ์ด์ง PIL Image ๊ฐ์ฒด๋ฅผ EfficientNetB0 ๋ชจ๋ธ๋ก ๋ถ๋ฅํ๊ณ | |
| Top-5 ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์คํธ๋ก ๋ฐํํฉ๋๋ค. | |
| """ | |
| if class_name_map is None or not model: | |
| raise RuntimeError("๋ชจ๋ธ ๋๋ ํด๋์ค ๋งต์ด ์์ง ๋ก๋๋์ง ์์์ต๋๋ค.") | |
| try: | |
| # 1. ์ด๋ฏธ์ง RGB ๋ณํ ๋ฐ ์ ์ฒ๋ฆฌ | |
| img = img.convert('RGB') | |
| input_tensor = preprocess(img) | |
| input_batch = input_tensor.unsqueeze(0).to(device) | |
| # 2. ์ถ๋ก ์ํ | |
| with torch.no_grad(): | |
| output = model(input_batch) | |
| # 3. ํ๋ฅ ๋ฐ Top-K ์ถ์ถ | |
| probabilities = F.softmax(output[0], dim=0) | |
| # Top-5 ํ๋ฅ ๋ฐ ์ธ๋ฑ์ค (์นดํ ๊ณ ๋ฆฌ ID) ์ถ์ถ | |
| top_prob, top_catid = torch.topk(probabilities, 5) | |
| results = [] | |
| for i in range(top_prob.size(0)): | |
| idx = top_catid[i].item() | |
| # ํด๋์ค ์ด๋ฆ ๋งคํ ์ ์ฉ | |
| class_name = class_name_map.get(idx, f"์ ์ ์๋ ํด๋์ค (ID: {idx})") | |
| results.append({ | |
| "rank": i + 1, | |
| "class_name": class_name, | |
| "class_index": idx, | |
| "probability": top_prob[i].item() | |
| }) | |
| return results | |
| except Exception as e: | |
| # ๋ถ๋ฅ ์ค ๋ฐ์ํ๋ ๋ชจ๋ ์ค๋ฅ๋ ํธ์ถ์(app.py)์๊ฒ RuntimeError๋ก ์ ๋ฌ | |
| raise RuntimeError(f"์ด๋ฏธ์ง ๋ถ๋ฅ ์ค PyTorch/CUDA ์ค๋ฅ ๋ฐ์: {e}") | |