Assignment / app.py
osamajan90's picture
final fix
1a8dbe7
Raw
History Blame Contribute Delete
2.15 kB
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
# =========================
# 1. MODEL DEFINITION
# (must match training exactly)
# =========================
def get_densenet121_mnist():
model = models.densenet121(weights=None)
# Grayscale: 1 input channel (same as training)
model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# 10 output classes (same as training)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 10)
return model
# =========================
# 2. LOAD MODEL
# =========================
def load_model():
model = get_densenet121_mnist()
state_dict = torch.load("best_densenet_mnist.pth", map_location="cpu")
model.load_state_dict(state_dict, strict=True) # strict=True since arch matches exactly
model.eval()
return model
model = load_model()
# =========================
# 3. PREPROCESS IMAGE
# (must match training transform exactly)
# =========================
preprocess = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((32, 32)), # same as training
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # same as training
])
# =========================
# 4. PREDICT
# =========================
LABELS = [str(i) for i in range(10)]
def predict(image):
image = Image.fromarray(image).convert("L") # ensure grayscale PIL
x = preprocess(image).unsqueeze(0) # (1, 1, 32, 32)
with torch.no_grad():
output = model(x)
probs = torch.softmax(output, dim=1).squeeze()
return {LABELS[i]: float(probs[i]) for i in range(10)}
# =========================
# 5. GRADIO UI
# =========================
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs=gr.Label(num_top_classes=3), # shows top-3 predictions with confidence
title="DenseNet121 MNIST Classifier",
description="Upload a handwritten digit (0–9). Model: DenseNet121 trained on MNIST."
)
demo.launch()