mnist-space / app.py
bit-wander's picture
Initial Space
dbb3a41
import torch, torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from huggingface_hub import hf_hub_download
import gradio as gr
# Download model
ckpt_path = hf_hub_download(repo_id="bit-wander/mnist-demo", filename="model.pt")
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(1,16,3,padding=1), nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16,32,3,padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(32,10)
)
def forward(self,x): return self.net(x)
model = SimpleCNN()
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
transform = T.Compose([
T.Grayscale(num_output_channels=1),
T.Resize((28,28)),
T.ToTensor(),
T.Normalize((0.5,), (0.5,))
])
def predict(img: Image.Image):
x = transform(img).unsqueeze(0)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=-1).squeeze().tolist()
return {str(i): probs[i] for i in range(10)}
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title="MNIST Demo"
)
if __name__ == "__main__":
demo.launch()