File size: 10,364 Bytes
80572ca
 
 
 
7a5c9ed
 
 
 
 
80572ca
 
 
62c8941
80572ca
 
62c8941
 
80572ca
 
 
 
 
 
7a5c9ed
 
 
 
 
80572ca
 
7a5c9ed
306acbb
7a5c9ed
 
 
 
 
 
 
 
 
 
 
306acbb
7a5c9ed
80572ca
 
7a5c9ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80572ca
7a5c9ed
80572ca
7a5c9ed
80572ca
7a5c9ed
80572ca
 
7a5c9ed
80572ca
 
7a5c9ed
80572ca
7a5c9ed
 
 
 
 
 
80572ca
 
 
7a5c9ed
80572ca
 
 
7a5c9ed
62c8941
7a5c9ed
62c8941
 
7a5c9ed
 
 
62c8941
7a5c9ed
80572ca
 
 
62c8941
 
7a5c9ed
 
 
62c8941
 
 
 
 
80572ca
 
 
63e7e05
62c8941
7a5c9ed
80572ca
62c8941
 
7a5c9ed
62c8941
 
 
 
7a5c9ed
62c8941
 
 
7a5c9ed
62c8941
 
80572ca
7a5c9ed
 
80572ca
 
7a5c9ed
80572ca
 
 
 
 
 
 
 
 
 
 
 
62c8941
80572ca
 
 
 
 
63e7e05
80572ca
63e7e05
 
80572ca
 
 
 
 
 
 
 
62c8941
80572ca
 
 
62c8941
80572ca
 
 
7a5c9ed
 
 
 
 
 
 
80572ca
 
 
 
 
 
 
 
 
7a5c9ed
 
 
 
 
 
 
 
 
 
 
 
80572ca
 
7a5c9ed
 
 
 
 
 
 
 
 
80572ca
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
AlexNet β€” ν—ˆκΉ…νŽ˜μ΄μŠ€ Spaces 데λͺ¨
λ…Όλ¬Έ: Krizhevsky, Sutskever, Hinton (NeurIPS 2012)

핡심 λ³€κ²½:
  - torchvision AlexNetκ³Ό μ™„μ „νžˆ λ™μΌν•œ ꡬ쑰(groups=1)둜 맞좰
    μ‚¬μ „ν•™μŠ΅ κ°€μ€‘μΉ˜λ₯Ό Conv+FC 전체 λ‘œλ“œ β†’ μ‹€μ œ λΆ„λ₯˜ μž‘λ™
  - ImageNet 1000개 클래슀 이름 μžλ™ λ‘œλ“œ
    (κ°•μ•„μ§€, 고양이, 사과, μ‚¬λžŒ λ“± λͺ¨λ‘ 포함)
