import gradio as gr import numpy as np import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image # ========================= # 1. MODEL DEFINITION # (must match training exactly) # ========================= def get_densenet121_mnist(): model = models.densenet121(weights=None) # Grayscale: 1 input channel (same as training) model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) # 10 output classes (same as training) num_ftrs = model.classifier.in_features model.classifier = nn.Linear(num_ftrs, 10) return model # ========================= # 2. LOAD MODEL # ========================= def load_model(): model = get_densenet121_mnist() state_dict = torch.load("best_densenet_mnist.pth", map_location="cpu") model.load_state_dict(state_dict, strict=True) # strict=True since arch matches exactly model.eval() return model model = load_model() # ========================= # 3. PREPROCESS IMAGE # (must match training transform exactly) # ========================= preprocess = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((32, 32)), # same as training transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) # same as training ]) # ========================= # 4. PREDICT # ========================= LABELS = [str(i) for i in range(10)] def predict(image): image = Image.fromarray(image).convert("L") # ensure grayscale PIL x = preprocess(image).unsqueeze(0) # (1, 1, 32, 32) with torch.no_grad(): output = model(x) probs = torch.softmax(output, dim=1).squeeze() return {LABELS[i]: float(probs[i]) for i in range(10)} # ========================= # 5. GRADIO UI # ========================= demo = gr.Interface( fn=predict, inputs=gr.Image(type="numpy"), outputs=gr.Label(num_top_classes=3), # shows top-3 predictions with confidence title="DenseNet121 MNIST Classifier", description="Upload a handwritten digit (0–9). Model: DenseNet121 trained on MNIST." ) demo.launch()