import torch import torchvision.transforms as transforms from PIL import Image import gradio as gr from model import AlexNet CLASSES = ['airplane','automobile','bird','cat','deer', 'dog','frog','horse','ship','truck'] # 모델 로드 model = AlexNet() model.load_state_dict(torch.load('alexnet_cifar10.pth', map_location='cpu')) model.eval() transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)) ]) def predict(image): img = transform(image).unsqueeze(0) with torch.no_grad(): output = model(img) probs = torch.softmax(output, dim=1)[0] return {CLASSES[i]: float(probs[i]) for i in range(10)} demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="AlexNet CIFAR-10 분류기", description="비행기, 자동차, 새, 고양이, 사슴, 개, 개구리, 말, 배, 트럭. 총 10가지 사진을 분류해주는 모델입니다. 사진을 넣어주세요" ) demo.launch()