| import gradio as gr
|
| import torch
|
| import torch.nn as nn
|
| from torchvision import transforms, models
|
| from PIL import Image
|
| import numpy as np
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| model = models.resnet18()
|
| model.fc = nn.Sequential(
|
| nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 10)
|
| )
|
| model.load_state_dict(torch.load("model.pth", map_location=device))
|
| model = model.to(device)
|
| model.eval()
|
|
|
|
|
| transform = transforms.Compose(
|
| [
|
| transforms.Grayscale(num_output_channels=3),
|
| transforms.Resize((32, 32)),
|
| transforms.ToTensor(),
|
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
| ]
|
| )
|
|
|
|
|
| def predict_digit(image):
|
| if image is None:
|
| return {str(i): 0.0 for i in range(10)}
|
|
|
|
|
|
|
| if isinstance(image, dict):
|
| image = image.get("composite", image.get("layers", [None])[0])
|
|
|
| if image is None:
|
| return {str(i): 0.0 for i in range(10)}
|
|
|
| if not isinstance(image, Image.Image):
|
| image = Image.fromarray(image.astype(np.uint8))
|
|
|
|
|
| image = image.convert("L")
|
| img_array = np.array(image)
|
|
|
|
|
|
|
| img_array = 255 - img_array
|
|
|
|
|
| if img_array.max() < 10:
|
| return {str(i): 0.0 for i in range(10)}
|
|
|
| image = Image.fromarray(img_array)
|
|
|
| img_tensor = transform(image).unsqueeze(0).to(device)
|
|
|
|
|
| with torch.no_grad():
|
| output = model(img_tensor)
|
| probabilities = torch.nn.functional.softmax(output, dim=1)[0]
|
|
|
| confidences = {str(i): float(probabilities[i]) for i in range(10)}
|
| return confidences
|
|
|
|
|
|
|
| interface = gr.Interface(
|
| fn=predict_digit,
|
| inputs=gr.Sketchpad(
|
| label="Draw a digit (0–9)",
|
| type="numpy",
|
| canvas_size=(280, 280),
|
| brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=18),
|
| ),
|
| outputs=gr.Label(num_top_classes=10, label="Predictions"),
|
| title="Handwritten Digit Recognizer",
|
| description="Draw a digit (0–9) on the white canvas below and click Predict.",
|
| submit_btn="Predict",
|
| clear_btn="Clear Canvas",
|
| )
|
|
|
| if __name__ == "__main__":
|
| interface.launch(share=True)
|
|
|