import gradio as gr from PIL import Image from typing import List, Dict import torch class GradioApp: def __init__(self) -> None: self.model = torch.load('pretrained_vit.pth', map_location='cpu') def predict(self, img_file: str, classes: List[str]) -> Dict[str, float]: classes = ['0', '1', '2'] img = self.model.val_transform(Image.open(img_file)).unsqueeze(0) with torch.inference_mode(): preds = torch.softmax(self.model(img), dim=1)[0].cpu().numpy() return {classes[i] : preds[i] for i in range(len(classes))} def launch(self): demo = gr.Interface( fn=self.predict, inputs=gr.Image(type='filepath'), outputs=gr.Label(num_top_classes=3), ) demo.launch() if __name__ == '__main__': app = GradioApp() app.launch()