Jugal-sheth commited on
Commit
52a915d
·
1 Parent(s): e0d6178

Add application file

Browse files
Files changed (1) hide show
  1. app.py +41 -0
app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from model import model, classes
4
+ from torchvision import transforms
5
+
6
+ checkpoint = torch.load('mnist_model.pth', map_location=torch.device('cpu'))
7
+
8
+ # Load the state dictionary into model
9
+ model.load_state_dict(checkpoint['model_state_dict'])
10
+
11
+ # Set your model to evaluation mode
12
+ model.eval()
13
+
14
+
15
+ def preprocess_image(image):
16
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
17
+ threshold = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
18
+ resized = cv2.resize(threshold, (28, 28), interpolation=cv2.INTER_AREA)
19
+ tensor = transforms.ToTensor()(resized).unsqueeze(0)
20
+ tensor = transforms.Normalize((0.5,), (0.5,))(tensor)
21
+ return tensor
22
+
23
+
24
+ def classify(image):
25
+ tensor = preprocess_image(image)
26
+ with torch.no_grad():
27
+ output = model(tensor)
28
+ prediction = output.argmax(dim=1, keepdim=True).item()
29
+ return str(prediction) # Convert prediction to string
30
+
31
+
32
+ iface = gr.Interface(
33
+ fn=classify,
34
+ inputs="sketchpad",
35
+ outputs='label',
36
+ theme="huggingface",
37
+ title="Digit Recognition",
38
+ description="Draw a Digit 0-9 and the algorithm will detect it in real time!",
39
+ article="<p style='text-align: center'>Digit Recognition | Demo Model by Jugal</p>",
40
+ live=True)
41
+ iface.launch(debug=True)