Pratik45's picture
Add Gradio demo app
43f03c3
"""
Gradio Demo for MNIST CNN Classifier
Hugging Face Space Application
"""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
# Define the model architecture (must match training)
class ConvNet(nn.Module):
"""Convolutional Neural Network for MNIST"""
def __init__(self, dropout_rate=0.3, num_classes=10):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(128)
self.pool = nn.MaxPool2d(2, 2)
self.dropout_conv = nn.Dropout2d(dropout_rate * 0.5)
self.fc1 = nn.Linear(128 * 7 * 7, 256)
self.bn5 = nn.BatchNorm1d(256)
self.dropout1 = nn.Dropout(dropout_rate)
self.fc2 = nn.Linear(256, 128)
self.bn6 = nn.BatchNorm1d(128)
self.dropout2 = nn.Dropout(dropout_rate * 0.5)
self.fc3 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = torch.relu(x)
x = self.pool(x)
x = self.dropout_conv(x)
x = self.conv3(x)
x = self.bn3(x)
x = torch.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = torch.relu(x)
x = self.pool(x)
x = self.dropout_conv(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.bn5(x)
x = torch.relu(x)
x = self.dropout1(x)
x = self.fc2(x)
x = self.bn6(x)
x = torch.relu(x)
x = self.dropout2(x)
x = self.fc3(x)
return x
# Load model
device = torch.device('cpu') # Use CPU for Hugging Face Spaces
model = ConvNet()
# Load the checkpoint
try:
checkpoint = torch.load('best_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print("✓ Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
model.to(device)
model.eval()
# Preprocessing transform
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def predict_digit(image):
"""
Predict the digit from an input image
Args:
image: PIL Image or numpy array
Returns:
Dictionary with predictions and confidences
"""
if image is None:
return None, {str(i): 0.0 for i in range(10)}
# Convert to PIL Image if numpy array
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Preprocess
image_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(image_tensor)
probabilities = torch.softmax(output, dim=1)
# Get prediction and confidence
confidence, predicted = torch.max(probabilities, 1)
predicted_digit = predicted.item()
confidence_pct = confidence.item() * 100
# Create confidence dictionary for all digits
confidences = {str(i): float(probabilities[0][i] * 100) for i in range(10)}
# Return result string and confidence dict
result = f"**Predicted Digit: {predicted_digit}**\n\n**Confidence: {confidence_pct:.2f}%**"
return result, confidences
# Create Gradio interface
demo = gr.Interface(
fn=predict_digit,
inputs=gr.Image(
label="Draw a digit (0-9)",
type="pil",
image_mode="L",
source="canvas",
shape=(280, 280),
brush_radius=15,
invert_colors=True
),
outputs=[
gr.Markdown(label="Prediction"),
gr.Label(label="Confidence Scores", num_top_classes=10)
],
title="🎯 MNIST Digit Recognition",
description="""
### Draw a digit (0-9) and see the AI predict it!
This model uses a Convolutional Neural Network trained on MNIST dataset, achieving **99.60% accuracy**.
**How to use:**
1. Draw a digit in the box on the left
2. The model will predict which digit you drew
3. See the confidence scores for all digits
**Model Details:**
- Architecture: 4-layer CNN with batch normalization
- Parameters: 271K
- Training: PyTorch with advanced techniques
- Performance: 99.60% test accuracy on MNIST
""",
examples=[
# You can add example images here if you have them
],
theme=gr.themes.Soft(),
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()