airlabshare commited on
Commit
702b5a3
·
verified ·
1 Parent(s): edf3e27

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +94 -0
model.py CHANGED
@@ -107,6 +107,100 @@ class AnyThermalVPRModel(PreTrainedModel):
107
  # Global descriptor
108
  return self.vpr_head((patch_tokens, cls_token))
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  AnyThermalVPRModel.register_for_auto_class("AutoModel")
111
  # Register for AutoModel discovery
112
  AnyThermalSegmentationModel.register_for_auto_class("AutoModel")
 
107
  # Global descriptor
108
  return self.vpr_head((patch_tokens, cls_token))
109
 
110
+ class AnyThermalDepthConfig(Dinov2Config):
111
+ model_type = "anythermal_depth"
112
+ def __init__(self, features=256, **kwargs):
113
+ super().__init__(**kwargs)
114
+ self.features = features
115
+
116
+ class ResidualConvUnit(nn.Module):
117
+ def __init__(self, features):
118
+ super().__init__()
119
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
120
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
121
+ self.relu = nn.ReLU(inplace=True)
122
+ def forward(self, x):
123
+ out = self.relu(x)
124
+ out = self.conv1(out)
125
+ out = self.relu(out)
126
+ out = self.conv2(out)
127
+ return out + x
128
+
129
+ class FeatureFusionBlock(nn.Module):
130
+ def __init__(self, features):
131
+ super().__init__()
132
+ self.resConfUnit1 = ResidualConvUnit(features)
133
+ self.resConfUnit2 = ResidualConvUnit(features)
134
+ def forward(self, *xs):
135
+ output = xs[0]
136
+ if len(xs) == 2:
137
+ output = output + self.resConfUnit1(xs[1])
138
+ output = self.resConfUnit2(output)
139
+ output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True)
140
+ return output
141
+
142
+ class AnyThermalDepthModel(PreTrainedModel):
143
+ config_class = AnyThermalDepthConfig
144
+ def __init__(self, config):
145
+ super().__init__(config)
146
+ self.backbone = Dinov2Model(config)
147
+ features = config.features
148
+
149
+ # Layers to match 'scratch' in blocks.py
150
+ self.scratch = nn.Module()
151
+ self.scratch.layer1_rn = nn.Conv2d(96, features, kernel_size=3, padding=1, bias=False)
152
+ self.scratch.layer2_rn = nn.Conv2d(192, features, kernel_size=3, padding=1, bias=False)
153
+ self.scratch.layer3_rn = nn.Conv2d(384, features, kernel_size=3, padding=1, bias=False)
154
+ self.scratch.layer4_rn = nn.Conv2d(768, features, kernel_size=3, padding=1, bias=False)
155
+
156
+ # Post-processing from vit.py
157
+ self.act_postprocess1 = nn.Sequential(nn.Conv2d(768, 96, 1), nn.ConvTranspose2d(96, 96, 4, stride=4))
158
+ self.act_postprocess2 = nn.Sequential(nn.Conv2d(768, 192, 1), nn.ConvTranspose2d(192, 192, 2, stride=2))
159
+ self.act_postprocess3 = nn.Sequential(nn.Conv2d(768, 384, 1))
160
+ self.act_postprocess4 = nn.Sequential(nn.Conv2d(768, 768, 1), nn.Conv2d(768, 768, 3, stride=2, padding=1))
161
+
162
+ # Fusion and output
163
+ self.refinenet4 = FeatureFusionBlock(features)
164
+ self.refinenet3 = FeatureFusionBlock(features)
165
+ self.refinenet2 = FeatureFusionBlock(features)
166
+ self.refinenet1 = FeatureFusionBlock(features)
167
+ self.output_conv = nn.Sequential(
168
+ nn.Conv2d(features, 128, kernel_size=3, padding=1),
169
+ nn.Upsample(scale_factor=1.75, mode="bilinear"), # Specific to Dinov2-ViT-B14
170
+ nn.Conv2d(128, 32, kernel_size=3, padding=1),
171
+ nn.ReLU(True),
172
+ nn.Conv2d(32, 1, kernel_size=1),
173
+ nn.ReLU(True)
174
+ )
175
+ self.post_init()
176
+
177
+ def forward(self, pixel_values):
178
+ # Extract features from layers 2, 5, 8, 11
179
+ outputs = self.backbone(pixel_values, output_hidden_states=True)
180
+ layers = [outputs.hidden_states[i] for i in [3, 6, 9, 12]]
181
+
182
+ def process(l, h, w):
183
+ l = l[:, 1:, :].transpose(1, 2)
184
+ return l.reshape(l.shape[0], l.shape[1], h//14, w//14)
185
+
186
+ b, _, h, w = pixel_values.shape
187
+ l1, l2, l3, l4 = [process(layers[i], h, w) for i in range(4)]
188
+
189
+ # Sequential Fusion
190
+ layer_1_rn = self.scratch.layer1_rn(self.act_postprocess1(l1))
191
+ layer_2_rn = self.scratch.layer2_rn(self.act_postprocess2(l2))
192
+ layer_3_rn = self.scratch.layer3_rn(self.act_postprocess3(l3))
193
+ layer_4_rn = self.scratch.layer4_rn(self.act_postprocess4(l4))
194
+
195
+ path_4 = self.refinenet4(layer_4_rn)
196
+ path_3 = self.refinenet3(path_4, layer_3_rn)
197
+ path_2 = self.refinenet2(path_3, layer_2_rn)
198
+ path_1 = self.refinenet1(path_2, layer_1_rn)
199
+
200
+ return self.output_conv(path_1).squeeze(1)
201
+
202
+ AnyThermalDepthModel.register_for_auto_class("AutoModel")
203
+
204
  AnyThermalVPRModel.register_for_auto_class("AutoModel")
205
  # Register for AutoModel discovery
206
  AnyThermalSegmentationModel.register_for_auto_class("AutoModel")