File size: 2,150 Bytes
f61af91
7f83de3
c95b072
 
c1a4773
 
4867772
f61af91
4867772
c95b072
c1a4773
4867772
c1a4773
 
c95b072
c1a4773
 
c95b072
c1a4773
 
 
c95b072
c1a4773
c95b072
 
 
 
 
 
c1a4773
1a8dbe7
c1a4773
c95b072
 
 
 
7f83de3
9bd88bd
4867772
c95b072
c1a4773
4867772
c1a4773
 
 
 
 
 
9bd88bd
4867772
 
c95b072
4867772
c1a4773
 
d3c7673
c1a4773
 
c95b072
 
 
c1a4773
c95b072
c1a4773
7f83de3
 
4867772
c95b072
4867772
7f83de3
 
c1a4773
 
 
 
7f83de3
f61af91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image


# =========================
# 1. MODEL DEFINITION
# (must match training exactly)
# =========================
def get_densenet121_mnist():
    model = models.densenet121(weights=None)

    # Grayscale: 1 input channel (same as training)
    model.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # 10 output classes (same as training)
    num_ftrs = model.classifier.in_features
    model.classifier = nn.Linear(num_ftrs, 10)

    return model


# =========================
# 2. LOAD MODEL
# =========================
def load_model():
    model = get_densenet121_mnist()
    state_dict = torch.load("best_densenet_mnist.pth", map_location="cpu")
    model.load_state_dict(state_dict, strict=True)  # strict=True since arch matches exactly
    model.eval()
    return model

model = load_model()


# =========================
# 3. PREPROCESS IMAGE
# (must match training transform exactly)
# =========================
preprocess = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((32, 32)),        # same as training
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # same as training
])


# =========================
# 4. PREDICT
# =========================
LABELS = [str(i) for i in range(10)]

def predict(image):
    image = Image.fromarray(image).convert("L")  # ensure grayscale PIL
    x = preprocess(image).unsqueeze(0)           # (1, 1, 32, 32)

    with torch.no_grad():
        output = model(x)
        probs = torch.softmax(output, dim=1).squeeze()

    return {LABELS[i]: float(probs[i]) for i in range(10)}


# =========================
# 5. GRADIO UI
# =========================
demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Label(num_top_classes=3),   # shows top-3 predictions with confidence
    title="DenseNet121 MNIST Classifier",
    description="Upload a handwritten digit (0–9). Model: DenseNet121 trained on MNIST."
)

demo.launch()