Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from model import AlexNet | |
| CLASSES = ['airplane','automobile','bird','cat','deer', | |
| 'dog','frog','horse','ship','truck'] | |
| # ๋ชจ๋ธ ๋ก๋ | |
| model = AlexNet() | |
| model.load_state_dict(torch.load('alexnet_cifar10.pth', map_location='cpu')) | |
| model.eval() | |
| transform = transforms.Compose([ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)) | |
| ]) | |
| def predict(image): | |
| img = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(img) | |
| probs = torch.softmax(output, dim=1)[0] | |
| return {CLASSES[i]: float(probs[i]) for i in range(10)} | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="AlexNet CIFAR-10 ๋ถ๋ฅ๊ธฐ", | |
| description="๋นํ๊ธฐ, ์๋์ฐจ, ์, ๊ณ ์์ด, ์ฌ์ด, ๊ฐ, ๊ฐ๊ตฌ๋ฆฌ, ๋ง, ๋ฐฐ, ํธ๋ญ. ์ด 10๊ฐ์ง ์ฌ์ง์ ๋ถ๋ฅํด์ฃผ๋ ๋ชจ๋ธ์ ๋๋ค. ์ฌ์ง์ ๋ฃ์ด์ฃผ์ธ์" | |
| ) | |
| demo.launch() |