i4ata commited on
Commit
26a33b7
·
1 Parent(s): 05b6e9c

smol update

Browse files
Files changed (4) hide show
  1. __pycache__/model.cpython-310.pyc +0 -0
  2. app.py +23 -8
  3. classname.txt +3 -0
  4. main.py +4 -0
__pycache__/model.cpython-310.pyc ADDED
Binary file (2.44 kB). View file
 
app.py CHANGED
@@ -1,22 +1,37 @@
1
  import gradio as gr
2
  from PIL import Image
3
- from typing import List, Dict
4
 
5
  import torch
6
 
 
 
7
  class GradioApp:
 
8
  def __init__(self) -> None:
9
- self.model = torch.load('pretrained_vit.pth', map_location='cpu')
10
- def predict(self, img_file: str, classes: List[str]) -> Dict[str, float]:
11
- classes = ['0', '1', '2']
12
- img = self.model.val_transform(Image.open(img_file)).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
13
  with torch.inference_mode():
14
- preds = torch.softmax(self.model(img), dim=1)[0].cpu().numpy()
15
- return {classes[i] : preds[i] for i in range(len(classes))}
 
16
  def launch(self):
 
17
  demo = gr.Interface(
18
  fn=self.predict,
19
- inputs=gr.Image(type='filepath'),
20
  outputs=gr.Label(num_top_classes=3),
21
  )
22
  demo.launch()
 
1
  import gradio as gr
2
  from PIL import Image
3
+ from typing import List, Dict, Union
4
 
5
  import torch
6
 
7
+ from model import ClassifierModel
8
+
9
  class GradioApp:
10
+
11
  def __init__(self) -> None:
12
+
13
+ self.models: Dict[str, Union[str, ClassifierModel]] = {
14
+ 'Custom': 'models/my_vit.pth',
15
+ 'Pretrained': 'models/pretrained_vit.pth'
16
+ }
17
+ with open('classname.txt') as f:
18
+ self.classes: List[str] = [line.strip() for line in f.readlines()]
19
+
20
+ def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
21
+
22
+ if isinstance(self.models[model_name], str):
23
+ self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
24
+
25
+ img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0)
26
  with torch.inference_mode():
27
+ preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy()
28
+ return dict(zip(self.classes, preds))
29
+
30
  def launch(self):
31
+
32
  demo = gr.Interface(
33
  fn=self.predict,
34
+ inputs=[gr.Image(type='filepath'), gr.Radio(('Custom', 'Pretrained'))],
35
  outputs=gr.Label(num_top_classes=3),
36
  )
37
  demo.launch()
classname.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ angular_leaf_spot
2
+ bean_rust
3
+ healthy
main.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import torch
2
+
3
+ a = torch.load('models/pretrained_vit.pth', map_location='cpu')
4
+ print(a)