AyoAgbaje commited on
Commit
5081693
·
verified ·
1 Parent(s): 051a39b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -133
app.py CHANGED
@@ -1,134 +1,134 @@
1
- import numpy as np
2
- import pandas as pd
3
- import yaml
4
- import os
5
-
6
- import torch
7
- import torchvision
8
- from torchvision.transforms import v2
9
- import PIL
10
- from PIL import Image
11
-
12
- import gradio as gr
13
-
14
- def app_interface():
15
- with open("model_names.yaml") as config:
16
- config__ = yaml.safe_load(config)
17
- model_1 = config__["models"]["vit_model"]
18
- def return_classes_():
19
- classes_ = {0: 'adonis',
20
- 1: 'african giant swallowtail',
21
- 2: 'american snoot',
22
- 3: 'an 88',
23
- 4: 'appollo',
24
- 5: 'atala',
25
- 6: 'banded orange heliconian',
26
- 7: 'banded peacock',
27
- 8: 'beckers white',
28
- 9: 'black hairstreak',
29
- 10: 'blue morpho',
30
- 11: 'blue spotted crow',
31
- 12: 'brown siproeta',
32
- 13: 'cabbage white',
33
- 14: 'cairns birdwing',
34
- 15: 'checquered skipper',
35
- 16: 'chestnut',
36
- 17: 'cleopatra',
37
- 18: 'clodius parnassian',
38
- 19: 'clouded sulphur',
39
- 20: 'common banded awl',
40
- 21: 'common wood-nymph',
41
- 22: 'copper tail',
42
- 23: 'crecent',
43
- 24: 'crimson patch',
44
- 25: 'danaid eggfly',
45
- 26: 'eastern coma',
46
- 27: 'eastern dapple white',
47
- 28: 'eastern pine elfin',
48
- 29: 'elbowed pierrot',
49
- 30: 'gold banded',
50
- 31: 'great eggfly',
51
- 32: 'great jay',
52
- 33: 'green celled cattleheart',
53
- 34: 'grey hairstreak',
54
- 35: 'indra swallow',
55
- 36: 'iphiclus sister',
56
- 37: 'julia',
57
- 38: 'large marble',
58
- 39: 'malachite',
59
- 40: 'mangrove skipper',
60
- 41: 'mestra',
61
- 42: 'metalmark',
62
- 43: 'milberts tortoiseshell',
63
- 44: 'monarch',
64
- 45: 'mourning cloak',
65
- 46: 'orange oakleaf',
66
- 47: 'orange tip',
67
- 48: 'orchard swallow',
68
- 49: 'painted lady',
69
- 50: 'paper kite',
70
- 51: 'peacock',
71
- 52: 'pine white',
72
- 53: 'pipevine swallow',
73
- 54: 'popinjay',
74
- 55: 'purple hairstreak',
75
- 56: 'purplish copper',
76
- 57: 'question mark',
77
- 58: 'red admiral',
78
- 59: 'red cracker',
79
- 60: 'red postman',
80
- 61: 'red spotted purple',
81
- 62: 'scarce swallow',
82
- 63: 'silver spot skipper',
83
- 64: 'sleepy orange',
84
- 65: 'sootywing',
85
- 66: 'southern dogface',
86
- 67: 'straited queen',
87
- 68: 'tropical leafwing',
88
- 69: 'two barred flasher',
89
- 70: 'ulyses',
90
- 71: 'viceroy',
91
- 72: 'wood satyr',
92
- 73: 'yellow swallow tail',
93
- 74: 'zebra long wing'}
94
- return classes_
95
- model = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
96
- def make_predictions(path_:str):
97
- transform = v2.Compose([
98
- v2.Resize((224,224)),
99
- v2.ToImage(),
100
- v2.ToDtype(torch.float32, scale = True),
101
- v2.Normalize(mean=[0.485, 0.456, 0.406],
102
- std=[0.229, 0.224, 0.225])])
103
- class_map = return_classes_()
104
- model_ = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
105
- # img_ = Image.open(path_)
106
- img__ = transform(path_).unsqueeze(dim = 0)
107
- pred_ = model_(img__)
108
- pred = torch.argmax(torch.softmax(pred_, dim = 1), axis = 1)
109
- top_n = pd.DataFrame(data = torch.softmax(pred_, dim = 1).detach().numpy(), columns = class_map.values()).T
110
- top_n.columns = ["probs"]
111
- top_n = top_n.sort_values("probs", ascending = False)
112
- top_n = top_n.loc[top_n.index[:3],:]
113
- top_n_dict = {}
114
- top_specie, top_prob = [i for i in top_n.index], [i for i in top_n.probs.values]
115
- for i, j in zip(top_specie, top_prob):
116
- top_n_dict[f"{i}"] = float(j)
117
- # display(top_n)
118
- # plt.imshow(img_)
119
- # plt.grid(axis = "both", lw = 0)
120
- # plt.title(f"{class_map[pred.item()].capitalize()} | {torch.softmax(pred_, dim = 1).max():.2f}", fontweight = "bold", fontsize = 10)
121
- return top_n_dict, f"predicted specie of butterfly is => {class_map[pred.item()].capitalize()}; with a confidence of {torch.softmax(pred_, dim = 1).max()}%"
122
- title = "ButterFly Specie Classifier"
123
- description = "Vision Transformer feature extractor model trained to classify species of butterflies"
124
-
125
- demo = gr.Interface(fn = make_predictions, inputs = gr.Image(type = "pil", label = "Upload Image here"),
126
- outputs=[gr.Label(num_top_classes=3, label = "Predictions"), gr.Textbox(label = "Result")],
127
- title = title,
128
- description=description)
129
-
130
- demo.launch(share = True)
131
-
132
-
133
- if __name__ == "__main__":
134
  app_interface()
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import yaml
4
+ import os
5
+
6
+ import torch
7
+ import torchvision
8
+ from torchvision.transforms import v2
9
+ import PIL
10
+ from PIL import Image
11
+
12
+ import gradio as gr
13
+
14
+ def app_interface():
15
+ with open("model_names.yaml") as config:
16
+ config__ = yaml.safe_load(config)
17
+ model_1 = config__["models"]["vit_model"]
18
+ def return_classes_():
19
+ classes_ = {0: 'adonis',
20
+ 1: 'african giant swallowtail',
21
+ 2: 'american snoot',
22
+ 3: 'an 88',
23
+ 4: 'appollo',
24
+ 5: 'atala',
25
+ 6: 'banded orange heliconian',
26
+ 7: 'banded peacock',
27
+ 8: 'beckers white',
28
+ 9: 'black hairstreak',
29
+ 10: 'blue morpho',
30
+ 11: 'blue spotted crow',
31
+ 12: 'brown siproeta',
32
+ 13: 'cabbage white',
33
+ 14: 'cairns birdwing',
34
+ 15: 'checquered skipper',
35
+ 16: 'chestnut',
36
+ 17: 'cleopatra',
37
+ 18: 'clodius parnassian',
38
+ 19: 'clouded sulphur',
39
+ 20: 'common banded awl',
40
+ 21: 'common wood-nymph',
41
+ 22: 'copper tail',
42
+ 23: 'crecent',
43
+ 24: 'crimson patch',
44
+ 25: 'danaid eggfly',
45
+ 26: 'eastern coma',
46
+ 27: 'eastern dapple white',
47
+ 28: 'eastern pine elfin',
48
+ 29: 'elbowed pierrot',
49
+ 30: 'gold banded',
50
+ 31: 'great eggfly',
51
+ 32: 'great jay',
52
+ 33: 'green celled cattleheart',
53
+ 34: 'grey hairstreak',
54
+ 35: 'indra swallow',
55
+ 36: 'iphiclus sister',
56
+ 37: 'julia',
57
+ 38: 'large marble',
58
+ 39: 'malachite',
59
+ 40: 'mangrove skipper',
60
+ 41: 'mestra',
61
+ 42: 'metalmark',
62
+ 43: 'milberts tortoiseshell',
63
+ 44: 'monarch',
64
+ 45: 'mourning cloak',
65
+ 46: 'orange oakleaf',
66
+ 47: 'orange tip',
67
+ 48: 'orchard swallow',
68
+ 49: 'painted lady',
69
+ 50: 'paper kite',
70
+ 51: 'peacock',
71
+ 52: 'pine white',
72
+ 53: 'pipevine swallow',
73
+ 54: 'popinjay',
74
+ 55: 'purple hairstreak',
75
+ 56: 'purplish copper',
76
+ 57: 'question mark',
77
+ 58: 'red admiral',
78
+ 59: 'red cracker',
79
+ 60: 'red postman',
80
+ 61: 'red spotted purple',
81
+ 62: 'scarce swallow',
82
+ 63: 'silver spot skipper',
83
+ 64: 'sleepy orange',
84
+ 65: 'sootywing',
85
+ 66: 'southern dogface',
86
+ 67: 'straited queen',
87
+ 68: 'tropical leafwing',
88
+ 69: 'two barred flasher',
89
+ 70: 'ulyses',
90
+ 71: 'viceroy',
91
+ 72: 'wood satyr',
92
+ 73: 'yellow swallow tail',
93
+ 74: 'zebra long wing'}
94
+ return classes_
95
+ model = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
96
+ def make_predictions(path_:str):
97
+ transform = v2.Compose([
98
+ v2.Resize((224,224)),
99
+ v2.ToImage(),
100
+ v2.ToDtype(torch.float32, scale = True),
101
+ v2.Normalize(mean=[0.485, 0.456, 0.406],
102
+ std=[0.229, 0.224, 0.225])])
103
+ class_map = return_classes_()
104
+ model_ = torch.load(f"models/{model_1}.pt", map_location=torch.device('cpu'))
105
+ # img_ = Image.open(path_)
106
+ img__ = transform(path_).unsqueeze(dim = 0)
107
+ pred_ = model_(img__)
108
+ pred = torch.argmax(torch.softmax(pred_, dim = 1), axis = 1)
109
+ top_n = pd.DataFrame(data = torch.softmax(pred_, dim = 1).detach().numpy(), columns = class_map.values()).T
110
+ top_n.columns = ["probs"]
111
+ top_n = top_n.sort_values("probs", ascending = False)
112
+ top_n = top_n.loc[top_n.index[:3],:]
113
+ top_n_dict = {}
114
+ top_specie, top_prob = [i for i in top_n.index], [i for i in top_n.probs.values]
115
+ for i, j in zip(top_specie, top_prob):
116
+ top_n_dict[f"{i}"] = float(j)
117
+ # display(top_n)
118
+ # plt.imshow(img_)
119
+ # plt.grid(axis = "both", lw = 0)
120
+ # plt.title(f"{class_map[pred.item()].capitalize()} | {torch.softmax(pred_, dim = 1).max():.2f}", fontweight = "bold", fontsize = 10)
121
+ return top_n_dict, f"predicted specie of butterfly is => {class_map[pred.item()].capitalize()}; with a confidence of {torch.softmax(pred_, dim = 1).max()}%"
122
+ title = "ButterFly Specie Classifier"
123
+ description = "Vision Transformer feature extractor model trained to classify species of butterflies"
124
+
125
+ demo = gr.Interface(fn = make_predictions, inputs = gr.Image(type = "pil", label = "Upload Image here"),
126
+ outputs=[gr.Label(num_top_classes=3, label = "Predictions"), gr.Textbox(label = "Result")],
127
+ title = title,
128
+ description=description)
129
+
130
+ demo.launch(share = False)
131
+
132
+
133
+ if __name__ == "__main__":
134
  app_interface()