Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| # ========================= | |
| # 1. MODEL DEFINITION | |
| # (must match training exactly) | |
| # ========================= | |
| def get_densenet121_mnist(): | |
| model = models.densenet121(weights=None) | |
| # Grayscale: 1 input channel (same as training) | |
| model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) | |
| # 10 output classes (same as training) | |
| num_ftrs = model.classifier.in_features | |
| model.classifier = nn.Linear(num_ftrs, 10) | |
| return model | |
| # ========================= | |
| # 2. LOAD MODEL | |
| # ========================= | |
| def load_model(): | |
| model = get_densenet121_mnist() | |
| state_dict = torch.load("best_densenet_mnist.pth", map_location="cpu") | |
| model.load_state_dict(state_dict, strict=True) # strict=True since arch matches exactly | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # ========================= | |
| # 3. PREPROCESS IMAGE | |
| # (must match training transform exactly) | |
| # ========================= | |
| preprocess = transforms.Compose([ | |
| transforms.Grayscale(num_output_channels=1), | |
| transforms.Resize((32, 32)), # same as training | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) # same as training | |
| ]) | |
| # ========================= | |
| # 4. PREDICT | |
| # ========================= | |
| LABELS = [str(i) for i in range(10)] | |
| def predict(image): | |
| image = Image.fromarray(image).convert("L") # ensure grayscale PIL | |
| x = preprocess(image).unsqueeze(0) # (1, 1, 32, 32) | |
| with torch.no_grad(): | |
| output = model(x) | |
| probs = torch.softmax(output, dim=1).squeeze() | |
| return {LABELS[i]: float(probs[i]) for i in range(10)} | |
| # ========================= | |
| # 5. GRADIO UI | |
| # ========================= | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=gr.Label(num_top_classes=3), # shows top-3 predictions with confidence | |
| title="DenseNet121 MNIST Classifier", | |
| description="Upload a handwritten digit (0–9). Model: DenseNet121 trained on MNIST." | |
| ) | |
| demo.launch() |