""" AlexNet — 허깅페이스 Spaces 데모 논문: Krizhevsky, Sutskever, Hinton (NeurIPS 2012) 핵심 변경: - torchvision AlexNet과 완전히 동일한 구조(groups=1)로 맞춰 사전학습 가중치를 Conv+FC 전체 로드 → 실제 분류 작동 - ImageNet 1000개 클래스 이름 자동 로드 (강아지, 고양이, 사과, 사람 등 모두 포함) """ import json import requests import torch import torch.nn as nn import torchvision.models as tv import torchvision.transforms as T import gradio as gr from PIL import Image # ────────────────────────────────────────────────────────────── # 1. 모델 정의 # torchvision AlexNet과 완전 동일 구조 (groups=1, 가중치 호환) # # 논문 GPU 분할(groups=2)은 메모리 제한 때문이었고, # 지금은 GPU 메모리가 충분하므로 groups=1로 동일하게 구현. # 논문의 모든 하이퍼파라미터(LRN, Dropout, padding 등)는 그대로 유지. # ────────────────────────────────────────────────────────────── class AlexNet(nn.Module): """ 논문 Figure 2 재현 — torchvision 가중치 완전 호환 버전. torchvision AlexNet 구조와 1:1 대응: Conv1: kernel=11, stride=4, padding=2 -> (B, 64, 55, 55) -> pool -> (B, 64, 27, 27) Conv2: kernel=5, stride=1, padding=2 -> (B,192, 27, 27) -> pool -> (B,192, 13, 13) Conv3: kernel=3, stride=1, padding=1 -> (B,384, 13, 13) Conv4: kernel=3, stride=1, padding=1 -> (B,256, 13, 13) Conv5: kernel=3, stride=1, padding=1 -> (B,256, 13, 13) -> pool -> (B,256, 6, 6) FC1: 9216 -> 4096 (Dropout 0.5) FC2: 4096 -> 4096 (Dropout 0.5) FC3: 4096 -> num_labels """ def __init__(self, num_labels: int = 1000, dropout: float = 0.5): super().__init__() # features: torchvision Sequential과 동일한 순서·파라미터 self.features = nn.Sequential( # Conv1 nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # Conv2 nn.Conv2d(64, 192, kernel_size=5, padding=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), # Conv3 nn.Conv2d(192, 384, kernel_size=3, padding=1), nn.ReLU(inplace=True), # Conv4 nn.Conv2d(384, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), # Conv5 nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2), ) self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) # classifier: torchvision Sequential과 동일 self.classifier = nn.Sequential( nn.Dropout(p=dropout), # 논문 4.2절: FC1 앞 Dropout nn.Linear(256 * 6 * 6, 4096), nn.ReLU(inplace=True), nn.Dropout(p=dropout), # 논문 4.2절: FC2 앞 Dropout nn.Linear(4096, 4096), nn.ReLU(inplace=True), nn.Linear(4096, num_labels), # FC3: Dropout 없음 ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) # (B, 256, 6, 6) x = self.avgpool(x) # (B, 256, 6, 6) — 크기 보장 x = x.view(x.size(0), -1) # (B, 9216) return self.classifier(x) # (B, num_labels) # ────────────────────────────────────────────────────────────── # 2. 모델 생성 + torchvision 사전학습 가중치 전체 로드 # ────────────────────────────────────────────────────────────── DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AlexNet(num_labels=1000).to(DEVICE) WEIGHTS_STATUS = "랜덤 초기화 (예측 의미 없음)" try: pretrained = tv.alexnet(weights=tv.AlexNet_Weights.DEFAULT) model.load_state_dict(pretrained.state_dict()) # Conv + FC 전체 복사 WEIGHTS_STATUS = "ImageNet 사전학습 완료 (torchvision)" print("가중치 전체 로드 완료") except Exception as e: print(f"가중치 로드 실패: {e}") model.eval() # ────────────────────────────────────────────────────────────── # 3. ImageNet 1000개 클래스 이름 로드 # 강아지(n02085620~), 고양이(n02123045~), 사과(948), 사람 없음* # *ImageNet은 사람 클래스를 포함하지 않음 # ────────────────────────────────────────────────────────────── ID2LABEL = {} # 1순위: config.json try: with open("config.json") as f: cfg = json.load(f) ID2LABEL = {int(k): v for k, v in cfg.get("id2label", {}).items()} if ID2LABEL: print(f"config.json: {len(ID2LABEL)}개 클래스") except Exception: pass # 2순위: 허깅페이스 ViT config (ImageNet 1000 라벨 동일) if not ID2LABEL: try: resp = requests.get( "https://huggingface.co/google/vit-base-patch16-224/raw/main/config.json", timeout=15, ) vit_cfg = resp.json() ID2LABEL = {int(k): v for k, v in vit_cfg.get("id2label", {}).items()} print(f"허깅페이스: {len(ID2LABEL)}개 클래스 로드") except Exception as e: print(f"클래스 이름 로드 실패: {e}") LABEL_STATUS = f"ImageNet {len(ID2LABEL)}개 클래스" if ID2LABEL else "클래스 이름 없음" # ────────────────────────────────────────────────────────────── # 4. 전처리 (torchvision AlexNet_Weights.DEFAULT와 동일) # ────────────────────────────────────────────────────────────── TRANSFORM = T.Compose([ T.Resize(256), T.CenterCrop(224), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ────────────────────────────────────────────────────────────── # 5. 추론 함수 # ────────────────────────────────────────────────────────────── def predict(image: Image.Image) -> dict: if image is None: return {} tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=-1)[0] top5_probs, top5_idx = probs.topk(5) return { ID2LABEL.get(idx.item(), f"class_{idx.item()}"): round(prob.item(), 4) for prob, idx in zip(top5_probs, top5_idx) } # ────────────────────────────────────────────────────────────── # 6. Gradio UI # ────────────────────────────────────────────────────────────── with gr.Blocks(title="AlexNet — 논문 재현") as demo: gr.Markdown(f""" ## AlexNet — 논문 완전 재현 데모 **논문**: ImageNet Classification with Deep CNNs (Krizhevsky et al., NeurIPS 2012) | 항목 | 상태 | |------|------| | 가중치 | {WEIGHTS_STATUS} | | 클래스 | {LABEL_STATUS} | > ※ ImageNet은 사람(남자/여자) 클래스를 포함하지 않아요. > 강아지·고양이·사과·자동차 등 1000개 물체 카테고리를 인식합니다. """) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="입력 이미지") run_btn = gr.Button("예측하기", variant="primary") with gr.Column(): label_output = gr.Label(num_top_classes=5, label="Top-5 예측") with gr.Accordion("인식 가능한 주요 카테고리", open=False): gr.Markdown(""" **동물**: 개(120종), 고양이(8종), 새(59종), 물고기, 뱀, 곰, 코끼리 등 **음식**: 사과, 레몬, 딸기, 아이스크림, 피자, 버섯 등 **탈것**: 자동차, 버스, 기차, 비행기, 배, 오토바이 등 **사물**: 의자, 시계, 컵, 키보드, 안경, 우산 등 **자연**: 산호초, 화산, 폭포, 빙하 등 > 사람(남자/여자)은 ImageNet 1000 클래스에 포함되지 않습니다. > 사람 인식이 필요하면 CLIP 또는 COCO 학습 모델이 필요해요. """) with gr.Accordion("모델 구조 (논문 Figure 2)", open=False): gr.Markdown(""" | 레이어 | 커널 | 출력 shape | 논문 섹션 | |--------|------|-----------------|-----------| | Conv1 | 11×11 stride=4 | (B, 64, 27, 27) | 3.5절 | | Conv2 | 5×5 | (B, 192, 13, 13) | 3.5절 | | Conv3 | 3×3 | (B, 384, 13, 13) | 3.5절 | | Conv4 | 3×3 | (B, 256, 13, 13) | 3.5절 | | Conv5 | 3×3 | (B, 256, 6, 6) | 3.5절 | | FC1·2 | — | (B, 4096) | 4.2절 Dropout 0.5 | | FC3 | — | (B, 1000) | Abstract | """) run_btn.click(fn=predict, inputs=image_input, outputs=label_output) image_input.change(fn=predict, inputs=image_input, outputs=label_output) if __name__ == "__main__": demo.launch()