"""

import json
import requests
import torch
import torch.nn as nn
import torchvision.models as tv
import torchvision.transforms as T
import gradio as gr
from PIL import Image


# ──────────────────────────────────────────────────────────────
# 1. λͺ¨λΈ μ •μ˜
#    torchvision AlexNetκ³Ό μ™„μ „ 동일 ꡬ쑰 (groups=1, κ°€μ€‘μΉ˜ ν˜Έν™˜)
#
#    λ…Όλ¬Έ GPU λΆ„ν• (groups=2)은 λ©”λͺ¨λ¦¬ μ œν•œ λ•Œλ¬Έμ΄μ—ˆκ³ ,
#    μ§€κΈˆμ€ GPU λ©”λͺ¨λ¦¬κ°€ μΆ©λΆ„ν•˜λ―€λ‘œ groups=1둜 λ™μΌν•˜κ²Œ κ΅¬ν˜„.
#    λ…Όλ¬Έμ˜ λͺ¨λ“  ν•˜μ΄νΌνŒŒλΌλ―Έν„°(LRN, Dropout, padding λ“±)λŠ” κ·ΈλŒ€λ‘œ μœ μ§€.
# ──────────────────────────────────────────────────────────────

class AlexNet(nn.Module):
    """
    λ…Όλ¬Έ Figure 2 μž¬ν˜„ β€” torchvision κ°€μ€‘μΉ˜ μ™„μ „ ν˜Έν™˜ 버전.

    torchvision AlexNet ꡬ쑰와 1:1 λŒ€μ‘:
      Conv1: kernel=11, stride=4, padding=2  -> (B, 64,  55, 55) -> pool -> (B, 64,  27, 27)
      Conv2: kernel=5,  stride=1, padding=2  -> (B,192,  27, 27) -> pool -> (B,192,  13, 13)
      Conv3: kernel=3,  stride=1, padding=1  -> (B,384,  13, 13)
      Conv4: kernel=3,  stride=1, padding=1  -> (B,256,  13, 13)
      Conv5: kernel=3,  stride=1, padding=1  -> (B,256,  13, 13) -> pool -> (B,256,   6,  6)
      FC1: 9216 -> 4096  (Dropout 0.5)
      FC2: 4096 -> 4096  (Dropout 0.5)
      FC3: 4096 -> num_labels
    """
    def __init__(self, num_labels: int = 1000, dropout: float = 0.5):
        super().__init__()

        # features: torchvision Sequentialκ³Ό λ™μΌν•œ μˆœμ„œΒ·νŒŒλΌλ―Έν„°
        self.features = nn.Sequential(
            # Conv1
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # Conv2
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # Conv3
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # Conv4
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            # Conv5
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        # classifier: torchvision Sequentialκ³Ό 동일
        self.classifier = nn.Sequential(
            nn.Dropout(p=dropout),          # λ…Όλ¬Έ 4.2절: FC1 μ•ž Dropout
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),          # λ…Όλ¬Έ 4.2절: FC2 μ•ž Dropout
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_labels),    # FC3: Dropout μ—†μŒ
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)       # (B, 256, 6, 6)
        x = self.avgpool(x)        # (B, 256, 6, 6) β€” 크기 보μž₯
        x = x.view(x.size(0), -1) # (B, 9216)
        return self.classifier(x)  # (B, num_labels)


# ──────────────────────────────────────────────────────────────
# 2. λͺ¨λΈ 생성 + torchvision μ‚¬μ „ν•™μŠ΅ κ°€μ€‘μΉ˜ 전체 λ‘œλ“œ
# ──────────────────────────────────────────────────────────────

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = AlexNet(num_labels=1000).to(DEVICE)

WEIGHTS_STATUS = "랜덀 μ΄ˆκΈ°ν™” (예츑 의미 μ—†μŒ)"
try:
    pretrained = tv.alexnet(weights=tv.AlexNet_Weights.DEFAULT)
    model.load_state_dict(pretrained.state_dict())   # Conv + FC 전체 볡사
    WEIGHTS_STATUS = "ImageNet μ‚¬μ „ν•™μŠ΅ μ™„λ£Œ (torchvision)"
    print("κ°€μ€‘μΉ˜ 전체 λ‘œλ“œ μ™„λ£Œ")
except Exception as e:
    print(f"κ°€μ€‘μΉ˜ λ‘œλ“œ μ‹€νŒ¨: {e}")

model.eval()


# ──────────────────────────────────────────────────────────────
# 3. ImageNet 1000개 클래슀 이름 λ‘œλ“œ
#    κ°•μ•„μ§€(n02085620~), 고양이(n02123045~), 사과(948), μ‚¬λžŒ μ—†μŒ*
#    *ImageNet은 μ‚¬λžŒ 클래슀λ₯Ό ν¬ν•¨ν•˜μ§€ μ•ŠμŒ
# ──────────────────────────────────────────────────────────────

ID2LABEL = {}

# 1μˆœμœ„: config.json
try:
    with open("config.json") as f:
        cfg = json.load(f)
    ID2LABEL = {int(k): v for k, v in cfg.get("id2label", {}).items()}
    if ID2LABEL:
        print(f"config.json: {len(ID2LABEL)}개 클래슀")
except Exception:
    pass

# 2μˆœμœ„: ν—ˆκΉ…νŽ˜μ΄μŠ€ ViT config (ImageNet 1000 라벨 동일)
if not ID2LABEL:
    try:
        resp = requests.get(
            "https://huggingface.co/google/vit-base-patch16-224/raw/main/config.json",
            timeout=15,
        )
        vit_cfg = resp.json()
        ID2LABEL = {int(k): v for k, v in vit_cfg.get("id2label", {}).items()}
        print(f"ν—ˆκΉ…νŽ˜μ΄μŠ€: {len(ID2LABEL)}개 클래슀 λ‘œλ“œ")
    except Exception as e:
        print(f"클래슀 이름 λ‘œλ“œ μ‹€νŒ¨: {e}")

LABEL_STATUS = f"ImageNet {len(ID2LABEL)}개 클래슀" if ID2LABEL else "클래슀 이름 μ—†μŒ"


# ──────────────────────────────────────────────────────────────
# 4. μ „μ²˜λ¦¬ (torchvision AlexNet_Weights.DEFAULT와 동일)
# ──────────────────────────────────────────────────────────────

TRANSFORM = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
])


# ──────────────────────────────────────────────────────────────
# 5. μΆ”λ‘  ν•¨μˆ˜
# ──────────────────────────────────────────────────────────────

def predict(image: Image.Image) -> dict:
    if image is None:
        return {}
    tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(tensor)
    probs = torch.softmax(logits, dim=-1)[0]
    top5_probs, top5_idx = probs.topk(5)
    return {
        ID2LABEL.get(idx.item(), f"class_{idx.item()}"): round(prob.item(), 4)
        for prob, idx in zip(top5_probs, top5_idx)
    }


# ──────────────────────────────────────────────────────────────
# 6. Gradio UI
# ──────────────────────────────────────────────────────────────

with gr.Blocks(title="AlexNet β€” λ…Όλ¬Έ μž¬ν˜„") as demo:
    gr.Markdown(f"""
    ## AlexNet β€” λ…Όλ¬Έ μ™„μ „ μž¬ν˜„ 데λͺ¨
    **λ…Όλ¬Έ**: ImageNet Classification with Deep CNNs (Krizhevsky et al., NeurIPS 2012)

    | ν•­λͺ© | μƒνƒœ |
    |------|------|
    | κ°€μ€‘μΉ˜ | {WEIGHTS_STATUS} |
    | 클래슀 | {LABEL_STATUS} |

    > β€» ImageNet은 μ‚¬λžŒ(λ‚¨μž/μ—¬μž) 클래슀λ₯Ό ν¬ν•¨ν•˜μ§€ μ•Šμ•„μš”.
    > κ°•μ•„μ§€Β·κ³ μ–‘μ΄Β·μ‚¬κ³ΌΒ·μžλ™μ°¨ λ“± 1000개 물체 μΉ΄ν…Œκ³ λ¦¬λ₯Ό μΈμ‹ν•©λ‹ˆλ‹€.
    """)

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="μž…λ ₯ 이미지")
            run_btn = gr.Button("μ˜ˆμΈ‘ν•˜κΈ°", variant="primary")
        with gr.Column():
            label_output = gr.Label(num_top_classes=5, label="Top-5 예츑")

    with gr.Accordion("인식 κ°€λŠ₯ν•œ μ£Όμš” μΉ΄ν…Œκ³ λ¦¬", open=False):
        gr.Markdown("""
        **동물**: 개(120μ’…), 고양이(8μ’…), μƒˆ(59μ’…), λ¬Όκ³ κΈ°, λ±€, κ³°, 코끼리 λ“±  
        **μŒμ‹**: 사과, 레λͺ¬, λ”ΈκΈ°, μ•„μ΄μŠ€ν¬λ¦Ό, ν”Όμž, 버섯 λ“±  
        **νƒˆκ²ƒ**: μžλ™μ°¨, λ²„μŠ€, κΈ°μ°¨, λΉ„ν–‰κΈ°, λ°°, μ˜€ν† λ°”μ΄ λ“±  
        **사물**: 의자, μ‹œκ³„, μ»΅, ν‚€λ³΄λ“œ, μ•ˆκ²½, μš°μ‚° λ“±  
        **μžμ—°**: μ‚°ν˜Έμ΄ˆ, ν™”μ‚°, 폭포, λΉ™ν•˜ λ“±

        > μ‚¬λžŒ(λ‚¨μž/μ—¬μž)은 ImageNet 1000 ν΄λž˜μŠ€μ— ν¬ν•¨λ˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
        > μ‚¬λžŒ 인식이 ν•„μš”ν•˜λ©΄ CLIP λ˜λŠ” COCO ν•™μŠ΅ λͺ¨λΈμ΄ ν•„μš”ν•΄μš”.
        """)

    with gr.Accordion("λͺ¨λΈ ꡬ쑰 (λ…Όλ¬Έ Figure 2)", open=False):
        gr.Markdown("""
        | λ ˆμ΄μ–΄ | 컀널 | 좜λ ₯ shape      | λ…Όλ¬Έ μ„Ήμ…˜ |
        |--------|------|-----------------|-----------|
        | Conv1  | 11Γ—11 stride=4 | (B, 64, 27, 27)  | 3.5절 |
        | Conv2  | 5Γ—5           | (B, 192, 13, 13) | 3.5절 |
        | Conv3  | 3Γ—3           | (B, 384, 13, 13) | 3.5절 |
        | Conv4  | 3Γ—3           | (B, 256, 13, 13) | 3.5절 |
        | Conv5  | 3Γ—3           | (B, 256, 6, 6)   | 3.5절 |
        | FC1Β·2  | β€”             | (B, 4096)        | 4.2절 Dropout 0.5 |
        | FC3    | β€”             | (B, 1000)        | Abstract |
        """)

    run_btn.click(fn=predict, inputs=image_input, outputs=label_output)
    image_input.change(fn=predict, inputs=image_input, outputs=label_output)

if __name__ == "__main__":
    demo.launch()