AlexNetCode / app.py
JangTaeng's picture
Upload app.py
7a5c9ed verified
"""
AlexNet β€” ν—ˆκΉ…νŽ˜μ΄μŠ€ Spaces 데λͺ¨
λ…Όλ¬Έ: Krizhevsky, Sutskever, Hinton (NeurIPS 2012)
핡심 λ³€κ²½:
- torchvision AlexNetκ³Ό μ™„μ „νžˆ λ™μΌν•œ ꡬ쑰(groups=1)둜 맞좰
μ‚¬μ „ν•™μŠ΅ κ°€μ€‘μΉ˜λ₯Ό Conv+FC 전체 λ‘œλ“œ β†’ μ‹€μ œ λΆ„λ₯˜ μž‘λ™
- ImageNet 1000개 클래슀 이름 μžλ™ λ‘œλ“œ
(κ°•μ•„μ§€, 고양이, 사과, μ‚¬λžŒ λ“± λͺ¨λ‘ 포함)
"""
import json
import requests
import torch
import torch.nn as nn
import torchvision.models as tv
import torchvision.transforms as T
import gradio as gr
from PIL import Image
# ──────────────────────────────────────────────────────────────
# 1. λͺ¨λΈ μ •μ˜
# torchvision AlexNetκ³Ό μ™„μ „ 동일 ꡬ쑰 (groups=1, κ°€μ€‘μΉ˜ ν˜Έν™˜)
#
# λ…Όλ¬Έ GPU λΆ„ν• (groups=2)은 λ©”λͺ¨λ¦¬ μ œν•œ λ•Œλ¬Έμ΄μ—ˆκ³ ,
# μ§€κΈˆμ€ GPU λ©”λͺ¨λ¦¬κ°€ μΆ©λΆ„ν•˜λ―€λ‘œ groups=1둜 λ™μΌν•˜κ²Œ κ΅¬ν˜„.
# λ…Όλ¬Έμ˜ λͺ¨λ“  ν•˜μ΄νΌνŒŒλΌλ―Έν„°(LRN, Dropout, padding λ“±)λŠ” κ·ΈλŒ€λ‘œ μœ μ§€.
# ──────────────────────────────────────────────────────────────
class AlexNet(nn.Module):
"""
λ…Όλ¬Έ Figure 2 μž¬ν˜„ β€” torchvision κ°€μ€‘μΉ˜ μ™„μ „ ν˜Έν™˜ 버전.
torchvision AlexNet ꡬ쑰와 1:1 λŒ€μ‘:
Conv1: kernel=11, stride=4, padding=2 -> (B, 64, 55, 55) -> pool -> (B, 64, 27, 27)
Conv2: kernel=5, stride=1, padding=2 -> (B,192, 27, 27) -> pool -> (B,192, 13, 13)
Conv3: kernel=3, stride=1, padding=1 -> (B,384, 13, 13)
Conv4: kernel=3, stride=1, padding=1 -> (B,256, 13, 13)
Conv5: kernel=3, stride=1, padding=1 -> (B,256, 13, 13) -> pool -> (B,256, 6, 6)
FC1: 9216 -> 4096 (Dropout 0.5)
FC2: 4096 -> 4096 (Dropout 0.5)
FC3: 4096 -> num_labels
"""
def __init__(self, num_labels: int = 1000, dropout: float = 0.5):
super().__init__()
# features: torchvision Sequentialκ³Ό λ™μΌν•œ μˆœμ„œΒ·νŒŒλΌλ―Έν„°
self.features = nn.Sequential(
# Conv1
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
# Conv2
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
# Conv3
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
# Conv4
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
# Conv5
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
# classifier: torchvision Sequentialκ³Ό 동일
self.classifier = nn.Sequential(
nn.Dropout(p=dropout), # λ…Όλ¬Έ 4.2절: FC1 μ•ž Dropout
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout), # λ…Όλ¬Έ 4.2절: FC2 μ•ž Dropout
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_labels), # FC3: Dropout μ—†μŒ
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.features(x) # (B, 256, 6, 6)
x = self.avgpool(x) # (B, 256, 6, 6) β€” 크기 보μž₯
x = x.view(x.size(0), -1) # (B, 9216)
return self.classifier(x) # (B, num_labels)
# ──────────────────────────────────────────────────────────────
# 2. λͺ¨λΈ 생성 + torchvision μ‚¬μ „ν•™μŠ΅ κ°€μ€‘μΉ˜ 전체 λ‘œλ“œ
# ──────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet(num_labels=1000).to(DEVICE)
WEIGHTS_STATUS = "랜덀 μ΄ˆκΈ°ν™” (예츑 의미 μ—†μŒ)"
try:
pretrained = tv.alexnet(weights=tv.AlexNet_Weights.DEFAULT)
model.load_state_dict(pretrained.state_dict()) # Conv + FC 전체 볡사
WEIGHTS_STATUS = "ImageNet μ‚¬μ „ν•™μŠ΅ μ™„λ£Œ (torchvision)"
print("κ°€μ€‘μΉ˜ 전체 λ‘œλ“œ μ™„λ£Œ")
except Exception as e:
print(f"κ°€μ€‘μΉ˜ λ‘œλ“œ μ‹€νŒ¨: {e}")
model.eval()
# ──────────────────────────────────────────────────────────────
# 3. ImageNet 1000개 클래슀 이름 λ‘œλ“œ
# κ°•μ•„μ§€(n02085620~), 고양이(n02123045~), 사과(948), μ‚¬λžŒ μ—†μŒ*
# *ImageNet은 μ‚¬λžŒ 클래슀λ₯Ό ν¬ν•¨ν•˜μ§€ μ•ŠμŒ
# ──────────────────────────────────────────────────────────────
ID2LABEL = {}
# 1μˆœμœ„: config.json
try:
with open("config.json") as f:
cfg = json.load(f)
ID2LABEL = {int(k): v for k, v in cfg.get("id2label", {}).items()}
if ID2LABEL:
print(f"config.json: {len(ID2LABEL)}개 클래슀")
except Exception:
pass
# 2μˆœμœ„: ν—ˆκΉ…νŽ˜μ΄μŠ€ ViT config (ImageNet 1000 라벨 동일)
if not ID2LABEL:
try:
resp = requests.get(
"https://huggingface.co/google/vit-base-patch16-224/raw/main/config.json",
timeout=15,
)
vit_cfg = resp.json()
ID2LABEL = {int(k): v for k, v in vit_cfg.get("id2label", {}).items()}
print(f"ν—ˆκΉ…νŽ˜μ΄μŠ€: {len(ID2LABEL)}개 클래슀 λ‘œλ“œ")
except Exception as e:
print(f"클래슀 이름 λ‘œλ“œ μ‹€νŒ¨: {e}")
LABEL_STATUS = f"ImageNet {len(ID2LABEL)}개 클래슀" if ID2LABEL else "클래슀 이름 μ—†μŒ"
# ──────────────────────────────────────────────────────────────
# 4. μ „μ²˜λ¦¬ (torchvision AlexNet_Weights.DEFAULT와 동일)
# ──────────────────────────────────────────────────────────────
TRANSFORM = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ──────────────────────────────────────────────────────────────
# 5. μΆ”λ‘  ν•¨μˆ˜
# ──────────────────────────────────────────────────────────────
def predict(image: Image.Image) -> dict:
if image is None:
return {}
tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=-1)[0]
top5_probs, top5_idx = probs.topk(5)
return {
ID2LABEL.get(idx.item(), f"class_{idx.item()}"): round(prob.item(), 4)
for prob, idx in zip(top5_probs, top5_idx)
}
# ──────────────────────────────────────────────────────────────
# 6. Gradio UI
# ──────────────────────────────────────────────────────────────
with gr.Blocks(title="AlexNet β€” λ…Όλ¬Έ μž¬ν˜„") as demo:
gr.Markdown(f"""
## AlexNet β€” λ…Όλ¬Έ μ™„μ „ μž¬ν˜„ 데λͺ¨
**λ…Όλ¬Έ**: ImageNet Classification with Deep CNNs (Krizhevsky et al., NeurIPS 2012)
| ν•­λͺ© | μƒνƒœ |
|------|------|
| κ°€μ€‘μΉ˜ | {WEIGHTS_STATUS} |
| 클래슀 | {LABEL_STATUS} |
> β€» ImageNet은 μ‚¬λžŒ(λ‚¨μž/μ—¬μž) 클래슀λ₯Ό ν¬ν•¨ν•˜μ§€ μ•Šμ•„μš”.
> κ°•μ•„μ§€Β·κ³ μ–‘μ΄Β·μ‚¬κ³ΌΒ·μžλ™μ°¨ λ“± 1000개 물체 μΉ΄ν…Œκ³ λ¦¬λ₯Ό μΈμ‹ν•©λ‹ˆλ‹€.
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="μž…λ ₯ 이미지")
run_btn = gr.Button("μ˜ˆμΈ‘ν•˜κΈ°", variant="primary")
with gr.Column():
label_output = gr.Label(num_top_classes=5, label="Top-5 예츑")
with gr.Accordion("인식 κ°€λŠ₯ν•œ μ£Όμš” μΉ΄ν…Œκ³ λ¦¬", open=False):
gr.Markdown("""
**동물**: 개(120μ’…), 고양이(8μ’…), μƒˆ(59μ’…), λ¬Όκ³ κΈ°, λ±€, κ³°, 코끼리 λ“±
**μŒμ‹**: 사과, 레λͺ¬, λ”ΈκΈ°, μ•„μ΄μŠ€ν¬λ¦Ό, ν”Όμž, 버섯 λ“±
**νƒˆκ²ƒ**: μžλ™μ°¨, λ²„μŠ€, κΈ°μ°¨, λΉ„ν–‰κΈ°, λ°°, μ˜€ν† λ°”μ΄ λ“±
**사물**: 의자, μ‹œκ³„, μ»΅, ν‚€λ³΄λ“œ, μ•ˆκ²½, μš°μ‚° λ“±
**μžμ—°**: μ‚°ν˜Έμ΄ˆ, ν™”μ‚°, 폭포, λΉ™ν•˜ λ“±
> μ‚¬λžŒ(λ‚¨μž/μ—¬μž)은 ImageNet 1000 ν΄λž˜μŠ€μ— ν¬ν•¨λ˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
> μ‚¬λžŒ 인식이 ν•„μš”ν•˜λ©΄ CLIP λ˜λŠ” COCO ν•™μŠ΅ λͺ¨λΈμ΄ ν•„μš”ν•΄μš”.
""")
with gr.Accordion("λͺ¨λΈ ꡬ쑰 (λ…Όλ¬Έ Figure 2)", open=False):
gr.Markdown("""
| λ ˆμ΄μ–΄ | 컀널 | 좜λ ₯ shape | λ…Όλ¬Έ μ„Ήμ…˜ |
|--------|------|-----------------|-----------|
| Conv1 | 11Γ—11 stride=4 | (B, 64, 27, 27) | 3.5절 |
| Conv2 | 5Γ—5 | (B, 192, 13, 13) | 3.5절 |
| Conv3 | 3Γ—3 | (B, 384, 13, 13) | 3.5절 |
| Conv4 | 3Γ—3 | (B, 256, 13, 13) | 3.5절 |
| Conv5 | 3Γ—3 | (B, 256, 6, 6) | 3.5절 |
| FC1Β·2 | β€” | (B, 4096) | 4.2절 Dropout 0.5 |
| FC3 | β€” | (B, 1000) | Abstract |
""")
run_btn.click(fn=predict, inputs=image_input, outputs=label_output)
image_input.change(fn=predict, inputs=image_input, outputs=label_output)
if __name__ == "__main__":
demo.launch()