| import gradio as gr |
| from PIL import Image |
| from typing import List, Dict |
|
|
| import torch |
|
|
| class GradioApp: |
| def __init__(self) -> None: |
| self.model = torch.load('pretrained_vit.pth', map_location='cpu') |
| def predict(self, img_file: str, classes: List[str]) -> Dict[str, float]: |
| classes = ['0', '1', '2'] |
| img = self.model.val_transform(Image.open(img_file)).unsqueeze(0) |
| with torch.inference_mode(): |
| preds = torch.softmax(self.model(img), dim=1)[0].cpu().numpy() |
| return {classes[i] : preds[i] for i in range(len(classes))} |
| def launch(self): |
| demo = gr.Interface( |
| fn=self.predict, |
| inputs=gr.Image(type='filepath'), |
| outputs=gr.Label(num_top_classes=3), |
| ) |
| demo.launch() |
|
|
| if __name__ == '__main__': |
| app = GradioApp() |
| app.launch() |
|
|