File size: 2,439 Bytes
5feebb1 8de41e5 cce011e 5feebb1 563e5fb 5feebb1 8de41e5 26a33b7 8de41e5 26a33b7 5feebb1 563e5fb 5feebb1 26a33b7 5feebb1 8de41e5 0548088 26a33b7 8de41e5 26a33b7 cce011e 8de41e5 cce011e d8052b6 cce011e 8de41e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import os
from typing import List, Dict, Union
from custom_transformer.vit import ViT
from utils import model_transforms, get_pretrained_vit
class GradioApp:
def __init__(self) -> None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
custom = ViT().to(device).eval()
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
pretrained = get_pretrained_vit().to(device).eval()
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
self.models: Dict[str, Union[str, nn.Module]] = {
'Custom': custom,
'Pretrained': pretrained
}
with open('classname.txt') as f:
self.classes: List[str] = [line.strip() for line in f.readlines()]
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
img = model_transforms[model_name](Image.open(img_file)).unsqueeze(0)
with torch.inference_mode():
preds = torch.softmax(self.models[model_name](img)[0], dim=0).cpu().numpy()
return dict(zip(self.classes, preds))
def launch(self):
dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data'
github_repo_url = 'https://github.com/i4ata/TransformerClassification'
examples_list = [['examples/' + example] for example in os.listdir('examples')]
demo = gr.Interface(
fn=self.predict,
inputs=[
gr.Image(type='filepath', label='Input image to classify'),
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
],
outputs=gr.Label(num_top_classes=3, label='Model predictions'),
examples=examples_list,
cache_examples=False,
title='Plants Diseases Classification',
description=f'This model performs classification on images of leaves that are either healthy, \
have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()
|