dgomes03 commited on
Commit
cc27db7
·
verified ·
1 Parent(s): e610a4f

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +35 -0
  2. model.py +24 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ from model import CNN
8
+
9
+ # Load model
10
+ model = CNN()
11
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
12
+ model.eval()
13
+
14
+ # Inference function
15
+ def predict_digit(image):
16
+ image = image.convert("L").resize((28, 28)) # Convert to grayscale
17
+ image = np.array(image) / 255.0 # Normalize
18
+ image = torch.tensor(image).unsqueeze(0).unsqueeze(0).float() # (1, 1, 28, 28)
19
+ with torch.no_grad():
20
+ logits = model(image)
21
+ probs = F.softmax(logits, dim=1).numpy().flatten()
22
+ predicted = np.argmax(probs)
23
+ return {str(i): float(probs[i]) for i in range(10)}
24
+
25
+ # Gradio UI
26
+ interface = gr.Interface(
27
+ fn=predict_digit,
28
+ inputs=gr.Image(type="pil", tool="editor", label="Draw or Upload a Digit"),
29
+ outputs=gr.Label(num_top_classes=3),
30
+ title="Handwritten Digit Classifier",
31
+ description="Draw a digit or upload an image of a handwritten digit."
32
+ )
33
+
34
+ if __name__ == "__main__":
35
+ interface.launch()
model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class CNN(nn.Module):
7
+ def __init__(self):
8
+ super(CNN, self).__init__()
9
+ self.conv1 = nn.Conv2d(1, 32, kernel_size=3) # (1, 28, 28) -> (32, 26, 26)
10
+ self.pool1 = nn.MaxPool2d(2, 2) # (32, 26, 26) -> (32, 13, 13)
11
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3) # (32, 13, 13) -> (64, 11, 11)
12
+ self.pool2 = nn.MaxPool2d(2, 2) # (64, 11, 11) -> (64, 5, 5)
13
+ self.fc1 = nn.Linear(64 * 5 * 5, 64)
14
+ self.fc2 = nn.Linear(64, 10)
15
+
16
+ def forward(self, x):
17
+ x = F.relu(self.conv1(x))
18
+ x = self.pool1(x)
19
+ x = F.relu(self.conv2(x))
20
+ x = self.pool2(x)
21
+ x = x.view(-1, 64 * 5 * 5)
22
+ x = F.relu(self.fc1(x))
23
+ x = self.fc2(x)
24
+ return x