| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| from PIL import Image | |
| import os | |
| from typing import List, Dict, Union | |
| from custom_transformer.vit import ViT | |
| from utils import model_transforms, get_pretrained_vit | |
| class GradioApp: | |
| def __init__(self) -> None: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| custom = ViT().to(device).eval() | |
| custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device)) | |
| pretrained = get_pretrained_vit().to(device).eval() | |
| pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device)) | |
| self.models: Dict[str, Union[str, nn.Module]] = { | |
| 'Custom': custom, | |
| 'Pretrained': pretrained | |
| } | |
| with open('classname.txt') as f: | |
| self.classes: List[str] = [line.strip() for line in f.readlines()] | |
| def predict(self, img_file: str, model_name: str) -> Dict[str, float]: | |
| img = model_transforms[model_name](Image.open(img_file)).unsqueeze(0) | |
| with torch.inference_mode(): | |
| preds = torch.softmax(self.models[model_name](img)[0], dim=0).cpu().numpy() | |
| return dict(zip(self.classes, preds)) | |
| def launch(self): | |
| dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data' | |
| github_repo_url = 'https://github.com/i4ata/TransformerClassification' | |
| examples_list = [['examples/' + example] for example in os.listdir('examples')] | |
| demo = gr.Interface( | |
| fn=self.predict, | |
| inputs=[ | |
| gr.Image(type='filepath', label='Input image to classify'), | |
| gr.Radio(choices=('Custom', 'Pretrained'), label='Available models') | |
| ], | |
| outputs=gr.Label(num_top_classes=3, label='Model predictions'), | |
| examples=examples_list, | |
| cache_examples=False, | |
| title='Plants Diseases Classification', | |
| description=f'This model performs classification on images of leaves that are either healthy, \ | |
| have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \ | |
| The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).', | |
| ) | |
| demo.launch() | |
| if __name__ == '__main__': | |
| app = GradioApp() | |
| app.launch() | |