import torch, torch.nn as nn import torchvision.transforms as T from PIL import Image from huggingface_hub import hf_hub_download import gradio as gr # Download model ckpt_path = hf_hub_download(repo_id="bit-wander/mnist-demo", filename="model.pt") class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Conv2d(1,16,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16,32,3,padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(32,10) ) def forward(self,x): return self.net(x) model = SimpleCNN() ckpt = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt['model_state_dict']) model.eval() transform = T.Compose([ T.Grayscale(num_output_channels=1), T.Resize((28,28)), T.ToTensor(), T.Normalize((0.5,), (0.5,)) ]) def predict(img: Image.Image): x = transform(img).unsqueeze(0) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits, dim=-1).squeeze().tolist() return {str(i): probs[i] for i in range(10)} demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title="MNIST Demo" ) if __name__ == "__main__": demo.launch()