i4ata's picture
smol update
26a33b7
raw
history blame
1.29 kB
import gradio as gr
from PIL import Image
from typing import List, Dict, Union
import torch
from model import ClassifierModel
class GradioApp:
def __init__(self) -> None:
self.models: Dict[str, Union[str, ClassifierModel]] = {
'Custom': 'models/my_vit.pth',
'Pretrained': 'models/pretrained_vit.pth'
}
with open('classname.txt') as f:
self.classes: List[str] = [line.strip() for line in f.readlines()]
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
if isinstance(self.models[model_name], str):
self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0)
with torch.inference_mode():
preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy()
return dict(zip(self.classes, preds))
def launch(self):
demo = gr.Interface(
fn=self.predict,
inputs=[gr.Image(type='filepath'), gr.Radio(('Custom', 'Pretrained'))],
outputs=gr.Label(num_top_classes=3),
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()