import gradio as gr from PIL import Image import os import torch from model import ClassifierModel from typing import List, Dict, Union class GradioApp: def __init__(self) -> None: self.models: Dict[str, Union[str, ClassifierModel]] = { 'Custom': 'models/my_vit.pth', 'Pretrained': 'models/pretrained_vit.pth' } with open('classname.txt') as f: self.classes: List[str] = [line.strip() for line in f.readlines()] def predict(self, img_file: str, model_name: str) -> Dict[str, float]: # Lazy loading of models if isinstance(self.models[model_name], str): self.models[model_name] = torch.load(self.models[model_name], map_location='cpu') img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0) with torch.inference_mode(): preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy() return dict(zip(self.classes, preds)) def launch(self): dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data' github_repo_url = 'https://github.com/i4ata/TransformerClassification' examples_list = [['examples/' + example] for example in os.listdir('examples')] demo = gr.Interface( fn=self.predict, inputs=[ gr.Image(type='filepath', label='Input image to classify'), gr.Radio(choices=('Custom', 'Pretrained'), label='Available models') ], outputs=gr.Label(num_top_classes=3, label='Model predictions'), title='Plants Diseases Classification', description=f'This model performs classification on images of leaves that are either healthy, \ have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \ The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).', examples=examples_list ) demo.launch() if __name__ == '__main__': app = GradioApp() app.launch()