smol update
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- app.py +23 -8
- classname.txt +3 -0
- main.py +4 -0
__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (2.44 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,22 +1,37 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
-
from typing import List, Dict
|
| 4 |
|
| 5 |
import torch
|
| 6 |
|
|
|
|
|
|
|
| 7 |
class GradioApp:
|
|
|
|
| 8 |
def __init__(self) -> None:
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
with torch.inference_mode():
|
| 14 |
-
preds = torch.softmax(self.
|
| 15 |
-
return
|
|
|
|
| 16 |
def launch(self):
|
|
|
|
| 17 |
demo = gr.Interface(
|
| 18 |
fn=self.predict,
|
| 19 |
-
inputs=gr.Image(type='filepath'),
|
| 20 |
outputs=gr.Label(num_top_classes=3),
|
| 21 |
)
|
| 22 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from PIL import Image
|
| 3 |
+
from typing import List, Dict, Union
|
| 4 |
|
| 5 |
import torch
|
| 6 |
|
| 7 |
+
from model import ClassifierModel
|
| 8 |
+
|
| 9 |
class GradioApp:
|
| 10 |
+
|
| 11 |
def __init__(self) -> None:
|
| 12 |
+
|
| 13 |
+
self.models: Dict[str, Union[str, ClassifierModel]] = {
|
| 14 |
+
'Custom': 'models/my_vit.pth',
|
| 15 |
+
'Pretrained': 'models/pretrained_vit.pth'
|
| 16 |
+
}
|
| 17 |
+
with open('classname.txt') as f:
|
| 18 |
+
self.classes: List[str] = [line.strip() for line in f.readlines()]
|
| 19 |
+
|
| 20 |
+
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
|
| 21 |
+
|
| 22 |
+
if isinstance(self.models[model_name], str):
|
| 23 |
+
self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
|
| 24 |
+
|
| 25 |
+
img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0)
|
| 26 |
with torch.inference_mode():
|
| 27 |
+
preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy()
|
| 28 |
+
return dict(zip(self.classes, preds))
|
| 29 |
+
|
| 30 |
def launch(self):
|
| 31 |
+
|
| 32 |
demo = gr.Interface(
|
| 33 |
fn=self.predict,
|
| 34 |
+
inputs=[gr.Image(type='filepath'), gr.Radio(('Custom', 'Pretrained'))],
|
| 35 |
outputs=gr.Label(num_top_classes=3),
|
| 36 |
)
|
| 37 |
demo.launch()
|
classname.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
angular_leaf_spot
|
| 2 |
+
bean_rust
|
| 3 |
+
healthy
|
main.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
a = torch.load('models/pretrained_vit.pth', map_location='cpu')
|
| 4 |
+
print(a)
|