AyoAgbaje's picture
Update app.py
5d8ea7e verified
import numpy as np
import pandas as pd
import yaml
import os
import torch
import torchvision
from torchvision.transforms import v2
import PIL
from PIL import Image
import gradio as gr
import zipfile
zip_file_path = "models.zip"
extract_dir = os.getcwd()
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
def app_interface():
with open("model_names.yaml") as config:
config__ = yaml.safe_load(config)
model_1 = config__["models"]["vit_model"]
def return_classes_():
classes_ = {0: 'adonis',
1: 'african giant swallowtail',
2: 'american snoot',
3: 'an 88',
4: 'appollo',
5: 'atala',
6: 'banded orange heliconian',
7: 'banded peacock',
8: 'beckers white',
9: 'black hairstreak',
10: 'blue morpho',
11: 'blue spotted crow',
12: 'brown siproeta',
13: 'cabbage white',
14: 'cairns birdwing',
15: 'checquered skipper',
16: 'chestnut',
17: 'cleopatra',
18: 'clodius parnassian',
19: 'clouded sulphur',
20: 'common banded awl',
21: 'common wood-nymph',
22: 'copper tail',
23: 'crecent',
24: 'crimson patch',
25: 'danaid eggfly',
26: 'eastern coma',
27: 'eastern dapple white',
28: 'eastern pine elfin',
29: 'elbowed pierrot',
30: 'gold banded',
31: 'great eggfly',
32: 'great jay',
33: 'green celled cattleheart',
34: 'grey hairstreak',
35: 'indra swallow',
36: 'iphiclus sister',
37: 'julia',
38: 'large marble',
39: 'malachite',
40: 'mangrove skipper',
41: 'mestra',
42: 'metalmark',
43: 'milberts tortoiseshell',
44: 'monarch',
45: 'mourning cloak',
46: 'orange oakleaf',
47: 'orange tip',
48: 'orchard swallow',
49: 'painted lady',
50: 'paper kite',
51: 'peacock',
52: 'pine white',
53: 'pipevine swallow',
54: 'popinjay',
55: 'purple hairstreak',
56: 'purplish copper',
57: 'question mark',
58: 'red admiral',
59: 'red cracker',
60: 'red postman',
61: 'red spotted purple',
62: 'scarce swallow',
63: 'silver spot skipper',
64: 'sleepy orange',
65: 'sootywing',
66: 'southern dogface',
67: 'straited queen',
68: 'tropical leafwing',
69: 'two barred flasher',
70: 'ulyses',
71: 'viceroy',
72: 'wood satyr',
73: 'yellow swallow tail',
74: 'zebra long wing'}
return classes_
model = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
def make_predictions(path_:str):
transform = v2.Compose([
v2.Resize((224,224)),
v2.ToImage(),
v2.ToDtype(torch.float32, scale = True),
v2.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
class_map = return_classes_()
model_ = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
# img_ = Image.open(path_)
img__ = transform(path_).unsqueeze(dim = 0)
pred_ = model_(img__)
pred = torch.argmax(torch.softmax(pred_, dim = 1), axis = 1)
top_n = pd.DataFrame(data = torch.softmax(pred_, dim = 1).detach().numpy(), columns = class_map.values()).T
top_n.columns = ["probs"]
top_n = top_n.sort_values("probs", ascending = False)
top_n = top_n.loc[top_n.index[:3],:]
top_n_dict = {}
top_specie, top_prob = [i for i in top_n.index], [i for i in top_n.probs.values]
for i, j in zip(top_specie, top_prob):
top_n_dict[f"{i}"] = float(j)
# display(top_n)
# plt.imshow(img_)
# plt.grid(axis = "both", lw = 0)
# plt.title(f"{class_map[pred.item()].capitalize()} | {torch.softmax(pred_, dim = 1).max():.2f}", fontweight = "bold", fontsize = 10)
return top_n_dict, f"predicted specie of butterfly is => {class_map[pred.item()].capitalize()}; with a confidence of {torch.softmax(pred_, dim = 1).max()}%"
title = "ButterFly Specie Classifier"
description = "Vision Transformer feature extractor model trained to classify species of butterflies"
demo = gr.Interface(fn = make_predictions, inputs = gr.Image(type = "pil", label = "Upload Image here"),
outputs=[gr.Label(num_top_classes=3, label = "Predictions"), gr.Textbox(label = "Result")],
title = title,
description=description)
demo.launch(share = False)
if __name__ == "__main__":
app_interface()