pallabi1608's picture
Update app.py
3886447 verified
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
# Define the network
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network = Net().to(device)
network.load_state_dict(torch.load("model.pth", map_location=device))
network.eval()
def predict_digit(image):
"""Predict digit from image"""
if image is None:
return None
# Handle Gradio 6.0 dict format
if isinstance(image, dict):
image = image['composite']
img = Image.fromarray(image).convert('L')
img = img.resize((28, 28))
img_array = np.array(img, dtype=np.float32) / 255.0
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0).to(device)
img_tensor = (img_tensor - 0.1307) / 0.3081
with torch.no_grad():
output = network(img_tensor)
probs = torch.exp(output[0])
pred = output.data.max(1)[1][0].item()
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
# Plot 1: Drawn digit
ax1.imshow(img, cmap='gray')
ax1.set_title(f'Predicted: {int(pred)}', fontsize=16, fontweight='bold')
ax1.axis('off')
# Plot 2: Confidence bars
digits = list(range(10))
confidences = [float(probs[i].item() * 100) for i in range(10)]
colors = ['#4CAF50' if i == pred else '#FF6B6B' for i in range(10)]
ax2.bar(digits, confidences, color=colors, alpha=0.7) # ✅ FIXED: Added closing parenthesis
ax2.set_xlabel('Digit', fontsize=12)
ax2.set_ylabel('Confidence (%)', fontsize=12)
ax2.set_title('Confidence Scores', fontsize=12, fontweight='bold')
ax2.set_ylim([0, 100])
ax2.set_xticks(digits)
ax2.grid(axis='y', alpha=0.3)
plt.tight_layout()
return fig
# Create interface
with gr.Blocks(title="MNIST Digit Recognition") as demo:
gr.Markdown("# 🔢 MNIST Digit Recognition")
gr.Markdown("Draw a digit (0-9) and the model will predict it!")
with gr.Row():
canvas = gr.Sketchpad(label="Draw Here", type="numpy")
output = gr.Plot(label="Prediction")
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
clear_btn = gr.Button("Clear")
predict_btn.click(fn=predict_digit, inputs=canvas, outputs=output)
clear_btn.click(lambda: None, outputs=canvas)
if __name__ == "__main__":
demo.launch()