HamzaNaser commited on
Commit
d4c9145
·
verified ·
1 Parent(s): 7495ed7

Upload 4 files

Browse files
Files changed (4) hide show
  1. 1.png +0 -0
  2. 2.png +0 -0
  3. 3.png +0 -0
  4. app.py +22 -4
1.png ADDED
2.png ADDED
3.png ADDED
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
5
 
@@ -19,17 +20,34 @@ transform = Compose([
19
  ])
20
 
21
  def predict(img):
22
- img = Image.fromarray(img.astype('uint8'), 'RGB')
 
 
23
  img = transform(img)
24
  img = img.unsqueeze(0)
25
 
26
- prediction = model(img).argmax(axis=1)
27
- return f'Model prediction is {prediction[0]}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  demo = gr.Interface(
30
  fn=predict,
31
  inputs=["image"],
32
- outputs=["text"],
 
33
  )
34
 
35
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
  from PIL import Image
5
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
6
 
 
20
  ])
21
 
22
  def predict(img):
23
+ labels = list(range(10))
24
+ if isinstance(img, np.ndarray):
25
+ img = Image.fromarray(img.astype('uint8'), 'RGB')
26
  img = transform(img)
27
  img = img.unsqueeze(0)
28
 
29
+ with torch.inference_mode():
30
+ prediction = torch.softmax(model(img),dim=1)[0]
31
+
32
+
33
+ result = { num:float(prob.numpy()) for num, prob in enumerate(prediction)}
34
+
35
+
36
+ return result
37
+
38
+
39
+ example_images = [
40
+ "1.png", # Make sure these paths are correct
41
+ "2.png",
42
+ "3.png"
43
+ ]
44
+
45
 
46
  demo = gr.Interface(
47
  fn=predict,
48
  inputs=["image"],
49
+ outputs=[gr.Label(num_top_classes=5, label="Predictions")],
50
+ examples=example_images,
51
  )
52
 
53
  demo.launch()