i4ata's picture
Upload 2 files
8de41e5 verified
raw
history blame
850 Bytes
import gradio as gr
from PIL import Image
from typing import List, Dict
import torch
class GradioApp:
def __init__(self) -> None:
self.model = torch.load('pretrained_vit.pth', map_location='cpu')
def predict(self, img_file: str, classes: List[str]) -> Dict[str, float]:
classes = ['0', '1', '2']
img = self.model.val_transform(Image.open(img_file)).unsqueeze(0)
with torch.inference_mode():
preds = torch.softmax(self.model(img), dim=1)[0].cpu().numpy()
return {classes[i] : preds[i] for i in range(len(classes))}
def launch(self):
demo = gr.Interface(
fn=self.predict,
inputs=gr.Image(type='filepath'),
outputs=gr.Label(num_top_classes=3),
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()