|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import yaml |
|
|
import os |
|
|
|
|
|
import torch |
|
|
import torchvision |
|
|
from torchvision.transforms import v2 |
|
|
import PIL |
|
|
from PIL import Image |
|
|
|
|
|
import gradio as gr |
|
|
import zipfile |
|
|
|
|
|
zip_file_path = "models.zip" |
|
|
|
|
|
extract_dir = os.getcwd() |
|
|
|
|
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(extract_dir) |
|
|
|
|
|
def app_interface(): |
|
|
with open("model_names.yaml") as config: |
|
|
config__ = yaml.safe_load(config) |
|
|
model_1 = config__["models"]["vit_model"] |
|
|
def return_classes_(): |
|
|
classes_ = {0: 'adonis', |
|
|
1: 'african giant swallowtail', |
|
|
2: 'american snoot', |
|
|
3: 'an 88', |
|
|
4: 'appollo', |
|
|
5: 'atala', |
|
|
6: 'banded orange heliconian', |
|
|
7: 'banded peacock', |
|
|
8: 'beckers white', |
|
|
9: 'black hairstreak', |
|
|
10: 'blue morpho', |
|
|
11: 'blue spotted crow', |
|
|
12: 'brown siproeta', |
|
|
13: 'cabbage white', |
|
|
14: 'cairns birdwing', |
|
|
15: 'checquered skipper', |
|
|
16: 'chestnut', |
|
|
17: 'cleopatra', |
|
|
18: 'clodius parnassian', |
|
|
19: 'clouded sulphur', |
|
|
20: 'common banded awl', |
|
|
21: 'common wood-nymph', |
|
|
22: 'copper tail', |
|
|
23: 'crecent', |
|
|
24: 'crimson patch', |
|
|
25: 'danaid eggfly', |
|
|
26: 'eastern coma', |
|
|
27: 'eastern dapple white', |
|
|
28: 'eastern pine elfin', |
|
|
29: 'elbowed pierrot', |
|
|
30: 'gold banded', |
|
|
31: 'great eggfly', |
|
|
32: 'great jay', |
|
|
33: 'green celled cattleheart', |
|
|
34: 'grey hairstreak', |
|
|
35: 'indra swallow', |
|
|
36: 'iphiclus sister', |
|
|
37: 'julia', |
|
|
38: 'large marble', |
|
|
39: 'malachite', |
|
|
40: 'mangrove skipper', |
|
|
41: 'mestra', |
|
|
42: 'metalmark', |
|
|
43: 'milberts tortoiseshell', |
|
|
44: 'monarch', |
|
|
45: 'mourning cloak', |
|
|
46: 'orange oakleaf', |
|
|
47: 'orange tip', |
|
|
48: 'orchard swallow', |
|
|
49: 'painted lady', |
|
|
50: 'paper kite', |
|
|
51: 'peacock', |
|
|
52: 'pine white', |
|
|
53: 'pipevine swallow', |
|
|
54: 'popinjay', |
|
|
55: 'purple hairstreak', |
|
|
56: 'purplish copper', |
|
|
57: 'question mark', |
|
|
58: 'red admiral', |
|
|
59: 'red cracker', |
|
|
60: 'red postman', |
|
|
61: 'red spotted purple', |
|
|
62: 'scarce swallow', |
|
|
63: 'silver spot skipper', |
|
|
64: 'sleepy orange', |
|
|
65: 'sootywing', |
|
|
66: 'southern dogface', |
|
|
67: 'straited queen', |
|
|
68: 'tropical leafwing', |
|
|
69: 'two barred flasher', |
|
|
70: 'ulyses', |
|
|
71: 'viceroy', |
|
|
72: 'wood satyr', |
|
|
73: 'yellow swallow tail', |
|
|
74: 'zebra long wing'} |
|
|
return classes_ |
|
|
model = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu')) |
|
|
def make_predictions(path_:str): |
|
|
transform = v2.Compose([ |
|
|
v2.Resize((224,224)), |
|
|
v2.ToImage(), |
|
|
v2.ToDtype(torch.float32, scale = True), |
|
|
v2.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225])]) |
|
|
class_map = return_classes_() |
|
|
model_ = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu')) |
|
|
|
|
|
img__ = transform(path_).unsqueeze(dim = 0) |
|
|
pred_ = model_(img__) |
|
|
pred = torch.argmax(torch.softmax(pred_, dim = 1), axis = 1) |
|
|
top_n = pd.DataFrame(data = torch.softmax(pred_, dim = 1).detach().numpy(), columns = class_map.values()).T |
|
|
top_n.columns = ["probs"] |
|
|
top_n = top_n.sort_values("probs", ascending = False) |
|
|
top_n = top_n.loc[top_n.index[:3],:] |
|
|
top_n_dict = {} |
|
|
top_specie, top_prob = [i for i in top_n.index], [i for i in top_n.probs.values] |
|
|
for i, j in zip(top_specie, top_prob): |
|
|
top_n_dict[f"{i}"] = float(j) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return top_n_dict, f"predicted specie of butterfly is => {class_map[pred.item()].capitalize()}; with a confidence of {torch.softmax(pred_, dim = 1).max()}%" |
|
|
title = "ButterFly Specie Classifier" |
|
|
description = "Vision Transformer feature extractor model trained to classify species of butterflies" |
|
|
|
|
|
demo = gr.Interface(fn = make_predictions, inputs = gr.Image(type = "pil", label = "Upload Image here"), |
|
|
outputs=[gr.Label(num_top_classes=3, label = "Predictions"), gr.Textbox(label = "Result")], |
|
|
title = title, |
|
|
description=description) |
|
|
|
|
|
demo.launch(share = False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
app_interface() |