HamzaNaser commited on
Commit
6e48ac8
·
verified ·
1 Parent(s): 915fe6e
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
5
+
6
+
7
+
8
+ print('loading model..')
9
+ model = torch.load('model.pth')
10
+ model.eval()
11
+ print('loaded.')
12
+
13
+
14
+ transform = Compose([
15
+ Resize((300,300)),
16
+ ToTensor(),
17
+ Normalize(mean=[0.485, 0.456, 0.406],
18
+ std=[0.229, 0.224, 0.225]),
19
+ ])
20
+
21
+ def predict(img):
22
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
23
+ img = transform(img)
24
+ img = img.unsqueeze(0)
25
+
26
+ prediction = model(img).argmax(axis=1)
27
+ return f'Model prediction is {prediction[0]}'
28
+
29
+ demo = gr.Interface(
30
+ fn=predict,
31
+ inputs=["image"],
32
+ outputs=["text"],
33
+ )
34
+
35
+ demo.launch()
36
+