i4ata's picture
small change
563e5fb
raw
history blame
2.44 kB
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import os
from typing import List, Dict, Union
from custom_transformer.vit import ViT
from utils import model_transforms, get_pretrained_vit
class GradioApp:
def __init__(self) -> None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
custom = ViT().to(device).eval()
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
pretrained = get_pretrained_vit().to(device).eval()
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
self.models: Dict[str, Union[str, nn.Module]] = {
'Custom': custom,
'Pretrained': pretrained
}
with open('classname.txt') as f:
self.classes: List[str] = [line.strip() for line in f.readlines()]
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
img = model_transforms[model_name](Image.open(img_file)).unsqueeze(0)
with torch.inference_mode():
preds = torch.softmax(self.models[model_name](img)[0], dim=0).cpu().numpy()
return dict(zip(self.classes, preds))
def launch(self):
dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data'
github_repo_url = 'https://github.com/i4ata/TransformerClassification'
examples_list = [['examples/' + example] for example in os.listdir('examples')]
demo = gr.Interface(
fn=self.predict,
inputs=[
gr.Image(type='filepath', label='Input image to classify'),
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
],
outputs=gr.Label(num_top_classes=3, label='Model predictions'),
examples=examples_list,
cache_examples=False,
title='Plants Diseases Classification',
description=f'This model performs classification on images of leaves that are either healthy, \
have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()