i4ata's picture
small fix agen
0548088
raw
history blame
2.5 kB
import torch
import torch.nn as nn
from torchvision import models
import gradio as gr
from PIL import Image
import os
from typing import List, Dict, Union
from custom_transformer.vit import ViT
from transforms import model_transforms
class GradioApp:
def __init__(self) -> None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
custom = ViT().to(device).eval()
custom.load_state_dict(torch.load('models/my_vit.pt', map_location=device))
pretrained = models.vit_b_16().to(device).eval()
pretrained.heads = nn.Linear(768, 3)
pretrained.load_state_dict(torch.load('models/pretrained_vit.pt', map_location=device))
self.models: Dict[str, Union[str, nn.Module]] = {
'Custom': custom,
'Pretrained': pretrained
}
with open('classname.txt') as f:
self.classes: List[str] = [line.strip() for line in f.readlines()]
def predict(self, img_file: str, model_name: str) -> Dict[str, float]:
img = model_transforms[model_name](Image.open(img_file)).unsqueeze(0)
with torch.inference_mode():
preds = torch.softmax(self.models[model_name](img)[0], dim=0).cpu().numpy()
return dict(zip(self.classes, preds))
def launch(self):
dataset_url = 'https://www.kaggle.com/datasets/marquis03/bean-leaf-lesions-classification/data'
github_repo_url = 'https://github.com/i4ata/TransformerClassification'
examples_list = [['examples/' + example] for example in os.listdir('examples')]
demo = gr.Interface(
fn=self.predict,
inputs=[
gr.Image(type='filepath', label='Input image to classify'),
gr.Radio(choices=('Custom', 'Pretrained'), label='Available models')
],
outputs=gr.Label(num_top_classes=3, label='Model predictions'),
examples=examples_list,
cache_examples=False,
title='Plants Diseases Classification',
description=f'This model performs classification on images of leaves that are either healthy, \
have bean rust, or have an angular leaf spot. A vision transformer neural network architecture is used. \
The dataset can be downloaded from [Kaggle]({dataset_url}) and the source code is on [GitHub]({github_repo_url}).',
)
demo.launch()
if __name__ == '__main__':
app = GradioApp()
app.launch()