Ritik Kumar commited on
Commit
bd5437d
·
1 Parent(s): 1ad3861

Add application file

Browse files
Files changed (2) hide show
  1. app.py +60 -0
  2. model.pt +3 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Importing essential libraries
2
+ import gradio as gr
3
+ import torch
4
+ from torch import nn
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ # Creating Neural Model class used for training and validation
9
+ class NeuralNetwork(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.flatten = nn.Flatten()
13
+ self.linear_relu_stack = nn.Sequential(
14
+ nn.Linear(28*28, 512),
15
+ nn.ReLU(),
16
+ nn.Linear(512, 512),
17
+ nn.ReLU(),
18
+ nn.Linear(512, 10),
19
+ )
20
+
21
+ def forward(self, x):
22
+ x = self.flatten(x)
23
+ logits = self.linear_relu_stack(x)
24
+ return logits
25
+
26
+ # Intializing and loading the saved model
27
+ model = NeuralNetwork()
28
+ state_dict = torch.load('model.pt', map_location='cpu')
29
+ model.load_state_dict(state_dict, strict=False)
30
+ model.eval()
31
+
32
+ # Preprocessing input value for giving to function
33
+ def preprocess_input(input):
34
+ input = input['composite']
35
+
36
+ # Convert the image data to a PIL Image
37
+ input = Image.fromarray(input)
38
+
39
+ # Resize the image to match the input size expected by your model
40
+ input = input.resize((28, 28))
41
+
42
+ # Convert the image to grayscale
43
+ input = input.convert('L')
44
+
45
+ # Flatten the pixel values
46
+ input = np.array(input)
47
+
48
+ return input
49
+
50
+ # Define a predict function
51
+ def predict(img):
52
+ x = torch.tensor(preprocess_input(img), dtype=torch.float32)
53
+ with torch.no_grad():
54
+ return model(x.unsqueeze(0)).argmax().item()
55
+
56
+ # Design UI
57
+ gr.Interface(fn=predict,
58
+ inputs="sketchpad",
59
+ outputs="label",
60
+ live=True).launch(share=True)
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2747d4ed523df909a0f83c65be46375580bb86176f5921c53c3141f03c22a8a7
3
+ size 2681332