# app.py import torch import torch.nn.functional as F import gradio as gr import numpy as np from PIL import Image from model import CNN # Load model model = CNN() model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu")) model.eval() # Inference function def predict_digit(image): image = image.convert("L").resize((28, 28)) # Convert to grayscale image = np.array(image) / 255.0 # Normalize image = torch.tensor(image).unsqueeze(0).unsqueeze(0).float() # (1, 1, 28, 28) with torch.no_grad(): logits = model(image) probs = F.softmax(logits, dim=1).numpy().flatten() predicted = np.argmax(probs) return {str(i): float(probs[i]) for i in range(10)} # Gradio UI interface = gr.Interface( fn=predict_digit, inputs=gr.Image(type="pil", shape=(280, 280), tool="editor"), outputs=gr.Label(num_top_classes=3), title="Handwritten Digit Classifier", description="Draw a digit or upload a digit image." ) if __name__ == "__main__": interface.launch()