sandbox338 commited on
Commit
8f8c75e
·
verified ·
1 Parent(s): 6e13a93

Upload model_architecture.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_architecture.py +155 -0
model_architecture.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Essential code to recreate model architecture
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as models
6
+ from detectron2.modeling.backbone import Backbone, BACKBONE_REGISTRY
7
+ from detectron2.layers import ShapeSpec
8
+
9
+ class WildlifeInceptionBackbone(nn.Module):
10
+ def __init__(self):
11
+ super(WildlifeInceptionBackbone, self).__init__()
12
+ inception = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, aux_logits=True)
13
+ inception.eval()
14
+
15
+ self.Conv2d_1a_3x3 = inception.Conv2d_1a_3x3
16
+ self.Conv2d_2a_3x3 = inception.Conv2d_2a_3x3
17
+ self.Conv2d_2b_3x3 = inception.Conv2d_2b_3x3
18
+ self.maxpool1 = inception.maxpool1
19
+ self.Conv2d_3b_1x1 = inception.Conv2d_3b_1x1
20
+ self.Conv2d_4a_3x3 = inception.Conv2d_4a_3x3
21
+ self.maxpool2 = inception.maxpool2
22
+
23
+ self.Mixed_5b = inception.Mixed_5b
24
+ self.Mixed_5c = inception.Mixed_5c
25
+ self.Mixed_5d = inception.Mixed_5d
26
+ self.Mixed_6a = inception.Mixed_6a
27
+ self.Mixed_6b = inception.Mixed_6b
28
+ self.Mixed_6c = inception.Mixed_6c
29
+ self.Mixed_6d = inception.Mixed_6d
30
+ self.Mixed_6e = inception.Mixed_6e
31
+ self.Mixed_7a = inception.Mixed_7a
32
+ self.Mixed_7b = inception.Mixed_7b
33
+ self.Mixed_7c = inception.Mixed_7c
34
+
35
+ self.level4_enhance = nn.Sequential(
36
+ nn.Conv2d(768, 256, 3, padding=1, bias=False),
37
+ nn.BatchNorm2d(256),
38
+ nn.ReLU(inplace=True),
39
+ nn.Conv2d(256, 256, 1, bias=False),
40
+ nn.BatchNorm2d(256),
41
+ nn.ReLU(inplace=True)
42
+ )
43
+
44
+ self.level5_enhance = nn.Sequential(
45
+ nn.Conv2d(2048, 256, 3, padding=1, bias=False),
46
+ nn.BatchNorm2d(256),
47
+ nn.ReLU(inplace=True),
48
+ nn.Conv2d(256, 256, 1, bias=False),
49
+ nn.BatchNorm2d(256),
50
+ nn.ReLU(inplace=True)
51
+ )
52
+
53
+ self._init_weights()
54
+
55
+ def _init_weights(self):
56
+ for m in [self.level4_enhance, self.level5_enhance]:
57
+ for layer in m.modules():
58
+ if isinstance(layer, nn.Conv2d):
59
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
60
+ elif isinstance(layer, nn.BatchNorm2d):
61
+ nn.init.constant_(layer.weight, 1)
62
+ nn.init.constant_(layer.bias, 0)
63
+
64
+ def forward(self, x):
65
+ x = self.Conv2d_1a_3x3(x)
66
+ x = self.Conv2d_2a_3x3(x)
67
+ x = self.Conv2d_2b_3x3(x)
68
+ x = self.maxpool1(x)
69
+ x = self.Conv2d_3b_1x1(x)
70
+ x = self.Conv2d_4a_3x3(x)
71
+ x = self.maxpool2(x)
72
+
73
+ x = self.Mixed_5b(x)
74
+ x = self.Mixed_5c(x)
75
+ x = self.Mixed_5d(x)
76
+ x = self.Mixed_6a(x)
77
+ x = self.Mixed_6b(x)
78
+ x = self.Mixed_6c(x)
79
+ x = self.Mixed_6d(x)
80
+
81
+ level4_raw = self.Mixed_6e(x)
82
+ level4_features = self.level4_enhance(level4_raw)
83
+
84
+ x = self.Mixed_7a(level4_raw)
85
+ x = self.Mixed_7b(x)
86
+ level5_raw = self.Mixed_7c(x)
87
+ level5_features = self.level5_enhance(level5_raw)
88
+
89
+ return {
90
+ "res4": level4_features,
91
+ "res5": level5_features
92
+ }
93
+
94
+ class EnhancedResNetBackbone(nn.Module):
95
+ def __init__(self):
96
+ super(EnhancedResNetBackbone, self).__init__()
97
+ resnet = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V2)
98
+
99
+ self.conv1 = resnet.conv1
100
+ self.bn1 = resnet.bn1
101
+ self.relu = resnet.relu
102
+ self.maxpool = resnet.maxpool
103
+ self.layer1 = resnet.layer1
104
+ self.layer2 = resnet.layer2
105
+ self.layer3 = resnet.layer3
106
+ self.layer4 = resnet.layer4
107
+
108
+ self.enhance_res4 = nn.Sequential(
109
+ nn.Conv2d(1024, 256, 3, padding=1, bias=False),
110
+ nn.BatchNorm2d(256),
111
+ nn.ReLU(inplace=True),
112
+ nn.Conv2d(256, 256, 1, bias=False),
113
+ nn.BatchNorm2d(256),
114
+ nn.ReLU(inplace=True)
115
+ )
116
+
117
+ self.enhance_res5 = nn.Sequential(
118
+ nn.Conv2d(2048, 256, 3, padding=1, bias=False),
119
+ nn.BatchNorm2d(256),
120
+ nn.ReLU(inplace=True),
121
+ nn.Conv2d(256, 256, 1, bias=False),
122
+ nn.BatchNorm2d(256),
123
+ nn.ReLU(inplace=True)
124
+ )
125
+
126
+ self._init_weights()
127
+
128
+ def _init_weights(self):
129
+ for m in [self.enhance_res4, self.enhance_res5]:
130
+ for layer in m.modules():
131
+ if isinstance(layer, nn.Conv2d):
132
+ nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
133
+ elif isinstance(layer, nn.BatchNorm2d):
134
+ nn.init.constant_(layer.weight, 1)
135
+ nn.init.constant_(layer.bias, 0)
136
+
137
+ def forward(self, x):
138
+ x = self.conv1(x)
139
+ x = self.bn1(x)
140
+ x = self.relu(x)
141
+ x = self.maxpool(x)
142
+
143
+ x = self.layer1(x)
144
+ x = self.layer2(x)
145
+
146
+ res4_raw = self.layer3(x)
147
+ res4_enhanced = self.enhance_res4(res4_raw)
148
+
149
+ res5_raw = self.layer4(res4_raw)
150
+ res5_enhanced = self.enhance_res5(res5_raw)
151
+
152
+ return {
153
+ "res4": res4_enhanced,
154
+ "res5": res5_enhanced
155
+ }