AdityaManojShinde commited on
Commit
bc4ba5a
·
verified ·
1 Parent(s): d92cb8a

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms, models
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ # Load model
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ model = models.resnet18()
12
+ model.fc = nn.Sequential(
13
+ nn.Linear(512, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, 10)
14
+ )
15
+ model.load_state_dict(torch.load("model.pth", map_location=device))
16
+ model = model.to(device)
17
+ model.eval()
18
+
19
+ # Preprocessing
20
+ transform = transforms.Compose(
21
+ [
22
+ transforms.Grayscale(num_output_channels=3),
23
+ transforms.Resize((32, 32)),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
26
+ ]
27
+ )
28
+
29
+
30
+ def predict_digit(image):
31
+ if image is None:
32
+ return {str(i): 0.0 for i in range(10)}
33
+
34
+ # Sketchpad returns a dict with "composite" key (RGBA numpy array)
35
+ # or directly a numpy array depending on Gradio version
36
+ if isinstance(image, dict):
37
+ image = image.get("composite", image.get("layers", [None])[0])
38
+
39
+ if image is None:
40
+ return {str(i): 0.0 for i in range(10)}
41
+
42
+ if not isinstance(image, Image.Image):
43
+ image = Image.fromarray(image.astype(np.uint8))
44
+
45
+ # Convert to grayscale
46
+ image = image.convert("L")
47
+ img_array = np.array(image)
48
+
49
+ # The canvas is white (255) with dark strokes.
50
+ # MNIST expects black background with white digit, so invert.
51
+ img_array = 255 - img_array
52
+
53
+ # Check if the canvas is essentially blank (all near-zero after inversion)
54
+ if img_array.max() < 10:
55
+ return {str(i): 0.0 for i in range(10)}
56
+
57
+ image = Image.fromarray(img_array)
58
+
59
+ img_tensor = transform(image).unsqueeze(0).to(device)
60
+
61
+ # Predict
62
+ with torch.no_grad():
63
+ output = model(img_tensor)
64
+ probabilities = torch.nn.functional.softmax(output, dim=1)[0]
65
+
66
+ confidences = {str(i): float(probabilities[i]) for i in range(10)}
67
+ return confidences
68
+
69
+
70
+ # Create Gradio interface with sketchpad (drawable white canvas)
71
+ interface = gr.Interface(
72
+ fn=predict_digit,
73
+ inputs=gr.Sketchpad(
74
+ label="Draw a digit (0–9)",
75
+ type="numpy",
76
+ canvas_size=(280, 280),
77
+ brush=gr.Brush(colors=["#000000"], color_mode="fixed", default_size=18),
78
+ ),
79
+ outputs=gr.Label(num_top_classes=10, label="Predictions"),
80
+ title="Handwritten Digit Recognizer",
81
+ description="Draw a digit (0–9) on the white canvas below and click Predict.",
82
+ submit_btn="Predict",
83
+ clear_btn="Clear Canvas",
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ interface.launch(share=True)