cnn_hf / effinet_basic_compo.py
WildOjisan's picture
.
00f636d
import torch
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import json
import os
import sys
# ==============================================================================
# 0. ImageNet ํด๋ž˜์Šค ์ด๋ฆ„ ๋กœ๋“œ
# ==============================================================================
CLASS_MAP_FILENAME = 'labels_map.txt'
class_name_map = None
# API ์„œ๋ฒ„ ์‹œ์ž‘ ์‹œ ImageNet ํด๋ž˜์Šค ๋งต์„ ๋ฉ”๋ชจ๋ฆฌ์— ๋กœ๋“œ
try:
if not os.path.exists(CLASS_MAP_FILENAME):
# NOTE: API ํ™˜๊ฒฝ์—์„œ๋Š” sys.exit ๋Œ€์‹  ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œ์ผœ์•ผ ํ•ฉ๋‹ˆ๋‹ค.
raise FileNotFoundError(f"[์˜ค๋ฅ˜] ํด๋ž˜์Šค ์ด๋ฆ„ ํŒŒ์ผ('{CLASS_MAP_FILENAME}')์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. API ์„œ๋ฒ„๋ฅผ ์‹œ์ž‘ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
with open(CLASS_MAP_FILENAME, 'r') as f:
class_map_json = json.load(f)
# ๐Ÿšจ๐Ÿšจ๐Ÿšจ ์ด ๋ถ€๋ถ„์ด ์ˆ˜์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๐Ÿšจ๐Ÿšจ๐Ÿšจ
# ๊ฐ’(Value)์ด ๋ฌธ์ž์—ด์ธ ๊ฒฝ์šฐ: v ์ž์ฒด๊ฐ€ ํด๋ž˜์Šค ์ด๋ฆ„์ž…๋‹ˆ๋‹ค.
# ๊ฐ’(Value)์ด ๋ฆฌ์ŠคํŠธ์ธ ๊ฒฝ์šฐ: ๋ฆฌ์ŠคํŠธ์˜ ๋งˆ์ง€๋ง‰ ์š”์†Œ(์ผ๋ฐ˜์ ์œผ๋กœ ์ธ๋ฑ์Šค 1)๋ฅผ ํด๋ž˜์Šค ์ด๋ฆ„์œผ๋กœ ๊ฐ€์ •ํ•ฉ๋‹ˆ๋‹ค.
labels_list = []
for k, v in class_map_json.items():
if k.isdigit() and 0 <= int(k) < 1000:
if isinstance(v, list) and len(v) > 1:
labels_list.append(v[1]) # ๋ฆฌ์ŠคํŠธ์ผ ๊ฒฝ์šฐ ๋‘ ๋ฒˆ์งธ ์š”์†Œ (์ด์ „ ์ฝ”๋“œ ์œ ์ง€)
elif isinstance(v, str):
labels_list.append(v) # ๋ฌธ์ž์—ด์ผ ๊ฒฝ์šฐ ์ „์ฒด ๋ฌธ์ž์—ด ์‚ฌ์šฉ (์ˆ˜์ •๋œ ํ•ต์‹ฌ)
else:
# ์•Œ ์ˆ˜ ์—†๋Š” ํ˜•์‹์€ ๋ฌด์‹œํ•˜๊ฑฐ๋‚˜, ๊ธฐ๋ณธ๊ฐ’ ์„ค์ •
labels_list.append(f"Unknown Class Index {k}")
# ์ธ๋ฑ์Šค์™€ ์ด๋ฆ„ ๋งคํ•‘ ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜
# labels_list์˜ ์ˆœ์„œ๊ฐ€ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ ์ธ๋ฑ์Šค (0~999)์™€ ์ผ์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
class_name_map = {i: name for i, name in enumerate(labels_list)}
# ํด๋ž˜์Šค ๋งต์ด 1000๊ฐœ๊ฐ€ ๋งž๋Š”์ง€ ํ™•์ธ (ImageNet ๊ธฐ์ค€)
if len(class_name_map) != 1000:
print(f"[๊ฒฝ๊ณ ] ๋กœ๋“œ๋œ ํด๋ž˜์Šค ์ˆ˜: {len(class_name_map)}๊ฐœ. ImageNet (1000๊ฐœ)๊ณผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. ํ™•์ธํ•ด ์ฃผ์„ธ์š”.")
print(f"ImageNet ํด๋ž˜์Šค ๋งต ๋กœ๋“œ ์„ฑ๊ณต. (์ด {len(class_name_map)}๊ฐœ)")
except FileNotFoundError as e:
# API ์„œ๋ฒ„ ์‹œ์ž‘์„ ๋ง‰๊ธฐ ์œ„ํ•ด ๋ฐœ์ƒ๋œ ์˜ค๋ฅ˜๋ฅผ ๋‹ค์‹œ ๋ฐœ์ƒ
raise e
except Exception as e:
# JSON ํŒŒ์‹ฑ ์˜ค๋ฅ˜ ๋“ฑ ๊ธฐํƒ€ ๋กœ๋”ฉ ์˜ค๋ฅ˜
print(f"[์˜ค๋ฅ˜] ํด๋ž˜์Šค ๋งต ๋กœ๋“œ ์ค‘ ์˜ˆ๊ธฐ์น˜ ์•Š์€ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
class_name_map = None # ๋กœ๋“œ ์‹คํŒจ ์‹œ None ์œ ์ง€
# API ์„œ๋ฒ„ ์‹œ์ž‘์„ ๋ง‰๊ธฐ ์œ„ํ•ด RuntimeError ๋ฐœ์ƒ
raise RuntimeError(f"ํด๋ž˜์Šค ๋งต ๋กœ๋“œ ์˜ค๋ฅ˜: {e}")
# ==============================================================================
# 1. ๋ชจ๋ธ ๋ฐ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ (์ „์—ญ์ ์œผ๋กœ ํ•œ ๋ฒˆ๋งŒ ์‹คํ–‰)
# ==============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ์‚ฌ์ „ ํ›ˆ๋ จ๋œ EfficientNetB0 ๋ชจ๋ธ ๋กœ๋“œ
# ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ try-except ๋ธ”๋ก์œผ๋กœ ๊ฐ์Œ‰๋‹ˆ๋‹ค.
try:
# weights ๊ฐ์ฒด๋Š” ์ „์ฒ˜๋ฆฌ(transforms) ์ •๋ณด๋„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
weights = EfficientNet_B0_Weights.DEFAULT
model = efficientnet_b0(weights=weights).to(device).eval() # eval ๋ชจ๋“œ๋กœ ์„ค์ •
preprocess = weights.transforms() # ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
print("EfficientNetB0 ๋ชจ๋ธ ๋ฐ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ ์„ฑ๊ณต.")
except Exception as e:
print(f"[์˜ค๋ฅ˜] EfficientNetB0 ๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
# ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ ์‹œ None ์„ค์ • ํ›„, ๋ถ„๋ฅ˜ ์‹œ ์˜ค๋ฅ˜๋ฅผ ๋ฐœ์ƒ์‹œํ‚ค๋„๋ก ํ•จ
model = None
preprocess = None
raise RuntimeError(f"๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {e}")
# ==============================================================================
# 2. ๋ถ„๋ฅ˜ ํ•จ์ˆ˜ (API์—์„œ ์‚ฌ์šฉ)
# ==============================================================================
def classify_image_pil(img: Image.Image) -> list:
"""
์ฃผ์–ด์ง„ PIL Image ๊ฐ์ฒด๋ฅผ EfficientNetB0 ๋ชจ๋ธ๋กœ ๋ถ„๋ฅ˜ํ•˜๊ณ 
Top-5 ๊ฒฐ๊ณผ๋ฅผ ๋ฆฌ์ŠคํŠธ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
"""
if class_name_map is None or not model:
raise RuntimeError("๋ชจ๋ธ ๋˜๋Š” ํด๋ž˜์Šค ๋งต์ด ์•„์ง ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
try:
# 1. ์ด๋ฏธ์ง€ RGB ๋ณ€ํ™˜ ๋ฐ ์ „์ฒ˜๋ฆฌ
img = img.convert('RGB')
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)
# 2. ์ถ”๋ก  ์ˆ˜ํ–‰
with torch.no_grad():
output = model(input_batch)
# 3. ํ™•๋ฅ  ๋ฐ Top-K ์ถ”์ถœ
probabilities = F.softmax(output[0], dim=0)
# Top-5 ํ™•๋ฅ  ๋ฐ ์ธ๋ฑ์Šค (์นดํ…Œ๊ณ ๋ฆฌ ID) ์ถ”์ถœ
top_prob, top_catid = torch.topk(probabilities, 5)
results = []
for i in range(top_prob.size(0)):
idx = top_catid[i].item()
# ํด๋ž˜์Šค ์ด๋ฆ„ ๋งคํ•‘ ์ ์šฉ
class_name = class_name_map.get(idx, f"์•Œ ์ˆ˜ ์—†๋Š” ํด๋ž˜์Šค (ID: {idx})")
results.append({
"rank": i + 1,
"class_name": class_name,
"class_index": idx,
"probability": top_prob[i].item()
})
return results
except Exception as e:
# ๋ถ„๋ฅ˜ ์ค‘ ๋ฐœ์ƒํ•˜๋Š” ๋ชจ๋“  ์˜ค๋ฅ˜๋Š” ํ˜ธ์ถœ์ž(app.py)์—๊ฒŒ RuntimeError๋กœ ์ „๋‹ฌ
raise RuntimeError(f"์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ์ค‘ PyTorch/CUDA ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")