File size: 4,689 Bytes
e5e8675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f364b6
e5e8675
 
 
 
 
 
 
 
 
 
 
 
 
 
4f364b6
e5e8675
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f364b6
e5e8675
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import json 
import os
import sys # ํŒŒ์ผ ๋ˆ„๋ฝ ์‹œ ์ข…๋ฃŒ๋ฅผ ์œ„ํ•ด ์ถ”๊ฐ€

# ==============================================================================
# 0. ImageNet ํด๋ž˜์Šค ์ด๋ฆ„ ๋กœ๋“œ
# ==============================================================================
CLASS_MAP_FILENAME = 'labels_map.txt' 
class_name_map = None # ์ „์—ญ ๋ณ€์ˆ˜๋กœ ์ดˆ๊ธฐํ™”

try:
    if not os.path.exists(CLASS_MAP_FILENAME):
        print(f"[์˜ค๋ฅ˜] ํด๋ž˜์Šค ์ด๋ฆ„ ํŒŒ์ผ('{CLASS_MAP_FILENAME}')์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
        print("ํŒŒ์ผ์„ ํ˜„์žฌ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅํ–ˆ๋Š”์ง€ ํ™•์ธํ•ด ์ฃผ์„ธ์š”.")
        sys.exit(1) # ํŒŒ์ผ์ด ์—†์œผ๋ฉด ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ

    # 1. ํŒŒ์ผ ๋กœ๋“œ (JSON ํ˜•์‹)
    with open(CLASS_MAP_FILENAME, 'r') as f:
        class_map_json = json.load(f)

    # 2. ์ œ๊ณตํ•ด์ฃผ์‹  ๋กœ์ง ์ ์šฉ: ์ธ๋ฑ์Šค 0๋ถ€ํ„ฐ 999๊นŒ์ง€ ์ด๋ฆ„๋งŒ ์ถ”์ถœํ•˜์—ฌ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
    # JSON ํŒŒ์ผ์˜ ํ‚ค๊ฐ€ ๋ฌธ์ž์—ด์ด๋ฏ€๋กœ str(i)๋กœ ์ ‘๊ทผํ•˜๊ณ , ๊ฐ’ ๋ฆฌ์ŠคํŠธ์˜ ๋‘ ๋ฒˆ์งธ ์š”์†Œ(์ด๋ฆ„)๋ฅผ ๊ฐ€์ ธ์˜ต๋‹ˆ๋‹ค.
    labels_list = [class_map_json[str(i)] for i in range(1000)]
    
    # 3. ์ธ๋ฑ์Šค์™€ ์ด๋ฆ„์„ ๋งคํ•‘ํ•˜๋Š” ๋”•์…”๋„ˆ๋ฆฌ๋กœ ๋ณ€ํ™˜ (๋‚˜์ค‘์— ํด๋ž˜์Šค ID๋กœ ์ด๋ฆ„ ์กฐํšŒ ์šฉ์ด)
    class_name_map = {i: name for i, name in enumerate(labels_list)}
    
    print(f"ImageNet ํด๋ž˜์Šค ์ด๋ฆ„ ({len(class_name_map)}๊ฐœ) ๋กœ๋“œ ์™„๋ฃŒ.")

except Exception as e:
    print(f"[์˜ค๋ฅ˜] ํด๋ž˜์Šค ํŒŒ์ผ ๋กœ๋“œ ๋˜๋Š” ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    sys.exit(1) # ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ํ”„๋กœ๊ทธ๋žจ ์ข…๋ฃŒ


# ==============================================================================
# 1. ์„ค์ • ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
# ==============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"์‚ฌ์šฉ ์žฅ์น˜: {device}")

# ImageNet์œผ๋กœ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ EfficientNetB0 ๋ชจ๋ธ ๋กœ๋“œ
print("์‚ฌ์ „ ํ›ˆ๋ จ๋œ EfficientNetB0 ๋ชจ๋ธ ๋กœ๋“œ ์ค‘...")
model = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
model.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
model = model.to(device)
print("๋ชจ๋ธ ๋กœ๋“œ ๋ฐ ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ • ์™„๋ฃŒ.")

# ==============================================================================
# 2. ํ•„์ˆ˜ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ ์ •์˜
# ==============================================================================
preprocess = transforms.Compose([
    transforms.Resize(256),      
    transforms.CenterCrop(224), 
    transforms.ToTensor(),       
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

# ==============================================================================
# 3. ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ฐ ์ถœ๋ ฅ ํ•จ์ˆ˜
# ==============================================================================

def classify_image(image_path_string):
    """
    ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ์˜ ํŒŒ์ผ์„ EfficientNetB0 ๋ชจ๋ธ๋กœ ๋ถ„๋ฅ˜ํ•˜๊ณ  ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•ฉ๋‹ˆ๋‹ค.
    (ํด๋ž˜์Šค ์ด๋ฆ„ ํฌํ•จ)
    """
    try:
        # 1. ์ด๋ฏธ์ง€ ๋กœ๋“œ ๋ฐ RGB ๋ณ€ํ™˜
        img = Image.open(image_path_string).convert('RGB')
        print(f"\n[INFO] ์ด๋ฏธ์ง€ ๋กœ๋“œ ์„ฑ๊ณต: {image_path_string}")
        
        input_tensor = preprocess(img)
        input_batch = input_tensor.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(input_batch)

        probabilities = F.softmax(output[0], dim=0)
        top_prob, top_catid = torch.topk(probabilities, 5)

        print("\n--- ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ (Top-5) ---")
        
        for i in range(top_prob.size(0)):
            idx = top_catid[i].item()
            
            # ํด๋ž˜์Šค ์ด๋ฆ„ ๋งคํ•‘ ์ ์šฉ: ๋กœ๋“œ๋œ ๋”•์…”๋„ˆ๋ฆฌ ์‚ฌ์šฉ
            class_name = class_name_map.get(idx, f"์•Œ ์ˆ˜ ์—†๋Š” ํด๋ž˜์Šค (ID: {idx})")
            
            print(f"์ˆœ์œ„ {i+1}:")
            print(f"  - ํด๋ž˜์Šค ์ด๋ฆ„: **{class_name}**")
            print(f"  - ํด๋ž˜์Šค ์ธ๋ฑ์Šค (ID): {idx}")
            print(f"  - ํ™•๋ฅ : {top_prob[i].item():.4f}")

    except FileNotFoundError:
        print(f"\n[์˜ค๋ฅ˜] ์ด๋ฏธ์ง€ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค: {image_path_string}")
        print("๊ฒฝ๋กœ๋ฅผ ๋‹ค์‹œ ํ™•์ธํ•ด์ฃผ์„ธ์š”.")
    except Exception as e:
        print(f"\n[์˜ค๋ฅ˜] ๋ถ„๋ฅ˜ ์ค‘ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}")


# --- ์‹คํ–‰ ---
# ๋ถ„๋ฅ˜ํ•  ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ๋ฅผ ๋ฌธ์ž์—ด๋กœ ์ง€์ • (์‚ฌ์šฉ์ž ํ™˜๊ฒฝ์— ๋งž๊ฒŒ ์ˆ˜์ • ํ•„์š”!)
CLASSIFY_TARGET_PATH = 'D:/pictures/muffin1.png' 

# ํ•จ์ˆ˜ ์‹คํ–‰
classify_image(CLASSIFY_TARGET_PATH)