Spaces:
Sleeping
Sleeping
File size: 2,150 Bytes
f61af91 7f83de3 c95b072 c1a4773 4867772 f61af91 4867772 c95b072 c1a4773 4867772 c1a4773 c95b072 c1a4773 c95b072 c1a4773 c95b072 c1a4773 c95b072 c1a4773 1a8dbe7 c1a4773 c95b072 7f83de3 9bd88bd 4867772 c95b072 c1a4773 4867772 c1a4773 9bd88bd 4867772 c95b072 4867772 c1a4773 d3c7673 c1a4773 c95b072 c1a4773 c95b072 c1a4773 7f83de3 4867772 c95b072 4867772 7f83de3 c1a4773 7f83de3 f61af91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# =========================
# 1. MODEL DEFINITION
# (must match training exactly)
# =========================
def get_densenet121_mnist():
model = models.densenet121(weights=None)
# Grayscale: 1 input channel (same as training)
model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# 10 output classes (same as training)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 10)
return model
# =========================
# 2. LOAD MODEL
# =========================
def load_model():
model = get_densenet121_mnist()
state_dict = torch.load("best_densenet_mnist.pth", map_location="cpu")
model.load_state_dict(state_dict, strict=True) # strict=True since arch matches exactly
model.eval()
return model
model = load_model()
# =========================
# 3. PREPROCESS IMAGE
# (must match training transform exactly)
# =========================
preprocess = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((32, 32)), # same as training
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # same as training
])
# =========================
# 4. PREDICT
# =========================
LABELS = [str(i) for i in range(10)]
def predict(image):
image = Image.fromarray(image).convert("L") # ensure grayscale PIL
x = preprocess(image).unsqueeze(0) # (1, 1, 32, 32)
with torch.no_grad():
output = model(x)
probs = torch.softmax(output, dim=1).squeeze()
return {LABELS[i]: float(probs[i]) for i in range(10)}
# =========================
# 5. GRADIO UI
# =========================
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=3), # shows top-3 predictions with confidence
title="DenseNet121 MNIST Classifier",
description="Upload a handwritten digit (0–9). Model: DenseNet121 trained on MNIST."
)
demo.launch() |