vishnuraggav commited on
Commit
21da3b2
·
1 Parent(s): 5c78476

Initial Commit

Browse files
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ''' Import Modules '''
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.models as models
5
+ import torchvision.transforms as T
6
+
7
+ import gradio as gr
8
+ import PIL.Image as Image
9
+ import numpy as np
10
+ import os
11
+
12
+ ''' Setup '''
13
+ weights_path = "vit_base_state_dict.pth"
14
+ model = models.vit_b_16()
15
+ model.heads = nn.Sequential(nn.Linear(768, 29))
16
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
17
+
18
+ transform = T.Compose([
19
+ T.Resize((224, 224)),
20
+ T.ToTensor(),
21
+ T.Normalize(mean=[0.5 for _ in range(3)], std=[0.5 for _ in range(3)])
22
+ ])
23
+
24
+ label_to_idx = {
25
+ 0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F', 6: 'G', 7: 'H',
26
+ 8: 'I', 9: 'J', 10: 'K', 11: 'L', 12: 'M', 13: 'N', 14: 'O',
27
+ 15: 'P', 16: 'Q', 17: 'R', 18: 'S', 19: 'T', 20: 'U', 21: 'V',
28
+ 22: 'W', 23: 'X', 24: 'Y', 25: 'Z', 26: 'del', 27: 'nothing', 28: 'space'
29
+ }
30
+
31
+ def main(input_image: np.array):
32
+ pil_image = Image.fromarray(input_image)
33
+ tensor_image = transform(pil_image)
34
+
35
+ with torch.inference_mode():
36
+ pred = model(tensor_image.unsqueeze(0)).squeeze(0)
37
+ pred = torch.argmax(torch.softmax(pred, dim=0), dim=0)
38
+ pred = label_to_idx[pred.item()]
39
+
40
+ return pred
41
+
42
+ img_files = os.listdir("examples")
43
+ img_files.remove(".DS_Store")
44
+ examples = ["examples/"+img_name for img_name in img_files]
45
+
46
+ app = gr.Interface(
47
+ fn=main,
48
+ inputs=gr.Image(),
49
+ outputs=gr.Textbox(),
50
+ examples=examples
51
+ )
52
+
53
+ app.launch()
examples/.DS_Store ADDED
Binary file (8.2 kB). View file
 
examples/A_test.jpg ADDED
examples/E_test.jpg ADDED
examples/F_test.jpg ADDED
examples/L_test.jpg ADDED
examples/N_test.jpg ADDED
examples/S_test.jpg ADDED
examples/T_test.jpg ADDED
examples/U_test.jpg ADDED
examples/W_test.jpg ADDED
examples/space_test.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ gradio
4
+ torchvision
vit_base_state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2c26579fdbf54064e1d514e8b0e47d7e666450fe764d6dcfd91b8699ef979bf
3
+ size 343346066