AdityaManojShinde's picture
Upload app.py with huggingface_hub
bc4ba5a verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import numpy as np
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18()
model.fc = nn.Sequential(
nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 10)
)
model.load_state_dict(torch.load("model.pth", map_location=device))
model = model.to(device)
model.eval()
# Preprocessing
transform = transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
def predict_digit(image):
if image is None:
return {str(i): 0.0 for i in range(10)}
# Sketchpad returns a dict with "composite" key (RGBA numpy array)
# or directly a numpy array depending on Gradio version
if isinstance(image, dict):
image = image.get("composite", image.get("layers", [None])[0])
if image is None:
return {str(i): 0.0 for i in range(10)}
if not isinstance(image, Image.Image):
image = Image.fromarray(image.astype(np.uint8))
# Convert to grayscale
image = image.convert("L")
img_array = np.array(image)
# The canvas is white (255) with dark strokes.
# MNIST expects black background with white digit, so invert.
img_array = 255 - img_array
# Check if the canvas is essentially blank (all near-zero after inversion)
if img_array.max() < 10:
return {str(i): 0.0 for i in range(10)}
image = Image.fromarray(img_array)
img_tensor = transform(image).unsqueeze(0).to(device)
# Predict
with torch.no_grad():
output = model(img_tensor)
probabilities = torch.nn.functional.softmax(output, dim=1)[0]
confidences = {str(i): float(probabilities[i]) for i in range(10)}
return confidences
# Create Gradio interface with sketchpad (drawable white canvas)
interface = gr.Interface(
fn=predict_digit,
inputs=gr.Sketchpad(
label="Draw a digit (0–9)",
type="numpy",
canvas_size=(280, 280),
brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=18),
),
outputs=gr.Label(num_top_classes=10, label="Predictions"),
title="Handwritten Digit Recognizer",
description="Draw a digit (0–9) on the white canvas below and click Predict.",
submit_btn="Predict",
clear_btn="Clear Canvas",
)
if __name__ == "__main__":
interface.launch(share=True)