Spaces:
Sleeping
Sleeping
| import torch, torch.nn as nn | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # Download model | |
| ckpt_path = hf_hub_download(repo_id="bit-wander/mnist-demo", filename="model.pt") | |
| class SimpleCNN(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(1,16,3,padding=1), nn.ReLU(), | |
| nn.MaxPool2d(2), | |
| nn.Conv2d(16,32,3,padding=1), nn.ReLU(), | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Linear(32,10) | |
| ) | |
| def forward(self,x): return self.net(x) | |
| model = SimpleCNN() | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| model.load_state_dict(ckpt['model_state_dict']) | |
| model.eval() | |
| transform = T.Compose([ | |
| T.Grayscale(num_output_channels=1), | |
| T.Resize((28,28)), | |
| T.ToTensor(), | |
| T.Normalize((0.5,), (0.5,)) | |
| ]) | |
| def predict(img: Image.Image): | |
| x = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| probs = torch.softmax(logits, dim=-1).squeeze().tolist() | |
| return {str(i): probs[i] for i in range(10)} | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title="MNIST Demo" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |