File size: 1,289 Bytes
8de41e5 26a33b7 8de41e5 26a33b7 8de41e5 26a33b7 8de41e5 26a33b7 8de41e5 26a33b7 8de41e5 26a33b7 8de41e5 26a33b7 8de41e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import gradio as gr
from PIL import Image
from typing import List, Dict, Union
import torch
from model import ClassifierModel
class GradioApp:
def __init__(self) -> None:
self.models: Dict[str, Union[str, ClassifierModel]] = {
'Custom': 'models/my_vit.pth',
'Pretrained': 'models/pretrained_vit.pth'
}
with open('classname.txt') as f:
self.classes: List[str] = [line.strip() for line in f.readlines()]
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
if isinstance(self.models[model_name], str):
self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0)
with torch.inference_mode():
preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy()
return dict(zip(self.classes, preds))
def launch(self):
demo = gr.Interface(
fn=self.predict,
inputs=[gr.Image(type='filepath'), gr.Radio(('Custom', 'Pretrained'))],
outputs=gr.Label(num_top_classes=3),
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()
|