Spaces:
Sleeping
Sleeping
File size: 4,689 Bytes
e5e8675 4f364b6 e5e8675 4f364b6 e5e8675 4f364b6 e5e8675 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import torch
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import json
import os
import sys # ํ์ผ ๋๋ฝ ์ ์ข
๋ฃ๋ฅผ ์ํด ์ถ๊ฐ
# ==============================================================================
# 0. ImageNet ํด๋์ค ์ด๋ฆ ๋ก๋
# ==============================================================================
CLASS_MAP_FILENAME = 'labels_map.txt'
class_name_map = None # ์ ์ญ ๋ณ์๋ก ์ด๊ธฐํ
try:
if not os.path.exists(CLASS_MAP_FILENAME):
print(f"[์ค๋ฅ] ํด๋์ค ์ด๋ฆ ํ์ผ('{CLASS_MAP_FILENAME}')์ ์ฐพ์ ์ ์์ต๋๋ค.")
print("ํ์ผ์ ํ์ฌ ๋๋ ํ ๋ฆฌ์ ์ ์ฅํ๋์ง ํ์ธํด ์ฃผ์ธ์.")
sys.exit(1) # ํ์ผ์ด ์์ผ๋ฉด ํ๋ก๊ทธ๋จ ์ข
๋ฃ
# 1. ํ์ผ ๋ก๋ (JSON ํ์)
with open(CLASS_MAP_FILENAME, 'r') as f:
class_map_json = json.load(f)
# 2. ์ ๊ณตํด์ฃผ์ ๋ก์ง ์ ์ฉ: ์ธ๋ฑ์ค 0๋ถํฐ 999๊น์ง ์ด๋ฆ๋ง ์ถ์ถํ์ฌ ๋ฆฌ์คํธ ์์ฑ
# JSON ํ์ผ์ ํค๊ฐ ๋ฌธ์์ด์ด๋ฏ๋ก str(i)๋ก ์ ๊ทผํ๊ณ , ๊ฐ ๋ฆฌ์คํธ์ ๋ ๋ฒ์งธ ์์(์ด๋ฆ)๋ฅผ ๊ฐ์ ธ์ต๋๋ค.
labels_list = [class_map_json[str(i)] for i in range(1000)]
# 3. ์ธ๋ฑ์ค์ ์ด๋ฆ์ ๋งคํํ๋ ๋์
๋๋ฆฌ๋ก ๋ณํ (๋์ค์ ํด๋์ค ID๋ก ์ด๋ฆ ์กฐํ ์ฉ์ด)
class_name_map = {i: name for i, name in enumerate(labels_list)}
print(f"ImageNet ํด๋์ค ์ด๋ฆ ({len(class_name_map)}๊ฐ) ๋ก๋ ์๋ฃ.")
except Exception as e:
print(f"[์ค๋ฅ] ํด๋์ค ํ์ผ ๋ก๋ ๋๋ ์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {e}")
sys.exit(1) # ์ค๋ฅ ๋ฐ์ ์ ํ๋ก๊ทธ๋จ ์ข
๋ฃ
# ==============================================================================
# 1. ์ค์ ๋ฐ ๋ชจ๋ธ ๋ก๋
# ==============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์ฌ์ฉ ์ฅ์น: {device}")
# ImageNet์ผ๋ก ์ฌ์ ํ๋ จ๋ EfficientNetB0 ๋ชจ๋ธ ๋ก๋
print("์ฌ์ ํ๋ จ๋ EfficientNetB0 ๋ชจ๋ธ ๋ก๋ ์ค...")
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
model.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
model = model.to(device)
print("๋ชจ๋ธ ๋ก๋ ๋ฐ ํ๊ฐ ๋ชจ๋ ์ค์ ์๋ฃ.")
# ==============================================================================
# 2. ํ์ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ ์ ์
# ==============================================================================
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# ==============================================================================
# 3. ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ฐ ์ถ๋ ฅ ํจ์
# ==============================================================================
def classify_image(image_path_string):
"""
์ฃผ์ด์ง ์ด๋ฏธ์ง ๊ฒฝ๋ก์ ํ์ผ์ EfficientNetB0 ๋ชจ๋ธ๋ก ๋ถ๋ฅํ๊ณ ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํฉ๋๋ค.
(ํด๋์ค ์ด๋ฆ ํฌํจ)
"""
try:
# 1. ์ด๋ฏธ์ง ๋ก๋ ๋ฐ RGB ๋ณํ
img = Image.open(image_path_string).convert('RGB')
print(f"\n[INFO] ์ด๋ฏธ์ง ๋ก๋ ์ฑ๊ณต: {image_path_string}")
input_tensor = preprocess(img)
input_batch = input_tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_batch)
probabilities = F.softmax(output[0], dim=0)
top_prob, top_catid = torch.topk(probabilities, 5)
print("\n--- ๋ถ๋ฅ ๊ฒฐ๊ณผ (Top-5) ---")
for i in range(top_prob.size(0)):
idx = top_catid[i].item()
# ํด๋์ค ์ด๋ฆ ๋งคํ ์ ์ฉ: ๋ก๋๋ ๋์
๋๋ฆฌ ์ฌ์ฉ
class_name = class_name_map.get(idx, f"์ ์ ์๋ ํด๋์ค (ID: {idx})")
print(f"์์ {i+1}:")
print(f" - ํด๋์ค ์ด๋ฆ: **{class_name}**")
print(f" - ํด๋์ค ์ธ๋ฑ์ค (ID): {idx}")
print(f" - ํ๋ฅ : {top_prob[i].item():.4f}")
except FileNotFoundError:
print(f"\n[์ค๋ฅ] ์ด๋ฏธ์ง ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค: {image_path_string}")
print("๊ฒฝ๋ก๋ฅผ ๋ค์ ํ์ธํด์ฃผ์ธ์.")
except Exception as e:
print(f"\n[์ค๋ฅ] ๋ถ๋ฅ ์ค ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}")
# --- ์คํ ---
# ๋ถ๋ฅํ ์ด๋ฏธ์ง ํ์ผ ๊ฒฝ๋ก๋ฅผ ๋ฌธ์์ด๋ก ์ง์ (์ฌ์ฉ์ ํ๊ฒฝ์ ๋ง๊ฒ ์์ ํ์!)
CLASSIFY_TARGET_PATH = 'D:/pictures/muffin1.png'
# ํจ์ ์คํ
classify_image(CLASSIFY_TARGET_PATH) |