Billy-06 commited on
Commit
a263b83
·
1 Parent(s): eb9829e

Added the file Architectures

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +53 -4
  3. classes.txt +200 -0
  4. model.py +248 -0
  5. requirements.txt +86 -0
.gitignore CHANGED
@@ -3,6 +3,9 @@ flagged/
3
  *.png
4
  *.jpg
5
  *.jpeg
 
6
  gradio_cache/
7
 
8
  venv/
 
 
 
3
  *.png
4
  *.jpg
5
  *.jpeg
6
+ *.pyc
7
  gradio_cache/
8
 
9
  venv/
10
+ __pychache__/
11
+
app.py CHANGED
@@ -1,7 +1,56 @@
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
 
5
+ from model import *
 
6
 
7
+ def load_cub200_classes():
8
+ """
9
+ This function loads the classes from the classes.txt file and returns a dictionary
10
+ """
11
+ with open("classes.txt", encoding="utf-8") as f:
12
+ classes = f.read().splitlines()
13
+
14
+ # convert classes to dictionary separating the lines by the first space
15
+ classes = {int(line.split(" ")[0]) : line.split(" ")[1] for line in classes}
16
+
17
+ # return the classes dictionary
18
+ return classes
19
+
20
+ def load_model():
21
+ """
22
+ This function loads the trained model and returns it
23
+ """
24
+
25
+ # load the resnet model
26
+ model = resnet50(pretrained=False, stride=[1, 2, 2, 1], num_classes=200)
27
+ # load the trained weights
28
+ model.load_state_dict(torch.load("resnet.pt", map_location=torch.device('cpu')))
29
+ # set the model to evaluation mode
30
+ model.eval()
31
+ # return the model
32
+ return model
33
+
34
+ def predict_image(image):
35
+ """
36
+ This function takes an image as input and returns the class label
37
+ """
38
+
39
+ # load the model
40
+ model = load_model()
41
+ # load the classes
42
+ classes = load_cub200_classes()
43
+
44
+ # convert image to tensor
45
+ tensor = torch.from_numpy(image).permute(2, 0, 1).float().unsqueeze(0)
46
+ # make prediction
47
+ prediction = model(tensor).detach().numpy()[0]
48
+ # convert prediction to probabilities
49
+ probabilities = np.exp(prediction) / np.sum(np.exp(prediction))
50
+ # get the class with the highest probability
51
+ class_idx = np.argmax(probabilities)
52
+ # return the class label
53
+ return "Class: " + classes[class_idx]
54
+
55
+ # create a gradio interface
56
+ gr.Interface(fn=predict_image, inputs="image", outputs="text").launch()
classes.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 1 001.Black_footed_Albatross
2
+ 2 002.Laysan_Albatross
3
+ 3 003.Sooty_Albatross
4
+ 4 004.Groove_billed_Ani
5
+ 5 005.Crested_Auklet
6
+ 6 006.Least_Auklet
7
+ 7 007.Parakeet_Auklet
8
+ 8 008.Rhinoceros_Auklet
9
+ 9 009.Brewer_Blackbird
10
+ 10 010.Red_winged_Blackbird
11
+ 11 011.Rusty_Blackbird
12
+ 12 012.Yellow_headed_Blackbird
13
+ 13 013.Bobolink
14
+ 14 014.Indigo_Bunting
15
+ 15 015.Lazuli_Bunting
16
+ 16 016.Painted_Bunting
17
+ 17 017.Cardinal
18
+ 18 018.Spotted_Catbird
19
+ 19 019.Gray_Catbird
20
+ 20 020.Yellow_breasted_Chat
21
+ 21 021.Eastern_Towhee
22
+ 22 022.Chuck_will_Widow
23
+ 23 023.Brandt_Cormorant
24
+ 24 024.Red_faced_Cormorant
25
+ 25 025.Pelagic_Cormorant
26
+ 26 026.Bronzed_Cowbird
27
+ 27 027.Shiny_Cowbird
28
+ 28 028.Brown_Creeper
29
+ 29 029.American_Crow
30
+ 30 030.Fish_Crow
31
+ 31 031.Black_billed_Cuckoo
32
+ 32 032.Mangrove_Cuckoo
33
+ 33 033.Yellow_billed_Cuckoo
34
+ 34 034.Gray_crowned_Rosy_Finch
35
+ 35 035.Purple_Finch
36
+ 36 036.Northern_Flicker
37
+ 37 037.Acadian_Flycatcher
38
+ 38 038.Great_Crested_Flycatcher
39
+ 39 039.Least_Flycatcher
40
+ 40 040.Olive_sided_Flycatcher
41
+ 41 041.Scissor_tailed_Flycatcher
42
+ 42 042.Vermilion_Flycatcher
43
+ 43 043.Yellow_bellied_Flycatcher
44
+ 44 044.Frigatebird
45
+ 45 045.Northern_Fulmar
46
+ 46 046.Gadwall
47
+ 47 047.American_Goldfinch
48
+ 48 048.European_Goldfinch
49
+ 49 049.Boat_tailed_Grackle
50
+ 50 050.Eared_Grebe
51
+ 51 051.Horned_Grebe
52
+ 52 052.Pied_billed_Grebe
53
+ 53 053.Western_Grebe
54
+ 54 054.Blue_Grosbeak
55
+ 55 055.Evening_Grosbeak
56
+ 56 056.Pine_Grosbeak
57
+ 57 057.Rose_breasted_Grosbeak
58
+ 58 058.Pigeon_Guillemot
59
+ 59 059.California_Gull
60
+ 60 060.Glaucous_winged_Gull
61
+ 61 061.Heermann_Gull
62
+ 62 062.Herring_Gull
63
+ 63 063.Ivory_Gull
64
+ 64 064.Ring_billed_Gull
65
+ 65 065.Slaty_backed_Gull
66
+ 66 066.Western_Gull
67
+ 67 067.Anna_Hummingbird
68
+ 68 068.Ruby_throated_Hummingbird
69
+ 69 069.Rufous_Hummingbird
70
+ 70 070.Green_Violetear
71
+ 71 071.Long_tailed_Jaeger
72
+ 72 072.Pomarine_Jaeger
73
+ 73 073.Blue_Jay
74
+ 74 074.Florida_Jay
75
+ 75 075.Green_Jay
76
+ 76 076.Dark_eyed_Junco
77
+ 77 077.Tropical_Kingbird
78
+ 78 078.Gray_Kingbird
79
+ 79 079.Belted_Kingfisher
80
+ 80 080.Green_Kingfisher
81
+ 81 081.Pied_Kingfisher
82
+ 82 082.Ringed_Kingfisher
83
+ 83 083.White_breasted_Kingfisher
84
+ 84 084.Red_legged_Kittiwake
85
+ 85 085.Horned_Lark
86
+ 86 086.Pacific_Loon
87
+ 87 087.Mallard
88
+ 88 088.Western_Meadowlark
89
+ 89 089.Hooded_Merganser
90
+ 90 090.Red_breasted_Merganser
91
+ 91 091.Mockingbird
92
+ 92 092.Nighthawk
93
+ 93 093.Clark_Nutcracker
94
+ 94 094.White_breasted_Nuthatch
95
+ 95 095.Baltimore_Oriole
96
+ 96 096.Hooded_Oriole
97
+ 97 097.Orchard_Oriole
98
+ 98 098.Scott_Oriole
99
+ 99 099.Ovenbird
100
+ 100 100.Brown_Pelican
101
+ 101 101.White_Pelican
102
+ 102 102.Western_Wood_Pewee
103
+ 103 103.Sayornis
104
+ 104 104.American_Pipit
105
+ 105 105.Whip_poor_Will
106
+ 106 106.Horned_Puffin
107
+ 107 107.Common_Raven
108
+ 108 108.White_necked_Raven
109
+ 109 109.American_Redstart
110
+ 110 110.Geococcyx
111
+ 111 111.Loggerhead_Shrike
112
+ 112 112.Great_Grey_Shrike
113
+ 113 113.Baird_Sparrow
114
+ 114 114.Black_throated_Sparrow
115
+ 115 115.Brewer_Sparrow
116
+ 116 116.Chipping_Sparrow
117
+ 117 117.Clay_colored_Sparrow
118
+ 118 118.House_Sparrow
119
+ 119 119.Field_Sparrow
120
+ 120 120.Fox_Sparrow
121
+ 121 121.Grasshopper_Sparrow
122
+ 122 122.Harris_Sparrow
123
+ 123 123.Henslow_Sparrow
124
+ 124 124.Le_Conte_Sparrow
125
+ 125 125.Lincoln_Sparrow
126
+ 126 126.Nelson_Sharp_tailed_Sparrow
127
+ 127 127.Savannah_Sparrow
128
+ 128 128.Seaside_Sparrow
129
+ 129 129.Song_Sparrow
130
+ 130 130.Tree_Sparrow
131
+ 131 131.Vesper_Sparrow
132
+ 132 132.White_crowned_Sparrow
133
+ 133 133.White_throated_Sparrow
134
+ 134 134.Cape_Glossy_Starling
135
+ 135 135.Bank_Swallow
136
+ 136 136.Barn_Swallow
137
+ 137 137.Cliff_Swallow
138
+ 138 138.Tree_Swallow
139
+ 139 139.Scarlet_Tanager
140
+ 140 140.Summer_Tanager
141
+ 141 141.Artic_Tern
142
+ 142 142.Black_Tern
143
+ 143 143.Caspian_Tern
144
+ 144 144.Common_Tern
145
+ 145 145.Elegant_Tern
146
+ 146 146.Forsters_Tern
147
+ 147 147.Least_Tern
148
+ 148 148.Green_tailed_Towhee
149
+ 149 149.Brown_Thrasher
150
+ 150 150.Sage_Thrasher
151
+ 151 151.Black_capped_Vireo
152
+ 152 152.Blue_headed_Vireo
153
+ 153 153.Philadelphia_Vireo
154
+ 154 154.Red_eyed_Vireo
155
+ 155 155.Warbling_Vireo
156
+ 156 156.White_eyed_Vireo
157
+ 157 157.Yellow_throated_Vireo
158
+ 158 158.Bay_breasted_Warbler
159
+ 159 159.Black_and_white_Warbler
160
+ 160 160.Black_throated_Blue_Warbler
161
+ 161 161.Blue_winged_Warbler
162
+ 162 162.Canada_Warbler
163
+ 163 163.Cape_May_Warbler
164
+ 164 164.Cerulean_Warbler
165
+ 165 165.Chestnut_sided_Warbler
166
+ 166 166.Golden_winged_Warbler
167
+ 167 167.Hooded_Warbler
168
+ 168 168.Kentucky_Warbler
169
+ 169 169.Magnolia_Warbler
170
+ 170 170.Mourning_Warbler
171
+ 171 171.Myrtle_Warbler
172
+ 172 172.Nashville_Warbler
173
+ 173 173.Orange_crowned_Warbler
174
+ 174 174.Palm_Warbler
175
+ 175 175.Pine_Warbler
176
+ 176 176.Prairie_Warbler
177
+ 177 177.Prothonotary_Warbler
178
+ 178 178.Swainson_Warbler
179
+ 179 179.Tennessee_Warbler
180
+ 180 180.Wilson_Warbler
181
+ 181 181.Worm_eating_Warbler
182
+ 182 182.Yellow_Warbler
183
+ 183 183.Northern_Waterthrush
184
+ 184 184.Louisiana_Waterthrush
185
+ 185 185.Bohemian_Waxwing
186
+ 186 186.Cedar_Waxwing
187
+ 187 187.American_Three_toed_Woodpecker
188
+ 188 188.Pileated_Woodpecker
189
+ 189 189.Red_bellied_Woodpecker
190
+ 190 190.Red_cockaded_Woodpecker
191
+ 191 191.Red_headed_Woodpecker
192
+ 192 192.Downy_Woodpecker
193
+ 193 193.Bewick_Wren
194
+ 194 194.Cactus_Wren
195
+ 195 195.Carolina_Wren
196
+ 196 196.House_Wren
197
+ 197 197.Marsh_Wren
198
+ 198 198.Rock_Wren
199
+ 199 199.Winter_Wren
200
+ 200 200.Common_Yellowthroat
model.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch.nn as nn
3
+ import torch.utils.model_zoo as model_zoo
4
+ import torch.optim as optim
5
+ from torchvision import transforms
6
+ import time
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ model_urls = {
11
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16
+ }
17
+
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ """
22
+ This is a basic block that contains two convolutional layers followed by
23
+ a batch normalization layer and a ReLU activation function, where the skip
24
+ connection is added before the second relu.
25
+ ---
26
+
27
+ - inplanes: { int } - The number of input channels.
28
+ - planes: { int } - The number of output channels.
29
+ - stride: { int } - The stride of convolutional layers.
30
+ - downsample: { nn.Sequential } - A sequential of convolutional layers that fit the
31
+ identity mapping to the desired output size.
32
+ """
33
+ expansion = 1
34
+
35
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
36
+ super(BasicBlock, self).__init__()
37
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
38
+ padding=1, bias=False)
39
+ self.bn1 = nn.BatchNorm2d(planes)
40
+ self.relu = nn.ReLU(inplace=True)
41
+
42
+ self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
43
+ padding=1, bias=False)
44
+ self.bn2 = nn.BatchNorm2d(planes)
45
+ self.downsample = downsample
46
+ self.stride = stride
47
+
48
+ def forward(self, x):
49
+ """
50
+ This is the forward pass of the basic block where the input tensor x is passed
51
+ through the first convolutional layer, batch normalization layer, and the ReLU
52
+ activation function. The result is passed through the second convolutional layer,
53
+ batch normalization layer, and the ReLU activation function. The result is then
54
+ added to the identity mapping and passed through the ReLU activation function.
55
+ """
56
+ residual = x
57
+
58
+ # Convolve with a 3X3Xplanes kernel
59
+ out = self.conv1(x)
60
+ out = self.bn1(out)
61
+ out = self.relu(out)
62
+
63
+ # Convolve with a 3X3Xplanes kernel
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ # If the stride is not 1 or the number of input channels is not equal
68
+ # to the number of output channels then we need to fit the identity
69
+ # mapping to the desired output size by applying the downsample.
70
+ if self.downsample is not None:
71
+ residual = self.downsample(x)
72
+
73
+ # Add the identity mapping to the output of the second convolutional layer.
74
+ out += residual
75
+ # Apply the ReLU activation function after the addition.
76
+ out = self.relu(out)
77
+
78
+ return out
79
+
80
+
81
+
82
+ class Bottleneck(nn.Module):
83
+ """
84
+ This class defines a bottle neck that fits the identity mapping to the desired
85
+ output size before adding it to the output of the following layers.
86
+ ---
87
+ - inplanes: { int } - The number of input channels.
88
+ - planes: { int } - The number of output channels.
89
+ - stride: { int } - The stride of the second convolutional layer.
90
+ - downsample: { nn.Sequential } - A sequential of convolutional layers that fit the
91
+ identity mapping to the desired output size.
92
+
93
+ The following layers are defined:
94
+ - A 1x1 convolutional layer (self.conv1) with inplanes input channels and planes
95
+ output channels is defined.
96
+ - A batch normalization layer (self.bn1) is defined for the output of self.conv1.
97
+ - A 3x3 convolutional layer (self.conv2) with planes input channels, planes output
98
+ channels, and stride 'stride' is defined.
99
+ - A batch normalization layer (self.bn2) is defined for the output of self.conv2.
100
+ - A 1x1 convolutional layer (self.conv3) with planes input channels
101
+ and planes * self.expansion output channels is defined.
102
+ - A batch normalization layer (self.bn3) is defined for the output of self.conv3.
103
+ - A ReLU activation function (self.relu) is defined.
104
+ """
105
+ expansion = 4
106
+
107
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
108
+ super(Bottleneck, self).__init__()
109
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
110
+ self.bn1 = nn.BatchNorm2d(planes)
111
+
112
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
113
+ stride=stride, padding=1, bias=False)
114
+ self.bn2 = nn.BatchNorm2d(planes)
115
+
116
+ self.conv3 = nn.Conv2d(
117
+ planes, planes * self.expansion, kernel_size=1, bias=False)
118
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
119
+ self.relu = nn.ReLU(inplace=True)
120
+
121
+ self.downsample = downsample
122
+ self.stride = stride
123
+
124
+ def forward(self, x):
125
+ """
126
+ The Forward Pass
127
+ ----------------
128
+ Steps:
129
+
130
+ - The input tensor x is saved as residual.
131
+ - x is passed through self.conv1, self.bn1, and self.relu.
132
+ - The result is passed through self.conv2, self.bn2, and self.relu.
133
+ - The result is passed through self.conv3 and self.bn3.
134
+
135
+ - If self.downsample is not None, residual is passed through self.downsample.
136
+ - The output of the previous step is added to out.
137
+ - The result is passed through self.relu.
138
+ - The result is returned.
139
+ """
140
+ residual = x
141
+ # Convolve with a 1X1Xplanes kernel
142
+ out = self.conv1(x)
143
+ out = self.bn1(out)
144
+ out = self.relu(out)
145
+
146
+ # Convolve with a 3X3Xplanes kernel
147
+ out = self.conv2(out)
148
+ out = self.bn2(out)
149
+ out = self.relu(out)
150
+
151
+ # Convolve with a 1X1Xplanes*expansion kernel
152
+ out = self.conv3(out)
153
+ out = self.bn3(out)
154
+
155
+ # If the stride is not 1 or the number of input channels is not equal
156
+ # to the number of output channels then we need to fit the identity
157
+ # mapping to the desired output size by applying the downsample.
158
+ if self.downsample is not None:
159
+ residual = self.downsample(x)
160
+
161
+ out += residual
162
+ # Apply the ReLU activation function after the addition.
163
+ out = self.relu(out)
164
+
165
+ return out
166
+
167
+
168
+ class ResNet(nn.Module):
169
+ """
170
+ This is the ResNet class that is used in ResNet50, ResNet101, and ResNet152.
171
+ """
172
+ def __init__(self, block, layers, stride=None):
173
+ self.inplanes = 64
174
+ super(ResNet, self).__init__()
175
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
176
+ self.bn1 = nn.BatchNorm2d(64)
177
+ self.relu = nn.ReLU(inplace=True)
178
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
179
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0])
180
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1])
181
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2])
182
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3])
183
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
184
+
185
+ self.fc = nn.Linear(512 * block.expansion, 1000)
186
+
187
+ for m in self.modules():
188
+ if isinstance(m, nn.Conv2d):
189
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
190
+ m.weight.data.normal_(0, math.sqrt(2. / n))
191
+ elif isinstance(m, nn.BatchNorm2d):
192
+ m.weight.data.fill_(1)
193
+ m.bias.data.zero_()
194
+
195
+ def _make_layer(self, block, planes, blocks, stride=1):
196
+ downsample = None
197
+ if stride != 1 or self.inplanes != planes * block.expansion:
198
+ downsample = nn.Sequential(
199
+ nn.Conv2d(self.inplanes, planes * block.expansion,
200
+ kernel_size=1, stride=stride, bias=False),
201
+ nn.BatchNorm2d(planes * block.expansion),
202
+ )
203
+
204
+ layers = []
205
+ layers.append(block(self.inplanes, planes, stride, downsample))
206
+ self.inplanes = planes * block.expansion
207
+ for i in range(1, blocks):
208
+ layers.append(block(self.inplanes, planes))
209
+
210
+ return nn.Sequential(*layers)
211
+
212
+ def forward(self, x):
213
+ x = self.conv1(x)
214
+ x = self.bn1(x)
215
+ x = self.relu(x)
216
+ x = self.maxpool(x)
217
+
218
+ x = self.layer1(x)
219
+ x = self.layer2(x)
220
+ x = self.layer3(x)
221
+ x = self.layer4(x)
222
+
223
+ x = self.avgpool(x)
224
+ x = x.view(x.size(0), -1)
225
+ x = self.fc(x)
226
+
227
+ return x
228
+
229
+
230
+
231
+
232
+ def resnet50(pretrained=False, stride=None, num_classes=200, **kwargs):
233
+ """Constructs a ResNet-50 model.
234
+
235
+ Args:
236
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
237
+ :param pretrained:
238
+ :param stride:
239
+ """
240
+ if stride is None:
241
+ stride = [1, 2, 2, 1]
242
+ model = ResNet(Bottleneck, [3, 4, 6, 3], stride=stride, **kwargs)
243
+ if pretrained:
244
+ model.load_state_dict(model_zoo.load_url(
245
+ model_urls['resnet50']), strict=True)
246
+ if num_classes != 1000:
247
+ model.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)
248
+ return model
requirements.txt CHANGED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ aiofiles==23.2.1
3
+ altair==5.1.2
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ attrs==23.1.0
7
+ certifi==2023.7.22
8
+ charset-normalizer==3.3.1
9
+ chex==0.1.84
10
+ click==8.1.7
11
+ colorama==0.4.6
12
+ contourpy==1.1.1
13
+ cycler==0.12.1
14
+ etils==1.5.2
15
+ fastapi==0.104.0
16
+ ffmpy==0.3.1
17
+ filelock==3.12.4
18
+ flax==0.7.4
19
+ fonttools==4.43.1
20
+ fsspec==2023.10.0
21
+ gradio==3.50.2
22
+ gradio_client==0.6.1
23
+ h11==0.14.0
24
+ httpcore==0.18.0
25
+ httpx==0.25.0
26
+ huggingface-hub==0.17.3
27
+ idna==3.4
28
+ importlib-resources==6.1.0
29
+ jax==0.4.19
30
+ jaxlib==0.4.19
31
+ Jinja2==3.1.2
32
+ jsonschema==4.19.1
33
+ jsonschema-specifications==2023.7.1
34
+ kiwisolver==1.4.5
35
+ markdown-it-py==3.0.0
36
+ MarkupSafe==2.1.3
37
+ matplotlib==3.8.0
38
+ mdurl==0.1.2
39
+ ml-dtypes==0.3.1
40
+ mpmath==1.3.0
41
+ msgpack==1.0.7
42
+ nest-asyncio==1.5.8
43
+ networkx==3.2
44
+ numpy==1.26.1
45
+ opt-einsum==3.3.0
46
+ optax==0.1.7
47
+ orbax-checkpoint==0.4.1
48
+ orjson==3.9.9
49
+ packaging==23.2
50
+ pandas==2.1.1
51
+ Pillow==10.1.0
52
+ protobuf==4.24.4
53
+ pydantic==2.4.2
54
+ pydantic_core==2.10.1
55
+ pydub==0.25.1
56
+ Pygments==2.16.1
57
+ pyparsing==3.1.1
58
+ python-dateutil==2.8.2
59
+ python-multipart==0.0.6
60
+ pytz==2023.3.post1
61
+ PyYAML==6.0.1
62
+ referencing==0.30.2
63
+ regex==2023.10.3
64
+ requests==2.31.0
65
+ rich==13.6.0
66
+ rpds-py==0.10.6
67
+ safetensors==0.4.0
68
+ scipy==1.11.3
69
+ semantic-version==2.10.0
70
+ six==1.16.0
71
+ sniffio==1.3.0
72
+ starlette==0.27.0
73
+ sympy==1.12
74
+ tensorstore==0.1.46
75
+ tokenizers==0.14.1
76
+ toolz==0.12.0
77
+ torch==2.1.0
78
+ torchvision==0.16.0
79
+ tqdm==4.66.1
80
+ transformers==4.34.1
81
+ typing_extensions==4.8.0
82
+ tzdata==2023.3
83
+ urllib3==2.0.7
84
+ uvicorn==0.23.2
85
+ websockets==11.0.3
86
+ zipp==3.17.0