|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import os |
|
|
import torchvision.transforms as T |
|
|
|
|
|
from model import FlowerClassificationModel |
|
|
from timeit import default_timer as timer |
|
|
from typing import Tuple, Dict |
|
|
from data_setup import classes, model_tsfm |
|
|
from utils import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flower_model = FlowerClassificationModel(num_classes=len(classes), pretrained=True) |
|
|
|
|
|
saved_path = 'flower_model_29.pth' |
|
|
|
|
|
print('Loading Model State Dictionary') |
|
|
|
|
|
flower_model.load_state_dict( |
|
|
torch.load(f=saved_path, |
|
|
map_location=torch.device('cpu'), |
|
|
)['model_state_dict'] |
|
|
) |
|
|
|
|
|
print('Model Loaded ...') |
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple, Dict |
|
|
|
|
|
def predict(img) -> Tuple[Dict, float]: |
|
|
"""Transforms and performs a prediction on img and returns prediction and time taken. |
|
|
""" |
|
|
|
|
|
start_time = timer() |
|
|
|
|
|
|
|
|
|
|
|
img = model_tsfm(img) |
|
|
img = img.unsqueeze(0) |
|
|
|
|
|
|
|
|
flower_model.eval() |
|
|
with torch.inference_mode(): |
|
|
|
|
|
pred_probs = torch.softmax(flower_model(img), dim=1) |
|
|
|
|
|
|
|
|
pred_labels_and_probs = {classes[i]: float(pred_probs[0][i]) for i in range(len(classes))} |
|
|
|
|
|
|
|
|
pred_time = round(timer() - start_time, 5) |
|
|
|
|
|
|
|
|
return pred_labels_and_probs, pred_time |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title= 'United Kingdom Flower Classification Mini ๐ป๐ผ๐ธโ๐๐ท' |
|
|
description = "An ResNet50 computer vision model to classify images of Flower Categories." |
|
|
article = "<p>Flower Classification Created by Chukwuka </p><p style='text-align: center'><a href='https://github.com/Sylvesterchuks/flower_classification'>Github Repo</a></p>" |
|
|
|
|
|
|
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
|
inputs=gr.Image(type='pil'), |
|
|
outputs=[gr.Label(num_top_classes=5, label="Predictions"), |
|
|
gr.Number(label='Prediction time (s)')], |
|
|
examples=example_list, |
|
|
title=title, |
|
|
description=description, |
|
|
article=article |
|
|
) |
|
|
|
|
|
print('Gradio Demo Launched') |
|
|
demo.launch() |
|
|
|
|
|
|