File size: 5,712 Bytes
5081693
 
 
 
 
 
 
 
 
 
 
 
c0333f0
 
 
 
 
 
 
 
5081693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d8ea7e
5081693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da7c600
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
import pandas as pd
import yaml
import os

import torch
import torchvision
from torchvision.transforms import v2
import PIL
from PIL import Image

import gradio as gr
import zipfile

zip_file_path = "models.zip"

extract_dir = os.getcwd()

with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)

def app_interface():
    with open("model_names.yaml") as config:
        config__ = yaml.safe_load(config)
    model_1 = config__["models"]["vit_model"]
    def return_classes_():
        classes_ = {0: 'adonis',
                    1: 'african giant swallowtail',
                    2: 'american snoot',
                    3: 'an 88',
                    4: 'appollo',
                    5: 'atala',
                    6: 'banded orange heliconian',
                    7: 'banded peacock',
                    8: 'beckers white',
                    9: 'black hairstreak',
                    10: 'blue morpho',
                    11: 'blue spotted crow',
                    12: 'brown siproeta',
                    13: 'cabbage white',
                    14: 'cairns birdwing',
                    15: 'checquered skipper',
                    16: 'chestnut',
                    17: 'cleopatra',
                    18: 'clodius parnassian',
                    19: 'clouded sulphur',
                    20: 'common banded awl',
                    21: 'common wood-nymph',
                    22: 'copper tail',
                    23: 'crecent',
                    24: 'crimson patch',
                    25: 'danaid eggfly',
                    26: 'eastern coma',
                    27: 'eastern dapple white',
                    28: 'eastern pine elfin',
                    29: 'elbowed pierrot',
                    30: 'gold banded',
                    31: 'great eggfly',
                    32: 'great jay',
                    33: 'green celled cattleheart',
                    34: 'grey hairstreak',
                    35: 'indra swallow',
                    36: 'iphiclus sister',
                    37: 'julia',
                    38: 'large marble',
                    39: 'malachite',
                    40: 'mangrove skipper',
                    41: 'mestra',
                    42: 'metalmark',
                    43: 'milberts tortoiseshell',
                    44: 'monarch',
                    45: 'mourning cloak',
                    46: 'orange oakleaf',
                    47: 'orange tip',
                    48: 'orchard swallow',
                    49: 'painted lady',
                    50: 'paper kite',
                    51: 'peacock',
                    52: 'pine white',
                    53: 'pipevine swallow',
                    54: 'popinjay',
                    55: 'purple hairstreak',
                    56: 'purplish copper',
                    57: 'question mark',
                    58: 'red admiral',
                    59: 'red cracker',
                    60: 'red postman',
                    61: 'red spotted purple',
                    62: 'scarce swallow',
                    63: 'silver spot skipper',
                    64: 'sleepy orange',
                    65: 'sootywing',
                    66: 'southern dogface',
                    67: 'straited queen',
                    68: 'tropical leafwing',
                    69: 'two barred flasher',
                    70: 'ulyses',
                    71: 'viceroy',
                    72: 'wood satyr',
                    73: 'yellow swallow tail',
                    74: 'zebra long wing'}
        return classes_
    model = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
    def make_predictions(path_:str):
        transform = v2.Compose([
            v2.Resize((224,224)),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale = True),
            v2.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])
        class_map = return_classes_()
        model_ = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
        # img_ = Image.open(path_)
        img__ = transform(path_).unsqueeze(dim = 0)
        pred_ = model_(img__)
        pred = torch.argmax(torch.softmax(pred_, dim = 1), axis = 1)
        top_n = pd.DataFrame(data = torch.softmax(pred_, dim = 1).detach().numpy(), columns = class_map.values()).T
        top_n.columns = ["probs"]
        top_n = top_n.sort_values("probs", ascending = False)
        top_n = top_n.loc[top_n.index[:3],:]
        top_n_dict = {}
        top_specie, top_prob = [i for i in top_n.index], [i for i in top_n.probs.values]
        for i, j in zip(top_specie, top_prob):
            top_n_dict[f"{i}"] = float(j)
        # display(top_n)
        # plt.imshow(img_)
        # plt.grid(axis = "both", lw = 0)
        # plt.title(f"{class_map[pred.item()].capitalize()} | {torch.softmax(pred_, dim = 1).max():.2f}", fontweight = "bold", fontsize = 10)
        return top_n_dict, f"predicted specie of butterfly is => {class_map[pred.item()].capitalize()}; with a confidence of {torch.softmax(pred_, dim = 1).max()}%"
    title = "ButterFly Specie Classifier"
    description = "Vision Transformer  feature extractor model trained to classify species of butterflies"

    demo = gr.Interface(fn = make_predictions, inputs = gr.Image(type = "pil", label = "Upload Image here"),
                        outputs=[gr.Label(num_top_classes=3, label = "Predictions"), gr.Textbox(label = "Result")],
                        title = title,
                        description=description)

    demo.launch(share = False)


if __name__ == "__main__":
    app_interface()