| | import gradio as gr |
| | from PIL import Image |
| | import os |
| |
|
| | import torch |
| |
|
| | from model import ClassifierModel |
| |
|
| | from typing import List, Dict, Union |
| |
|
| | class GradioApp: |
| |
|
| | def __init__(self) -> None: |
| |
|
| | self.models: Dict[str, Union[str, ClassifierModel]] = { |
| | 'Custom': 'models/my_vit.pth', |
| | 'Pretrained': 'models/pretrained_vit.pth' |
| | } |
| | with open('classname.txt') as f: |
| | self.classes: List[str] = [line.strip() for line in f.readlines()] |
| | |
| | def predict(self, img_file: str, model_name: str) -> Dict[str, float]: |
| | |
| | |
| | if isinstance(self.models[model_name], str): |
| | self.models[model_name] = torch.load(self.models[model_name], map_location='cpu') |
| | self.models[model_name].eval() |
| |
|
| | img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0) |
| | with torch.inference_mode(): |
| | preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy() |
| | return dict(zip(self.classes, preds)) |
| | |
| | def launch(self): |
| | |
| | dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data' |
| | github_repo_url = 'https://github.com/i4ata/TransformerClassification' |
| | examples_list = [['examples/' + example] for example in os.listdir('examples')] |
| |
|
| | demo = gr.Interface( |
| | fn=self.predict, |
| | inputs=[ |
| | gr.Image(type='filepath', label='Input image to classify'), |
| | gr.Radio(choices=('Custom', 'Pretrained'), label='Available models') |
| | ], |
| | outputs=gr.Label(num_top_classes=3, label='Model predictions'), |
| | examples=examples_list, |
| | cache_examples=False, |
| | title='Plants Diseases Classification', |
| | description=f'This model performs classification on images of leaves that are either healthy, \ |
| | have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \ |
| | The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).', |
| | ) |
| | demo.launch() |
| |
|
| | if __name__ == '__main__': |
| | app = GradioApp() |
| | app.launch() |
| |
|