cnn_hf / effinet_basic.py
WildOjisan's picture
.
4f364b6
import torch
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import json
import os
import sys # ํŒŒ์ผ ๋ˆ„๋ฝ ์‹œ ์ข…๋ฃŒ๋ฅผ ์œ„ํ•ด ์ถ”๊ฐ€
# ==============================================================================
# 0. ImageNet ํด๋ž˜์Šค ์ด๋ฆ„ ๋กœ๋“œ
# ==============================================================================
CLASS_MAP_FILENAME = 'labels_map.txt'
class_name_map = None # ์ „์—ญ ๋ณ€์ˆ˜๋กœ ์ดˆ๊ธฐํ™”
try:
if not os.path.exists(CLASS_MAP_FILENAME):
print(f"[์˜ค๋ฅ˜] ํด๋ž˜์Šค ์ด๋ฆ„ ํŒŒ์ผ('{CLASS_MAP_FILENAME}')์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
print("ํŒŒ์ผ์„ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅํ–ˆ๋Š”์ง€ ํ™•์ธํ•ด ์ฃผ์„ธ์š”.")
sys.exit(1) # ํŒŒ์ผ์ด ์—†์œผ๋ฉด ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ
# 1. ํŒŒ์ผ ๋กœ๋“œ (JSON ํ˜•์‹)
with open(CLASS_MAP_FILENAME, 'r') as f:
class_map_json = json.load(f)
# 2. ์ œ๊ณตํ•ด์ฃผ์‹  ๋กœ์ง ์ ์šฉ: ์ธ๋ฑ์Šค 0๋ถ€ํ„ฐ 999๊นŒ์ง€ ์ด๋ฆ„๋งŒ ์ถ”์ถœํ•˜์—ฌ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
# JSON ํŒŒ์ผ์˜ ํ‚ค๊ฐ€ ๋ฌธ์ž์—ด์ด๋ฏ€๋กœ str(i)๋กœ ์ ‘๊ทผํ•˜๊ณ , ๊ฐ’ ๋ฆฌ์ŠคํŠธ์˜ ๋‘ ๋ฒˆ์งธ ์š”์†Œ(์ด๋ฆ„)๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
labels_list = [class_map_json[str(i)] for i in range(1000)]
# 3. ์ธ๋ฑ์Šค์™€ ์ด๋ฆ„์„ ๋งคํ•‘ํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜ (๋‚˜์ค‘์— ํด๋ž˜์Šค ID๋กœ ์ด๋ฆ„ ์กฐํšŒ ์šฉ์ด)
class_name_map = {i: name for i, name in enumerate(labels_list)}
print(f"ImageNet ํด๋ž˜์Šค ์ด๋ฆ„ ({len(class_name_map)}๊ฐœ) ๋กœ๋“œ ์™„๋ฃŒ.")
except Exception as e:
print(f"[์˜ค๋ฅ˜] ํด๋ž˜์Šค ํŒŒ์ผ ๋กœ๋“œ ๋˜๋Š” ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
sys.exit(1) # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ
# ==============================================================================
# 1. ์„ค์ • ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
# ==============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์‚ฌ์šฉ ์žฅ์น˜: {device}")
# ImageNet์œผ๋กœ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ EfficientNetB0 ๋ชจ๋ธ ๋กœ๋“œ
print("์‚ฌ์ „ ํ›ˆ๋ จ๋œ EfficientNetB0 ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
model = model.to(device)
print("๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ • ์™„๋ฃŒ.")
# ==============================================================================
# 2. ํ•„์ˆ˜ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ ์ •์˜
# ==============================================================================
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ==============================================================================
# 3. ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ฐ ์ถœ๋ ฅ ํ•จ์ˆ˜
# ==============================================================================
def classify_image(image_path_string):
"""
์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ์˜ ํŒŒ์ผ์„ EfficientNetB0 ๋ชจ๋ธ๋กœ ๋ถ„๋ฅ˜ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
(ํด๋ž˜์Šค ์ด๋ฆ„ ํฌํ•จ)
"""
try:
# 1. ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ RGB ๋ณ€ํ™˜
img = Image.open(image_path_string).convert('RGB')
print(f"\n[INFO] ์ด๋ฏธ์ง€ ๋กœ๋“œ ์„ฑ๊ณต: {image_path_string}")
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_batch)
probabilities = F.softmax(output[0], dim=0)
top_prob, top_catid = torch.topk(probabilities, 5)
print("\n--- ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ (Top-5) ---")
for i in range(top_prob.size(0)):
idx = top_catid[i].item()
# ํด๋ž˜์Šค ์ด๋ฆ„ ๋งคํ•‘ ์ ์šฉ: ๋กœ๋“œ๋œ ๋”•์…”๋„ˆ๋ฆฌ ์‚ฌ์šฉ
class_name = class_name_map.get(idx, f"์•Œ ์ˆ˜ ์—†๋Š” ํด๋ž˜์Šค (ID: {idx})")
print(f"์ˆœ์œ„ {i+1}:")
print(f" - ํด๋ž˜์Šค ์ด๋ฆ„: **{class_name}**")
print(f" - ํด๋ž˜์Šค ์ธ๋ฑ์Šค (ID): {idx}")
print(f" - ํ™•๋ฅ : {top_prob[i].item():.4f}")
except FileNotFoundError:
print(f"\n[์˜ค๋ฅ˜] ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {image_path_string}")
print("๊ฒฝ๋กœ๋ฅผ ๋‹ค์‹œ ํ™•์ธํ•ด์ฃผ์„ธ์š”.")
except Exception as e:
print(f"\n[์˜ค๋ฅ˜] ๋ถ„๋ฅ˜ ์ค‘ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")
# --- ์‹คํ–‰ ---
# ๋ถ„๋ฅ˜ํ•  ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฌธ์ž์—ด๋กœ ์ง€์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ • ํ•„์š”!)
CLASSIFY_TARGET_PATH = 'D:/pictures/muffin1.png'
# ํ•จ์ˆ˜ ์‹คํ–‰
classify_image(CLASSIFY_TARGET_PATH)