Billy-06 commited on
Commit ·
a263b83
1
Parent(s): eb9829e
Added the file Architectures
Browse files- .gitignore +3 -0
- app.py +53 -4
- classes.txt +200 -0
- model.py +248 -0
- requirements.txt +86 -0
.gitignore
CHANGED
|
@@ -3,6 +3,9 @@ flagged/
|
|
| 3 |
*.png
|
| 4 |
*.jpg
|
| 5 |
*.jpeg
|
|
|
|
| 6 |
gradio_cache/
|
| 7 |
|
| 8 |
venv/
|
|
|
|
|
|
|
|
|
| 3 |
*.png
|
| 4 |
*.jpg
|
| 5 |
*.jpeg
|
| 6 |
+
*.pyc
|
| 7 |
gradio_cache/
|
| 8 |
|
| 9 |
venv/
|
| 10 |
+
__pychache__/
|
| 11 |
+
|
app.py
CHANGED
|
@@ -1,7 +1,56 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
|
| 5 |
+
from model import *
|
|
|
|
| 6 |
|
| 7 |
+
def load_cub200_classes():
|
| 8 |
+
"""
|
| 9 |
+
This function loads the classes from the classes.txt file and returns a dictionary
|
| 10 |
+
"""
|
| 11 |
+
with open("classes.txt", encoding="utf-8") as f:
|
| 12 |
+
classes = f.read().splitlines()
|
| 13 |
+
|
| 14 |
+
# convert classes to dictionary separating the lines by the first space
|
| 15 |
+
classes = {int(line.split(" ")[0]) : line.split(" ")[1] for line in classes}
|
| 16 |
+
|
| 17 |
+
# return the classes dictionary
|
| 18 |
+
return classes
|
| 19 |
+
|
| 20 |
+
def load_model():
|
| 21 |
+
"""
|
| 22 |
+
This function loads the trained model and returns it
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
# load the resnet model
|
| 26 |
+
model = resnet50(pretrained=False, stride=[1, 2, 2, 1], num_classes=200)
|
| 27 |
+
# load the trained weights
|
| 28 |
+
model.load_state_dict(torch.load("resnet.pt", map_location=torch.device('cpu')))
|
| 29 |
+
# set the model to evaluation mode
|
| 30 |
+
model.eval()
|
| 31 |
+
# return the model
|
| 32 |
+
return model
|
| 33 |
+
|
| 34 |
+
def predict_image(image):
|
| 35 |
+
"""
|
| 36 |
+
This function takes an image as input and returns the class label
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# load the model
|
| 40 |
+
model = load_model()
|
| 41 |
+
# load the classes
|
| 42 |
+
classes = load_cub200_classes()
|
| 43 |
+
|
| 44 |
+
# convert image to tensor
|
| 45 |
+
tensor = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0)
|
| 46 |
+
# make prediction
|
| 47 |
+
prediction = model(tensor).detach().numpy()[0]
|
| 48 |
+
# convert prediction to probabilities
|
| 49 |
+
probabilities = np.exp(prediction) / np.sum(np.exp(prediction))
|
| 50 |
+
# get the class with the highest probability
|
| 51 |
+
class_idx = np.argmax(probabilities)
|
| 52 |
+
# return the class label
|
| 53 |
+
return "Class: " + classes[class_idx]
|
| 54 |
+
|
| 55 |
+
# create a gradio interface
|
| 56 |
+
gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()
|
classes.txt
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
1 001.Black_footed_Albatross
|
| 2 |
+
2 002.Laysan_Albatross
|
| 3 |
+
3 003.Sooty_Albatross
|
| 4 |
+
4 004.Groove_billed_Ani
|
| 5 |
+
5 005.Crested_Auklet
|
| 6 |
+
6 006.Least_Auklet
|
| 7 |
+
7 007.Parakeet_Auklet
|
| 8 |
+
8 008.Rhinoceros_Auklet
|
| 9 |
+
9 009.Brewer_Blackbird
|
| 10 |
+
10 010.Red_winged_Blackbird
|
| 11 |
+
11 011.Rusty_Blackbird
|
| 12 |
+
12 012.Yellow_headed_Blackbird
|
| 13 |
+
13 013.Bobolink
|
| 14 |
+
14 014.Indigo_Bunting
|
| 15 |
+
15 015.Lazuli_Bunting
|
| 16 |
+
16 016.Painted_Bunting
|
| 17 |
+
17 017.Cardinal
|
| 18 |
+
18 018.Spotted_Catbird
|
| 19 |
+
19 019.Gray_Catbird
|
| 20 |
+
20 020.Yellow_breasted_Chat
|
| 21 |
+
21 021.Eastern_Towhee
|
| 22 |
+
22 022.Chuck_will_Widow
|
| 23 |
+
23 023.Brandt_Cormorant
|
| 24 |
+
24 024.Red_faced_Cormorant
|
| 25 |
+
25 025.Pelagic_Cormorant
|
| 26 |
+
26 026.Bronzed_Cowbird
|
| 27 |
+
27 027.Shiny_Cowbird
|
| 28 |
+
28 028.Brown_Creeper
|
| 29 |
+
29 029.American_Crow
|
| 30 |
+
30 030.Fish_Crow
|
| 31 |
+
31 031.Black_billed_Cuckoo
|
| 32 |
+
32 032.Mangrove_Cuckoo
|
| 33 |
+
33 033.Yellow_billed_Cuckoo
|
| 34 |
+
34 034.Gray_crowned_Rosy_Finch
|
| 35 |
+
35 035.Purple_Finch
|
| 36 |
+
36 036.Northern_Flicker
|
| 37 |
+
37 037.Acadian_Flycatcher
|
| 38 |
+
38 038.Great_Crested_Flycatcher
|
| 39 |
+
39 039.Least_Flycatcher
|
| 40 |
+
40 040.Olive_sided_Flycatcher
|
| 41 |
+
41 041.Scissor_tailed_Flycatcher
|
| 42 |
+
42 042.Vermilion_Flycatcher
|
| 43 |
+
43 043.Yellow_bellied_Flycatcher
|
| 44 |
+
44 044.Frigatebird
|
| 45 |
+
45 045.Northern_Fulmar
|
| 46 |
+
46 046.Gadwall
|
| 47 |
+
47 047.American_Goldfinch
|
| 48 |
+
48 048.European_Goldfinch
|
| 49 |
+
49 049.Boat_tailed_Grackle
|
| 50 |
+
50 050.Eared_Grebe
|
| 51 |
+
51 051.Horned_Grebe
|
| 52 |
+
52 052.Pied_billed_Grebe
|
| 53 |
+
53 053.Western_Grebe
|
| 54 |
+
54 054.Blue_Grosbeak
|
| 55 |
+
55 055.Evening_Grosbeak
|
| 56 |
+
56 056.Pine_Grosbeak
|
| 57 |
+
57 057.Rose_breasted_Grosbeak
|
| 58 |
+
58 058.Pigeon_Guillemot
|
| 59 |
+
59 059.California_Gull
|
| 60 |
+
60 060.Glaucous_winged_Gull
|
| 61 |
+
61 061.Heermann_Gull
|
| 62 |
+
62 062.Herring_Gull
|
| 63 |
+
63 063.Ivory_Gull
|
| 64 |
+
64 064.Ring_billed_Gull
|
| 65 |
+
65 065.Slaty_backed_Gull
|
| 66 |
+
66 066.Western_Gull
|
| 67 |
+
67 067.Anna_Hummingbird
|
| 68 |
+
68 068.Ruby_throated_Hummingbird
|
| 69 |
+
69 069.Rufous_Hummingbird
|
| 70 |
+
70 070.Green_Violetear
|
| 71 |
+
71 071.Long_tailed_Jaeger
|
| 72 |
+
72 072.Pomarine_Jaeger
|
| 73 |
+
73 073.Blue_Jay
|
| 74 |
+
74 074.Florida_Jay
|
| 75 |
+
75 075.Green_Jay
|
| 76 |
+
76 076.Dark_eyed_Junco
|
| 77 |
+
77 077.Tropical_Kingbird
|
| 78 |
+
78 078.Gray_Kingbird
|
| 79 |
+
79 079.Belted_Kingfisher
|
| 80 |
+
80 080.Green_Kingfisher
|
| 81 |
+
81 081.Pied_Kingfisher
|
| 82 |
+
82 082.Ringed_Kingfisher
|
| 83 |
+
83 083.White_breasted_Kingfisher
|
| 84 |
+
84 084.Red_legged_Kittiwake
|
| 85 |
+
85 085.Horned_Lark
|
| 86 |
+
86 086.Pacific_Loon
|
| 87 |
+
87 087.Mallard
|
| 88 |
+
88 088.Western_Meadowlark
|
| 89 |
+
89 089.Hooded_Merganser
|
| 90 |
+
90 090.Red_breasted_Merganser
|
| 91 |
+
91 091.Mockingbird
|
| 92 |
+
92 092.Nighthawk
|
| 93 |
+
93 093.Clark_Nutcracker
|
| 94 |
+
94 094.White_breasted_Nuthatch
|
| 95 |
+
95 095.Baltimore_Oriole
|
| 96 |
+
96 096.Hooded_Oriole
|
| 97 |
+
97 097.Orchard_Oriole
|
| 98 |
+
98 098.Scott_Oriole
|
| 99 |
+
99 099.Ovenbird
|
| 100 |
+
100 100.Brown_Pelican
|
| 101 |
+
101 101.White_Pelican
|
| 102 |
+
102 102.Western_Wood_Pewee
|
| 103 |
+
103 103.Sayornis
|
| 104 |
+
104 104.American_Pipit
|
| 105 |
+
105 105.Whip_poor_Will
|
| 106 |
+
106 106.Horned_Puffin
|
| 107 |
+
107 107.Common_Raven
|
| 108 |
+
108 108.White_necked_Raven
|
| 109 |
+
109 109.American_Redstart
|
| 110 |
+
110 110.Geococcyx
|
| 111 |
+
111 111.Loggerhead_Shrike
|
| 112 |
+
112 112.Great_Grey_Shrike
|
| 113 |
+
113 113.Baird_Sparrow
|
| 114 |
+
114 114.Black_throated_Sparrow
|
| 115 |
+
115 115.Brewer_Sparrow
|
| 116 |
+
116 116.Chipping_Sparrow
|
| 117 |
+
117 117.Clay_colored_Sparrow
|
| 118 |
+
118 118.House_Sparrow
|
| 119 |
+
119 119.Field_Sparrow
|
| 120 |
+
120 120.Fox_Sparrow
|
| 121 |
+
121 121.Grasshopper_Sparrow
|
| 122 |
+
122 122.Harris_Sparrow
|
| 123 |
+
123 123.Henslow_Sparrow
|
| 124 |
+
124 124.Le_Conte_Sparrow
|
| 125 |
+
125 125.Lincoln_Sparrow
|
| 126 |
+
126 126.Nelson_Sharp_tailed_Sparrow
|
| 127 |
+
127 127.Savannah_Sparrow
|
| 128 |
+
128 128.Seaside_Sparrow
|
| 129 |
+
129 129.Song_Sparrow
|
| 130 |
+
130 130.Tree_Sparrow
|
| 131 |
+
131 131.Vesper_Sparrow
|
| 132 |
+
132 132.White_crowned_Sparrow
|
| 133 |
+
133 133.White_throated_Sparrow
|
| 134 |
+
134 134.Cape_Glossy_Starling
|
| 135 |
+
135 135.Bank_Swallow
|
| 136 |
+
136 136.Barn_Swallow
|
| 137 |
+
137 137.Cliff_Swallow
|
| 138 |
+
138 138.Tree_Swallow
|
| 139 |
+
139 139.Scarlet_Tanager
|
| 140 |
+
140 140.Summer_Tanager
|
| 141 |
+
141 141.Artic_Tern
|
| 142 |
+
142 142.Black_Tern
|
| 143 |
+
143 143.Caspian_Tern
|
| 144 |
+
144 144.Common_Tern
|
| 145 |
+
145 145.Elegant_Tern
|
| 146 |
+
146 146.Forsters_Tern
|
| 147 |
+
147 147.Least_Tern
|
| 148 |
+
148 148.Green_tailed_Towhee
|
| 149 |
+
149 149.Brown_Thrasher
|
| 150 |
+
150 150.Sage_Thrasher
|
| 151 |
+
151 151.Black_capped_Vireo
|
| 152 |
+
152 152.Blue_headed_Vireo
|
| 153 |
+
153 153.Philadelphia_Vireo
|
| 154 |
+
154 154.Red_eyed_Vireo
|
| 155 |
+
155 155.Warbling_Vireo
|
| 156 |
+
156 156.White_eyed_Vireo
|
| 157 |
+
157 157.Yellow_throated_Vireo
|
| 158 |
+
158 158.Bay_breasted_Warbler
|
| 159 |
+
159 159.Black_and_white_Warbler
|
| 160 |
+
160 160.Black_throated_Blue_Warbler
|
| 161 |
+
161 161.Blue_winged_Warbler
|
| 162 |
+
162 162.Canada_Warbler
|
| 163 |
+
163 163.Cape_May_Warbler
|
| 164 |
+
164 164.Cerulean_Warbler
|
| 165 |
+
165 165.Chestnut_sided_Warbler
|
| 166 |
+
166 166.Golden_winged_Warbler
|
| 167 |
+
167 167.Hooded_Warbler
|
| 168 |
+
168 168.Kentucky_Warbler
|
| 169 |
+
169 169.Magnolia_Warbler
|
| 170 |
+
170 170.Mourning_Warbler
|
| 171 |
+
171 171.Myrtle_Warbler
|
| 172 |
+
172 172.Nashville_Warbler
|
| 173 |
+
173 173.Orange_crowned_Warbler
|
| 174 |
+
174 174.Palm_Warbler
|
| 175 |
+
175 175.Pine_Warbler
|
| 176 |
+
176 176.Prairie_Warbler
|
| 177 |
+
177 177.Prothonotary_Warbler
|
| 178 |
+
178 178.Swainson_Warbler
|
| 179 |
+
179 179.Tennessee_Warbler
|
| 180 |
+
180 180.Wilson_Warbler
|
| 181 |
+
181 181.Worm_eating_Warbler
|
| 182 |
+
182 182.Yellow_Warbler
|
| 183 |
+
183 183.Northern_Waterthrush
|
| 184 |
+
184 184.Louisiana_Waterthrush
|
| 185 |
+
185 185.Bohemian_Waxwing
|
| 186 |
+
186 186.Cedar_Waxwing
|
| 187 |
+
187 187.American_Three_toed_Woodpecker
|
| 188 |
+
188 188.Pileated_Woodpecker
|
| 189 |
+
189 189.Red_bellied_Woodpecker
|
| 190 |
+
190 190.Red_cockaded_Woodpecker
|
| 191 |
+
191 191.Red_headed_Woodpecker
|
| 192 |
+
192 192.Downy_Woodpecker
|
| 193 |
+
193 193.Bewick_Wren
|
| 194 |
+
194 194.Cactus_Wren
|
| 195 |
+
195 195.Carolina_Wren
|
| 196 |
+
196 196.House_Wren
|
| 197 |
+
197 197.Marsh_Wren
|
| 198 |
+
198 198.Rock_Wren
|
| 199 |
+
199 199.Winter_Wren
|
| 200 |
+
200 200.Common_Yellowthroat
|
model.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.utils.model_zoo as model_zoo
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
import time
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
model_urls = {
|
| 11 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 12 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 13 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 14 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 15 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BasicBlock(nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
This is a basic block that contains two convolutional layers followed by
|
| 23 |
+
a batch normalization layer and a ReLU activation function, where the skip
|
| 24 |
+
connection is added before the second relu.
|
| 25 |
+
---
|
| 26 |
+
|
| 27 |
+
- inplanes: { int } - The number of input channels.
|
| 28 |
+
- planes: { int } - The number of output channels.
|
| 29 |
+
- stride: { int } - The stride of convolutional layers.
|
| 30 |
+
- downsample: { nn.Sequential } - A sequential of convolutional layers that fit the
|
| 31 |
+
identity mapping to the desired output size.
|
| 32 |
+
"""
|
| 33 |
+
expansion = 1
|
| 34 |
+
|
| 35 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 36 |
+
super(BasicBlock, self).__init__()
|
| 37 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
| 38 |
+
padding=1, bias=False)
|
| 39 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 40 |
+
self.relu = nn.ReLU(inplace=True)
|
| 41 |
+
|
| 42 |
+
self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
|
| 43 |
+
padding=1, bias=False)
|
| 44 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 45 |
+
self.downsample = downsample
|
| 46 |
+
self.stride = stride
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
"""
|
| 50 |
+
This is the forward pass of the basic block where the input tensor x is passed
|
| 51 |
+
through the first convolutional layer, batch normalization layer, and the ReLU
|
| 52 |
+
activation function. The result is passed through the second convolutional layer,
|
| 53 |
+
batch normalization layer, and the ReLU activation function. The result is then
|
| 54 |
+
added to the identity mapping and passed through the ReLU activation function.
|
| 55 |
+
"""
|
| 56 |
+
residual = x
|
| 57 |
+
|
| 58 |
+
# Convolve with a 3X3Xplanes kernel
|
| 59 |
+
out = self.conv1(x)
|
| 60 |
+
out = self.bn1(out)
|
| 61 |
+
out = self.relu(out)
|
| 62 |
+
|
| 63 |
+
# Convolve with a 3X3Xplanes kernel
|
| 64 |
+
out = self.conv2(out)
|
| 65 |
+
out = self.bn2(out)
|
| 66 |
+
|
| 67 |
+
# If the stride is not 1 or the number of input channels is not equal
|
| 68 |
+
# to the number of output channels then we need to fit the identity
|
| 69 |
+
# mapping to the desired output size by applying the downsample.
|
| 70 |
+
if self.downsample is not None:
|
| 71 |
+
residual = self.downsample(x)
|
| 72 |
+
|
| 73 |
+
# Add the identity mapping to the output of the second convolutional layer.
|
| 74 |
+
out += residual
|
| 75 |
+
# Apply the ReLU activation function after the addition.
|
| 76 |
+
out = self.relu(out)
|
| 77 |
+
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class Bottleneck(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
This class defines a bottle neck that fits the identity mapping to the desired
|
| 85 |
+
output size before adding it to the output of the following layers.
|
| 86 |
+
---
|
| 87 |
+
- inplanes: { int } - The number of input channels.
|
| 88 |
+
- planes: { int } - The number of output channels.
|
| 89 |
+
- stride: { int } - The stride of the second convolutional layer.
|
| 90 |
+
- downsample: { nn.Sequential } - A sequential of convolutional layers that fit the
|
| 91 |
+
identity mapping to the desired output size.
|
| 92 |
+
|
| 93 |
+
The following layers are defined:
|
| 94 |
+
- A 1x1 convolutional layer (self.conv1) with inplanes input channels and planes
|
| 95 |
+
output channels is defined.
|
| 96 |
+
- A batch normalization layer (self.bn1) is defined for the output of self.conv1.
|
| 97 |
+
- A 3x3 convolutional layer (self.conv2) with planes input channels, planes output
|
| 98 |
+
channels, and stride 'stride' is defined.
|
| 99 |
+
- A batch normalization layer (self.bn2) is defined for the output of self.conv2.
|
| 100 |
+
- A 1x1 convolutional layer (self.conv3) with planes input channels
|
| 101 |
+
and planes * self.expansion output channels is defined.
|
| 102 |
+
- A batch normalization layer (self.bn3) is defined for the output of self.conv3.
|
| 103 |
+
- A ReLU activation function (self.relu) is defined.
|
| 104 |
+
"""
|
| 105 |
+
expansion = 4
|
| 106 |
+
|
| 107 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 108 |
+
super(Bottleneck, self).__init__()
|
| 109 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 110 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 111 |
+
|
| 112 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
| 113 |
+
stride=stride, padding=1, bias=False)
|
| 114 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 115 |
+
|
| 116 |
+
self.conv3 = nn.Conv2d(
|
| 117 |
+
planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 118 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 119 |
+
self.relu = nn.ReLU(inplace=True)
|
| 120 |
+
|
| 121 |
+
self.downsample = downsample
|
| 122 |
+
self.stride = stride
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
"""
|
| 126 |
+
The Forward Pass
|
| 127 |
+
----------------
|
| 128 |
+
Steps:
|
| 129 |
+
|
| 130 |
+
- The input tensor x is saved as residual.
|
| 131 |
+
- x is passed through self.conv1, self.bn1, and self.relu.
|
| 132 |
+
- The result is passed through self.conv2, self.bn2, and self.relu.
|
| 133 |
+
- The result is passed through self.conv3 and self.bn3.
|
| 134 |
+
|
| 135 |
+
- If self.downsample is not None, residual is passed through self.downsample.
|
| 136 |
+
- The output of the previous step is added to out.
|
| 137 |
+
- The result is passed through self.relu.
|
| 138 |
+
- The result is returned.
|
| 139 |
+
"""
|
| 140 |
+
residual = x
|
| 141 |
+
# Convolve with a 1X1Xplanes kernel
|
| 142 |
+
out = self.conv1(x)
|
| 143 |
+
out = self.bn1(out)
|
| 144 |
+
out = self.relu(out)
|
| 145 |
+
|
| 146 |
+
# Convolve with a 3X3Xplanes kernel
|
| 147 |
+
out = self.conv2(out)
|
| 148 |
+
out = self.bn2(out)
|
| 149 |
+
out = self.relu(out)
|
| 150 |
+
|
| 151 |
+
# Convolve with a 1X1Xplanes*expansion kernel
|
| 152 |
+
out = self.conv3(out)
|
| 153 |
+
out = self.bn3(out)
|
| 154 |
+
|
| 155 |
+
# If the stride is not 1 or the number of input channels is not equal
|
| 156 |
+
# to the number of output channels then we need to fit the identity
|
| 157 |
+
# mapping to the desired output size by applying the downsample.
|
| 158 |
+
if self.downsample is not None:
|
| 159 |
+
residual = self.downsample(x)
|
| 160 |
+
|
| 161 |
+
out += residual
|
| 162 |
+
# Apply the ReLU activation function after the addition.
|
| 163 |
+
out = self.relu(out)
|
| 164 |
+
|
| 165 |
+
return out
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ResNet(nn.Module):
|
| 169 |
+
"""
|
| 170 |
+
This is the ResNet class that is used in ResNet50, ResNet101, and ResNet152.
|
| 171 |
+
"""
|
| 172 |
+
def __init__(self, block, layers, stride=None):
|
| 173 |
+
self.inplanes = 64
|
| 174 |
+
super(ResNet, self).__init__()
|
| 175 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 176 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 177 |
+
self.relu = nn.ReLU(inplace=True)
|
| 178 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 179 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0])
|
| 180 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1])
|
| 181 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2])
|
| 182 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3])
|
| 183 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 184 |
+
|
| 185 |
+
self.fc = nn.Linear(512 * block.expansion, 1000)
|
| 186 |
+
|
| 187 |
+
for m in self.modules():
|
| 188 |
+
if isinstance(m, nn.Conv2d):
|
| 189 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 190 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 191 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 192 |
+
m.weight.data.fill_(1)
|
| 193 |
+
m.bias.data.zero_()
|
| 194 |
+
|
| 195 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 196 |
+
downsample = None
|
| 197 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 198 |
+
downsample = nn.Sequential(
|
| 199 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 200 |
+
kernel_size=1, stride=stride, bias=False),
|
| 201 |
+
nn.BatchNorm2d(planes * block.expansion),
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
layers = []
|
| 205 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 206 |
+
self.inplanes = planes * block.expansion
|
| 207 |
+
for i in range(1, blocks):
|
| 208 |
+
layers.append(block(self.inplanes, planes))
|
| 209 |
+
|
| 210 |
+
return nn.Sequential(*layers)
|
| 211 |
+
|
| 212 |
+
def forward(self, x):
|
| 213 |
+
x = self.conv1(x)
|
| 214 |
+
x = self.bn1(x)
|
| 215 |
+
x = self.relu(x)
|
| 216 |
+
x = self.maxpool(x)
|
| 217 |
+
|
| 218 |
+
x = self.layer1(x)
|
| 219 |
+
x = self.layer2(x)
|
| 220 |
+
x = self.layer3(x)
|
| 221 |
+
x = self.layer4(x)
|
| 222 |
+
|
| 223 |
+
x = self.avgpool(x)
|
| 224 |
+
x = x.view(x.size(0), -1)
|
| 225 |
+
x = self.fc(x)
|
| 226 |
+
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def resnet50(pretrained=False, stride=None, num_classes=200, **kwargs):
|
| 233 |
+
"""Constructs a ResNet-50 model.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 237 |
+
:param pretrained:
|
| 238 |
+
:param stride:
|
| 239 |
+
"""
|
| 240 |
+
if stride is None:
|
| 241 |
+
stride = [1, 2, 2, 1]
|
| 242 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], stride=stride, **kwargs)
|
| 243 |
+
if pretrained:
|
| 244 |
+
model.load_state_dict(model_zoo.load_url(
|
| 245 |
+
model_urls['resnet50']), strict=True)
|
| 246 |
+
if num_classes != 1000:
|
| 247 |
+
model.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
|
| 248 |
+
return model
|
requirements.txt
CHANGED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.0.0
|
| 2 |
+
aiofiles==23.2.1
|
| 3 |
+
altair==5.1.2
|
| 4 |
+
annotated-types==0.6.0
|
| 5 |
+
anyio==3.7.1
|
| 6 |
+
attrs==23.1.0
|
| 7 |
+
certifi==2023.7.22
|
| 8 |
+
charset-normalizer==3.3.1
|
| 9 |
+
chex==0.1.84
|
| 10 |
+
click==8.1.7
|
| 11 |
+
colorama==0.4.6
|
| 12 |
+
contourpy==1.1.1
|
| 13 |
+
cycler==0.12.1
|
| 14 |
+
etils==1.5.2
|
| 15 |
+
fastapi==0.104.0
|
| 16 |
+
ffmpy==0.3.1
|
| 17 |
+
filelock==3.12.4
|
| 18 |
+
flax==0.7.4
|
| 19 |
+
fonttools==4.43.1
|
| 20 |
+
fsspec==2023.10.0
|
| 21 |
+
gradio==3.50.2
|
| 22 |
+
gradio_client==0.6.1
|
| 23 |
+
h11==0.14.0
|
| 24 |
+
httpcore==0.18.0
|
| 25 |
+
httpx==0.25.0
|
| 26 |
+
huggingface-hub==0.17.3
|
| 27 |
+
idna==3.4
|
| 28 |
+
importlib-resources==6.1.0
|
| 29 |
+
jax==0.4.19
|
| 30 |
+
jaxlib==0.4.19
|
| 31 |
+
Jinja2==3.1.2
|
| 32 |
+
jsonschema==4.19.1
|
| 33 |
+
jsonschema-specifications==2023.7.1
|
| 34 |
+
kiwisolver==1.4.5
|
| 35 |
+
markdown-it-py==3.0.0
|
| 36 |
+
MarkupSafe==2.1.3
|
| 37 |
+
matplotlib==3.8.0
|
| 38 |
+
mdurl==0.1.2
|
| 39 |
+
ml-dtypes==0.3.1
|
| 40 |
+
mpmath==1.3.0
|
| 41 |
+
msgpack==1.0.7
|
| 42 |
+
nest-asyncio==1.5.8
|
| 43 |
+
networkx==3.2
|
| 44 |
+
numpy==1.26.1
|
| 45 |
+
opt-einsum==3.3.0
|
| 46 |
+
optax==0.1.7
|
| 47 |
+
orbax-checkpoint==0.4.1
|
| 48 |
+
orjson==3.9.9
|
| 49 |
+
packaging==23.2
|
| 50 |
+
pandas==2.1.1
|
| 51 |
+
Pillow==10.1.0
|
| 52 |
+
protobuf==4.24.4
|
| 53 |
+
pydantic==2.4.2
|
| 54 |
+
pydantic_core==2.10.1
|
| 55 |
+
pydub==0.25.1
|
| 56 |
+
Pygments==2.16.1
|
| 57 |
+
pyparsing==3.1.1
|
| 58 |
+
python-dateutil==2.8.2
|
| 59 |
+
python-multipart==0.0.6
|
| 60 |
+
pytz==2023.3.post1
|
| 61 |
+
PyYAML==6.0.1
|
| 62 |
+
referencing==0.30.2
|
| 63 |
+
regex==2023.10.3
|
| 64 |
+
requests==2.31.0
|
| 65 |
+
rich==13.6.0
|
| 66 |
+
rpds-py==0.10.6
|
| 67 |
+
safetensors==0.4.0
|
| 68 |
+
scipy==1.11.3
|
| 69 |
+
semantic-version==2.10.0
|
| 70 |
+
six==1.16.0
|
| 71 |
+
sniffio==1.3.0
|
| 72 |
+
starlette==0.27.0
|
| 73 |
+
sympy==1.12
|
| 74 |
+
tensorstore==0.1.46
|
| 75 |
+
tokenizers==0.14.1
|
| 76 |
+
toolz==0.12.0
|
| 77 |
+
torch==2.1.0
|
| 78 |
+
torchvision==0.16.0
|
| 79 |
+
tqdm==4.66.1
|
| 80 |
+
transformers==4.34.1
|
| 81 |
+
typing_extensions==4.8.0
|
| 82 |
+
tzdata==2023.3
|
| 83 |
+
urllib3==2.0.7
|
| 84 |
+
uvicorn==0.23.2
|
| 85 |
+
websockets==11.0.3
|
| 86 |
+
zipp==3.17.0
|