i4ata's picture
done myb
cce011e
raw
history blame
2.2 kB
import gradio as gr
from PIL import Image
import os
import torch
from model import ClassifierModel
from typing import List, Dict, Union
class GradioApp:
def __init__(self) -> None:
self.models: Dict[str, Union[str, ClassifierModel]] = {
'Custom': 'models/my_vit.pth',
'Pretrained': 'models/pretrained_vit.pth'
}
with open('classname.txt') as f:
self.classes: List[str] = [line.strip() for line in f.readlines()]
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
# Lazy loading of models
if isinstance(self.models[model_name], str):
self.models[model_name] = torch.load(self.models[model_name], map_location='cpu')
img = torch.unsqueeze(self.models[model_name].val_transform(Image.open(img_file)), 0)
with torch.inference_mode():
preds = torch.softmax(self.models[model_name](img), dim=1)[0].numpy()
return dict(zip(self.classes, preds))
def launch(self):
dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data'
github_repo_url = 'https://github.com/i4ata/TransformerClassification'
examples_list = [['examples/' + example] for example in os.listdir('examples')]
demo = gr.Interface(
fn=self.predict,
inputs=[
gr.Image(type='filepath', label='Input image to classify'),
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
],
outputs=gr.Label(num_top_classes=3, label='Model predictions'),
title='Plants Diseases Classification',
description=f'This model performs classification on images of leaves that are either healthy, \
have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
examples=examples_list
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()