import gradio as gr import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import numpy as np # Load model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = models.resnet18() model.fc = nn.Sequential( nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 10) ) model.load_state_dict(torch.load("model.pth", map_location=device)) model = model.to(device) model.eval() # Preprocessing transform = transforms.Compose( [ transforms.Grayscale(num_output_channels=3), transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) def predict_digit(image): if image is None: return {str(i): 0.0 for i in range(10)} # Sketchpad returns a dict with "composite" key (RGBA numpy array) # or directly a numpy array depending on Gradio version if isinstance(image, dict): image = image.get("composite", image.get("layers", [None])[0]) if image is None: return {str(i): 0.0 for i in range(10)} if not isinstance(image, Image.Image): image = Image.fromarray(image.astype(np.uint8)) # Convert to grayscale image = image.convert("L") img_array = np.array(image) # The canvas is white (255) with dark strokes. # MNIST expects black background with white digit, so invert. img_array = 255 - img_array # Check if the canvas is essentially blank (all near-zero after inversion) if img_array.max() < 10: return {str(i): 0.0 for i in range(10)} image = Image.fromarray(img_array) img_tensor = transform(image).unsqueeze(0).to(device) # Predict with torch.no_grad(): output = model(img_tensor) probabilities = torch.nn.functional.softmax(output, dim=1)[0] confidences = {str(i): float(probabilities[i]) for i in range(10)} return confidences # Create Gradio interface with sketchpad (drawable white canvas) interface = gr.Interface( fn=predict_digit, inputs=gr.Sketchpad( label="Draw a digit (0–9)", type="numpy", canvas_size=(280, 280), brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=18), ), outputs=gr.Label(num_top_classes=10, label="Predictions"), title="Handwritten Digit Recognizer", description="Draw a digit (0–9) on the white canvas below and click Predict.", submit_btn="Predict", clear_btn="Clear Canvas", ) if __name__ == "__main__": interface.launch(share=True)