Spaces:
Sleeping
Sleeping
File size: 5,656 Bytes
00f636d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import json
import os
import sys
# ==============================================================================
# 0. ImageNet ํด๋์ค ์ด๋ฆ ๋ก๋
# ==============================================================================
CLASS_MAP_FILENAME = 'labels_map.txt'
class_name_map = None
# API ์๋ฒ ์์ ์ ImageNet ํด๋์ค ๋งต์ ๋ฉ๋ชจ๋ฆฌ์ ๋ก๋
try:
if not os.path.exists(CLASS_MAP_FILENAME):
# NOTE: API ํ๊ฒฝ์์๋ sys.exit ๋์ ์์ธ๋ฅผ ๋ฐ์์์ผ์ผ ํฉ๋๋ค.
raise FileNotFoundError(f"[์ค๋ฅ] ํด๋์ค ์ด๋ฆ ํ์ผ('{CLASS_MAP_FILENAME}')์ ์ฐพ์ ์ ์์ต๋๋ค. API ์๋ฒ๋ฅผ ์์ํ ์ ์์ต๋๋ค.")
with open(CLASS_MAP_FILENAME, 'r') as f:
class_map_json = json.load(f)
# ๐จ๐จ๐จ ์ด ๋ถ๋ถ์ด ์์ ๋์์ต๋๋ค. ๐จ๐จ๐จ
# ๊ฐ(Value)์ด ๋ฌธ์์ด์ธ ๊ฒฝ์ฐ: v ์์ฒด๊ฐ ํด๋์ค ์ด๋ฆ์
๋๋ค.
# ๊ฐ(Value)์ด ๋ฆฌ์คํธ์ธ ๊ฒฝ์ฐ: ๋ฆฌ์คํธ์ ๋ง์ง๋ง ์์(์ผ๋ฐ์ ์ผ๋ก ์ธ๋ฑ์ค 1)๋ฅผ ํด๋์ค ์ด๋ฆ์ผ๋ก ๊ฐ์ ํฉ๋๋ค.
labels_list = []
for k, v in class_map_json.items():
if k.isdigit() and 0 <= int(k) < 1000:
if isinstance(v, list) and len(v) > 1:
labels_list.append(v[1]) # ๋ฆฌ์คํธ์ผ ๊ฒฝ์ฐ ๋ ๋ฒ์งธ ์์ (์ด์ ์ฝ๋ ์ ์ง)
elif isinstance(v, str):
labels_list.append(v) # ๋ฌธ์์ด์ผ ๊ฒฝ์ฐ ์ ์ฒด ๋ฌธ์์ด ์ฌ์ฉ (์์ ๋ ํต์ฌ)
else:
# ์ ์ ์๋ ํ์์ ๋ฌด์ํ๊ฑฐ๋, ๊ธฐ๋ณธ๊ฐ ์ค์
labels_list.append(f"Unknown Class Index {k}")
# ์ธ๋ฑ์ค์ ์ด๋ฆ ๋งคํ ๋์
๋๋ฆฌ๋ก ๋ณํ
# labels_list์ ์์๊ฐ ๋ชจ๋ธ์ ์ถ๋ ฅ ์ธ๋ฑ์ค (0~999)์ ์ผ์นํด์ผ ํฉ๋๋ค.
class_name_map = {i: name for i, name in enumerate(labels_list)}
# ํด๋์ค ๋งต์ด 1000๊ฐ๊ฐ ๋ง๋์ง ํ์ธ (ImageNet ๊ธฐ์ค)
if len(class_name_map) != 1000:
print(f"[๊ฒฝ๊ณ ] ๋ก๋๋ ํด๋์ค ์: {len(class_name_map)}๊ฐ. ImageNet (1000๊ฐ)๊ณผ ๋ค๋ฆ
๋๋ค. ํ์ธํด ์ฃผ์ธ์.")
print(f"ImageNet ํด๋์ค ๋งต ๋ก๋ ์ฑ๊ณต. (์ด {len(class_name_map)}๊ฐ)")
except FileNotFoundError as e:
# API ์๋ฒ ์์์ ๋ง๊ธฐ ์ํด ๋ฐ์๋ ์ค๋ฅ๋ฅผ ๋ค์ ๋ฐ์
raise e
except Exception as e:
# JSON ํ์ฑ ์ค๋ฅ ๋ฑ ๊ธฐํ ๋ก๋ฉ ์ค๋ฅ
print(f"[์ค๋ฅ] ํด๋์ค ๋งต ๋ก๋ ์ค ์๊ธฐ์น ์์ ์ค๋ฅ ๋ฐ์: {e}")
class_name_map = None # ๋ก๋ ์คํจ ์ None ์ ์ง
# API ์๋ฒ ์์์ ๋ง๊ธฐ ์ํด RuntimeError ๋ฐ์
raise RuntimeError(f"ํด๋์ค ๋งต ๋ก๋ ์ค๋ฅ: {e}")
# ==============================================================================
# 1. ๋ชจ๋ธ ๋ฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ๋ก๋ (์ ์ญ์ ์ผ๋ก ํ ๋ฒ๋ง ์คํ)
# ==============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ์ฌ์ ํ๋ จ๋ EfficientNetB0 ๋ชจ๋ธ ๋ก๋
# ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ ์ ์์ผ๋ฏ๋ก try-except ๋ธ๋ก์ผ๋ก ๊ฐ์๋๋ค.
try:
# weights ๊ฐ์ฒด๋ ์ ์ฒ๋ฆฌ(transforms) ์ ๋ณด๋ ํฌํจํฉ๋๋ค.
weights = EfficientNet_B0_Weights.DEFAULT
model = efficientnet_b0(weights=weights).to(device).eval() # eval ๋ชจ๋๋ก ์ค์
preprocess = weights.transforms() # ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ๋ก๋
print("EfficientNetB0 ๋ชจ๋ธ ๋ฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ๋ก๋ ์ฑ๊ณต.")
except Exception as e:
print(f"[์ค๋ฅ] EfficientNetB0 ๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}")
# ๋ชจ๋ธ ๋ก๋ ์คํจ ์ None ์ค์ ํ, ๋ถ๋ฅ ์ ์ค๋ฅ๋ฅผ ๋ฐ์์ํค๋๋ก ํจ
model = None
preprocess = None
raise RuntimeError(f"๋ชจ๋ธ ๋ก๋ ์ค๋ฅ: {e}")
# ==============================================================================
# 2. ๋ถ๋ฅ ํจ์ (API์์ ์ฌ์ฉ)
# ==============================================================================
def classify_image_pil(img: Image.Image) -> list:
"""
์ฃผ์ด์ง PIL Image ๊ฐ์ฒด๋ฅผ EfficientNetB0 ๋ชจ๋ธ๋ก ๋ถ๋ฅํ๊ณ
Top-5 ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์คํธ๋ก ๋ฐํํฉ๋๋ค.
"""
if class_name_map is None or not model:
raise RuntimeError("๋ชจ๋ธ ๋๋ ํด๋์ค ๋งต์ด ์์ง ๋ก๋๋์ง ์์์ต๋๋ค.")
try:
# 1. ์ด๋ฏธ์ง RGB ๋ณํ ๋ฐ ์ ์ฒ๋ฆฌ
img = img.convert('RGB')
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)
# 2. ์ถ๋ก ์ํ
with torch.no_grad():
output = model(input_batch)
# 3. ํ๋ฅ ๋ฐ Top-K ์ถ์ถ
probabilities = F.softmax(output[0], dim=0)
# Top-5 ํ๋ฅ ๋ฐ ์ธ๋ฑ์ค (์นดํ
๊ณ ๋ฆฌ ID) ์ถ์ถ
top_prob, top_catid = torch.topk(probabilities, 5)
results = []
for i in range(top_prob.size(0)):
idx = top_catid[i].item()
# ํด๋์ค ์ด๋ฆ ๋งคํ ์ ์ฉ
class_name = class_name_map.get(idx, f"์ ์ ์๋ ํด๋์ค (ID: {idx})")
results.append({
"rank": i + 1,
"class_name": class_name,
"class_index": idx,
"probability": top_prob[i].item()
})
return results
except Exception as e:
# ๋ถ๋ฅ ์ค ๋ฐ์ํ๋ ๋ชจ๋ ์ค๋ฅ๋ ํธ์ถ์(app.py)์๊ฒ RuntimeError๋ก ์ ๋ฌ
raise RuntimeError(f"์ด๋ฏธ์ง ๋ถ๋ฅ ์ค PyTorch/CUDA ์ค๋ฅ ๋ฐ์: {e}")
|