gberton commited on
Commit
ab33843
·
verified ·
1 Parent(s): 233ec29

Upload dpt_head.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dpt_head.py +702 -0
dpt_head.py ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """DPT (Dense Prediction Transformer) depth head in PyTorch.
17
+
18
+ Ported from the Scenic/Flax implementation at:
19
+ research/vision/scene_understanding/imsight/modules/dpt.py
20
+ scenic/projects/dense_features/models/decoders.py
21
+
22
+ Architecture:
23
+ ReassembleBlocks → 4×Conv3x3 → 4×FeatureFusionBlock → project → DepthHead
24
+ """
25
+
26
+ import io
27
+ import os
28
+ import urllib.request
29
+ import zipfile
30
+
31
+ import numpy as np
32
+ import torch
33
+ from torch import nn
34
+ import torch.nn.functional as F
35
+
36
+
37
+ # ── Building blocks ─────────────────────────────────────────────────────────
38
+
39
+
40
+ class PreActResidualConvUnit(nn.Module):
41
+ """Pre-activation residual convolution unit."""
42
+
43
+ def __init__(self, features: int):
44
+ super().__init__()
45
+ self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False)
46
+ self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False)
47
+
48
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
49
+ residual = x
50
+ x = F.relu(x)
51
+ x = self.conv1(x)
52
+ x = F.relu(x)
53
+ x = self.conv2(x)
54
+ return x + residual
55
+
56
+
57
+ class FeatureFusionBlock(nn.Module):
58
+ """Fuses features with optional residual input, then upsamples 2×."""
59
+
60
+ def __init__(self, features: int, has_residual: bool = False,
61
+ expand: bool = False):
62
+ super().__init__()
63
+ self.has_residual = has_residual
64
+ if has_residual:
65
+ self.residual_unit = PreActResidualConvUnit(features)
66
+ self.main_unit = PreActResidualConvUnit(features)
67
+ out_features = features // 2 if expand else features
68
+ self.out_conv = nn.Conv2d(features, out_features, 1, bias=True)
69
+
70
+ def forward(self, x: torch.Tensor,
71
+ residual: torch.Tensor = None) -> torch.Tensor:
72
+ if self.has_residual and residual is not None:
73
+ if residual.shape != x.shape:
74
+ residual = F.interpolate(
75
+ residual, size=x.shape[2:], mode="bilinear",
76
+ align_corners=False)
77
+ residual = self.residual_unit(residual)
78
+ x = x + residual
79
+ x = self.main_unit(x)
80
+ # Upsample 2× with align_corners=True (matches Scenic reference)
81
+ x = F.interpolate(x, scale_factor=2, mode="bilinear",
82
+ align_corners=True)
83
+ x = self.out_conv(x)
84
+ return x
85
+
86
+
87
+ class ReassembleBlocks(nn.Module):
88
+ """Projects and resizes intermediate ViT features to different scales."""
89
+
90
+ def __init__(self, input_embed_dim: int = 1024,
91
+ out_channels: tuple = (128, 256, 512, 1024),
92
+ readout_type: str = "project"):
93
+ super().__init__()
94
+ self.readout_type = readout_type
95
+
96
+ # 1×1 conv to project to per-level channels
97
+ self.out_projections = nn.ModuleList([
98
+ nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels
99
+ ])
100
+
101
+ # Spatial resize layers: 4× up, 2× up, identity, 2× down
102
+ self.resize_layers = nn.ModuleList([
103
+ nn.ConvTranspose2d(out_channels[0], out_channels[0],
104
+ kernel_size=4, stride=4, padding=0),
105
+ nn.ConvTranspose2d(out_channels[1], out_channels[1],
106
+ kernel_size=2, stride=2, padding=0),
107
+ nn.Identity(),
108
+ nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2,
109
+ padding=1),
110
+ ])
111
+
112
+ # Readout projection (concatenate cls_token with patch features)
113
+ if readout_type == "project":
114
+ self.readout_projects = nn.ModuleList([
115
+ nn.Linear(2 * input_embed_dim, input_embed_dim)
116
+ for _ in out_channels
117
+ ])
118
+
119
+ def forward(self, features):
120
+ """Process list of (cls_token, spatial_features) tuples.
121
+
122
+ Args:
123
+ features: list of (cls_token [B,D], patch_feats [B,D,H,W])
124
+
125
+ Returns:
126
+ list of tensors at different scales.
127
+ """
128
+ out = []
129
+ for i, (cls_token, x) in enumerate(features):
130
+ B, D, H, W = x.shape
131
+
132
+ if self.readout_type == "project":
133
+ # Flatten spatial → (B, HW, D)
134
+ x_flat = x.flatten(2).transpose(1, 2)
135
+ # Expand cls_token → (B, HW, D)
136
+ readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1)
137
+ # Concat + project + GELU
138
+ x_cat = torch.cat([x_flat, readout], dim=-1)
139
+ x_proj = F.gelu(self.readout_projects[i](x_cat))
140
+ # Reshape back to spatial
141
+ x = x_proj.transpose(1, 2).reshape(B, D, H, W)
142
+
143
+ # 1×1 projection
144
+ x = self.out_projections[i](x)
145
+ # Spatial resize
146
+ x = self.resize_layers[i](x)
147
+ out.append(x)
148
+ return out
149
+
150
+
151
+ class DPTDepthHead(nn.Module):
152
+ """Full DPT head + depth classification decoder.
153
+
154
+ Takes 4 intermediate ViT features and produces a depth map.
155
+ """
156
+
157
+ def __init__(self, input_embed_dim: int = 1024,
158
+ channels: int = 256,
159
+ post_process_channels: tuple = (128, 256, 512, 1024),
160
+ readout_type: str = "project",
161
+ num_depth_bins: int = 256,
162
+ min_depth: float = 1e-3,
163
+ max_depth: float = 10.0):
164
+ super().__init__()
165
+ self.num_depth_bins = num_depth_bins
166
+ self.min_depth = min_depth
167
+ self.max_depth = max_depth
168
+
169
+ # Reassemble: project + resize
170
+ self.reassemble = ReassembleBlocks(
171
+ input_embed_dim=input_embed_dim,
172
+ out_channels=post_process_channels,
173
+ readout_type=readout_type,
174
+ )
175
+
176
+ # 3×3 convs to map each level to `channels`
177
+ self.convs = nn.ModuleList([
178
+ nn.Conv2d(ch, channels, 3, padding=1, bias=False)
179
+ for ch in post_process_channels
180
+ ])
181
+
182
+ # Fusion blocks: first has no residual, rest have residual
183
+ self.fusion_blocks = nn.ModuleList([
184
+ FeatureFusionBlock(channels, has_residual=False),
185
+ FeatureFusionBlock(channels, has_residual=True),
186
+ FeatureFusionBlock(channels, has_residual=True),
187
+ FeatureFusionBlock(channels, has_residual=True),
188
+ ])
189
+
190
+ # Final projection
191
+ self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
192
+
193
+ # Depth classification head (Dense layer)
194
+ self.depth_head = nn.Linear(channels, num_depth_bins)
195
+
196
+ def forward(self, intermediate_features, image_size=None):
197
+ """Run DPT depth prediction.
198
+
199
+ Args:
200
+ intermediate_features: list of 4 (cls_token, patch_feats) tuples
201
+ image_size: (H, W) to resize output to, or None
202
+
203
+ Returns:
204
+ depth map tensor (B, 1, H, W)
205
+ """
206
+ # Reassemble
207
+ x = self.reassemble(intermediate_features)
208
+ # 3×3 conv per level
209
+ x = [self.convs[i](feat) for i, feat in enumerate(x)]
210
+
211
+ # Fuse bottom-up: start from deepest (x[-1])
212
+ out = self.fusion_blocks[0](x[-1])
213
+ for i in range(1, 4):
214
+ out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
215
+
216
+ # Project
217
+ out = self.project(out)
218
+ out = F.relu(out)
219
+
220
+ # Depth classification
221
+ # out: (B, C, H, W) → (B, H, W, C)
222
+ out = out.permute(0, 2, 3, 1)
223
+ out = self.depth_head(out) # (B, H, W, num_bins)
224
+
225
+ # Classification-based depth prediction
226
+ bin_centers = torch.linspace(
227
+ self.min_depth, self.max_depth, self.num_depth_bins,
228
+ device=out.device)
229
+ out = F.relu(out) + self.min_depth
230
+ out_norm = out / out.sum(dim=-1, keepdim=True)
231
+ depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers)
232
+ depth = depth.unsqueeze(1) # (B, 1, H, W)
233
+
234
+ # Resize to original image size
235
+ if image_size is not None:
236
+ depth = F.interpolate(depth, size=image_size, mode="bilinear",
237
+ align_corners=False)
238
+
239
+ return depth
240
+
241
+
242
+ class DPTNormalsHead(nn.Module):
243
+ """Full DPT head + surface normals decoder.
244
+
245
+ Takes 4 intermediate ViT features and produces a normal map.
246
+ """
247
+
248
+ def __init__(self, input_embed_dim: int = 1024,
249
+ channels: int = 256,
250
+ post_process_channels: tuple = (128, 256, 512, 1024),
251
+ readout_type: str = "project"):
252
+ super().__init__()
253
+
254
+ # Reassemble: project + resize
255
+ self.reassemble = ReassembleBlocks(
256
+ input_embed_dim=input_embed_dim,
257
+ out_channels=post_process_channels,
258
+ readout_type=readout_type,
259
+ )
260
+
261
+ # 3×3 convs to map each level to `channels`
262
+ self.convs = nn.ModuleList([
263
+ nn.Conv2d(ch, channels, 3, padding=1, bias=False)
264
+ for ch in post_process_channels
265
+ ])
266
+
267
+ # Fusion blocks: first has no residual, rest have residual
268
+ self.fusion_blocks = nn.ModuleList([
269
+ FeatureFusionBlock(channels, has_residual=False),
270
+ FeatureFusionBlock(channels, has_residual=True),
271
+ FeatureFusionBlock(channels, has_residual=True),
272
+ FeatureFusionBlock(channels, has_residual=True),
273
+ ])
274
+
275
+ # Final projection
276
+ self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
277
+
278
+ # Normals head (Dense layer)
279
+ self.normals_head = nn.Linear(channels, 3)
280
+
281
+ def forward(self, intermediate_features, image_size=None):
282
+ """Run DPT normals prediction.
283
+
284
+ Args:
285
+ intermediate_features: list of 4 (cls_token, patch_feats) tuples
286
+ image_size: (H, W) to resize output to, or None
287
+
288
+ Returns:
289
+ normal map tensor (B, 3, H, W)
290
+ """
291
+ # Reassemble
292
+ x = self.reassemble(intermediate_features)
293
+ # 3×3 conv per level
294
+ x = [self.convs[i](feat) for i, feat in enumerate(x)]
295
+
296
+ # Fuse bottom-up: start from deepest (x[-1])
297
+ out = self.fusion_blocks[0](x[-1])
298
+ for i in range(1, 4):
299
+ out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
300
+
301
+ # Project
302
+ out = self.project(out)
303
+
304
+ # Normals head
305
+ # out: (B, C, H, W) → (B, H, W, C)
306
+ out = out.permute(0, 2, 3, 1)
307
+ out = self.normals_head(out) # (B, H, W, 3)
308
+
309
+ # Normalize to unit length
310
+ out = F.normalize(out, p=2, dim=-1)
311
+
312
+ # Resize to original image size
313
+ if image_size is not None:
314
+ # PyTorch interpolate expects (B, C, H, W)
315
+ out = out.permute(0, 3, 1, 2)
316
+ out = F.interpolate(out, size=image_size, mode="bilinear",
317
+ align_corners=False)
318
+ else:
319
+ out = out.permute(0, 3, 1, 2)
320
+
321
+ return out
322
+
323
+
324
+ class DPTSegmentationHead(nn.Module):
325
+ """Full DPT head + segmentation decoder.
326
+
327
+ Takes 4 intermediate ViT features and produces a segmentation map.
328
+ """
329
+
330
+ def __init__(self, input_embed_dim: int = 1024,
331
+ channels: int = 256,
332
+ post_process_channels: tuple = (128, 256, 512, 1024),
333
+ readout_type: str = "project",
334
+ num_classes: int = 150):
335
+ super().__init__()
336
+
337
+ # Reassemble: project + resize
338
+ self.reassemble = ReassembleBlocks(
339
+ input_embed_dim=input_embed_dim,
340
+ out_channels=post_process_channels,
341
+ readout_type=readout_type,
342
+ )
343
+
344
+ # 3×3 convs to map each level to `channels`
345
+ self.convs = nn.ModuleList([
346
+ nn.Conv2d(ch, channels, 3, padding=1, bias=False)
347
+ for ch in post_process_channels
348
+ ])
349
+
350
+ # Fusion blocks: first has no residual, rest have residual
351
+ self.fusion_blocks = nn.ModuleList([
352
+ FeatureFusionBlock(channels, has_residual=False),
353
+ FeatureFusionBlock(channels, has_residual=True),
354
+ FeatureFusionBlock(channels, has_residual=True),
355
+ FeatureFusionBlock(channels, has_residual=True),
356
+ ])
357
+
358
+ # Final projection
359
+ self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
360
+
361
+ # Segmentation head (Dense layer)
362
+ self.segmentation_head = nn.Linear(channels, num_classes)
363
+
364
+ def forward(self, intermediate_features, image_size=None):
365
+ """Run DPT segmentation prediction.
366
+
367
+ Args:
368
+ intermediate_features: list of 4 (cls_token, patch_feats) tuples
369
+ image_size: (H, W) to resize output to, or None
370
+
371
+ Returns:
372
+ segmentation map tensor (B, num_classes, H, W)
373
+ """
374
+ # Reassemble
375
+ x = self.reassemble(intermediate_features)
376
+ # 3×3 conv per level
377
+ x = [self.convs[i](feat) for i, feat in enumerate(x)]
378
+
379
+ # Fuse bottom-up: start from deepest (x[-1])
380
+ out = self.fusion_blocks[0](x[-1])
381
+ for i in range(1, 4):
382
+ out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
383
+
384
+ # Project
385
+ out = self.project(out)
386
+
387
+ # Segmentation head
388
+ # out: (B, C, H, W) → (B, H, W, C)
389
+ out = out.permute(0, 2, 3, 1)
390
+ out = self.segmentation_head(out) # (B, H, W, num_classes)
391
+
392
+ # Resize to original image size
393
+ if image_size is not None:
394
+ # PyTorch interpolate expects (B, C, H, W)
395
+ out = out.permute(0, 3, 1, 2)
396
+ out = F.interpolate(out, size=image_size, mode="bilinear",
397
+ align_corners=False)
398
+ else:
399
+ out = out.permute(0, 3, 1, 2)
400
+
401
+ return out
402
+
403
+
404
+ # ── Weight loading from Scenic/Flax checkpoint ─────────────────────────────
405
+
406
+
407
+ def _load_npy_from_zip(zf, name):
408
+ """Load a single .npy array from a zipfile."""
409
+ with zf.open(name) as f:
410
+ return np.load(io.BytesIO(f.read()))
411
+
412
+
413
+ def _conv_kernel_flax_to_torch(w):
414
+ """Convert Flax conv kernel (H,W,Cin,Cout) → PyTorch (Cout,Cin,H,W)."""
415
+ return torch.from_numpy(w.transpose(3, 2, 0, 1).copy())
416
+
417
+
418
+ def _conv_transpose_kernel_flax_to_torch(w):
419
+ """Convert Flax ConvTranspose kernel (H,W,Cin,Cout) → PyTorch (Cin,Cout,H,W)."""
420
+ return torch.from_numpy(w.transpose(2, 3, 0, 1).copy())
421
+
422
+
423
+ def _linear_kernel_flax_to_torch(w):
424
+ """Convert Flax Dense kernel (in,out) → PyTorch Linear (out,in)."""
425
+ return torch.from_numpy(w.T.copy())
426
+
427
+
428
+ def _bias(w):
429
+ return torch.from_numpy(w.copy())
430
+
431
+
432
+ def load_dpt_weights(model: DPTDepthHead, zip_path: str):
433
+ """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
434
+ zf = zipfile.ZipFile(zip_path, "r")
435
+ npy = lambda name: _load_npy_from_zip(zf, name)
436
+ sd = {}
437
+ prefix = "decoder/dpt/"
438
+
439
+ # --- ReassembleBlocks ---
440
+ for i in range(4):
441
+ # out_projections (Conv2d 1×1)
442
+ sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
443
+ npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
444
+ sd[f"reassemble.out_projections.{i}.bias"] = _bias(
445
+ npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
446
+
447
+ # readout_projects (Linear)
448
+ sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
449
+ npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
450
+ sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
451
+ npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
452
+
453
+ # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
454
+ sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
455
+ npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
456
+ sd["reassemble.resize_layers.0.bias"] = _bias(
457
+ npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
458
+ sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
459
+ npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
460
+ sd["reassemble.resize_layers.1.bias"] = _bias(
461
+ npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
462
+ # resize_layers_2 = Identity (no weights)
463
+ sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
464
+ npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
465
+ sd["reassemble.resize_layers.3.bias"] = _bias(
466
+ npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
467
+
468
+ # --- Convs (3×3, no bias) ---
469
+ for i in range(4):
470
+ sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
471
+ npy(f"{prefix}convs_{i}/kernel.npy"))
472
+
473
+ # --- Fusion blocks ---
474
+ for i in range(4):
475
+ fb = f"{prefix}fusion_blocks_{i}/"
476
+ if i == 0:
477
+ # No residual unit, only 1 PreActResidualConvUnit
478
+ sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
479
+ npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
480
+ sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
481
+ npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
482
+ else:
483
+ # Residual unit (index 0) + main unit (index 1)
484
+ sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
485
+ npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
486
+ sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
487
+ npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
488
+ sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
489
+ npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
490
+ sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
491
+ npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
492
+
493
+ # out_conv (Conv2d 1×1)
494
+ sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
495
+ npy(f"{fb}Conv_0/kernel.npy"))
496
+ sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
497
+ npy(f"{fb}Conv_0/bias.npy"))
498
+
499
+ # --- Project ---
500
+ sd["project.weight"] = _conv_kernel_flax_to_torch(
501
+ npy(f"{prefix}project/kernel.npy"))
502
+ sd["project.bias"] = _bias(
503
+ npy(f"{prefix}project/bias.npy"))
504
+
505
+ # --- Depth classification head ---
506
+ sd["depth_head.weight"] = _linear_kernel_flax_to_torch(
507
+ npy("decoder/pixel_depth_classif/kernel.npy"))
508
+ sd["depth_head.bias"] = _bias(
509
+ npy("decoder/pixel_depth_classif/bias.npy"))
510
+
511
+ zf.close()
512
+
513
+ # Load into model
514
+ missing, unexpected = model.load_state_dict(sd, strict=True)
515
+ if missing:
516
+ print(f"WARNING: Missing keys: {missing}")
517
+ if unexpected:
518
+ print(f"WARNING: Unexpected keys: {unexpected}")
519
+ print(f"Loaded DPT depth head weights ({len(sd)} tensors)")
520
+ return model
521
+
522
+
523
+ def load_normals_weights(model: DPTNormalsHead, zip_path: str):
524
+ """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
525
+ zf = zipfile.ZipFile(zip_path, "r")
526
+ npy = lambda name: _load_npy_from_zip(zf, name)
527
+ sd = {}
528
+ prefix = "decoder/dpt/"
529
+
530
+ # --- ReassembleBlocks ---
531
+ for i in range(4):
532
+ # out_projections (Conv2d 1×1)
533
+ sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
534
+ npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
535
+ sd[f"reassemble.out_projections.{i}.bias"] = _bias(
536
+ npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
537
+
538
+ # readout_projects (Linear)
539
+ sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
540
+ npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
541
+ sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
542
+ npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
543
+
544
+ # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
545
+ sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
546
+ npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
547
+ sd["reassemble.resize_layers.0.bias"] = _bias(
548
+ npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
549
+ sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
550
+ npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
551
+ sd["reassemble.resize_layers.1.bias"] = _bias(
552
+ npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
553
+ # resize_layers_2 = Identity (no weights)
554
+ sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
555
+ npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
556
+ sd["reassemble.resize_layers.3.bias"] = _bias(
557
+ npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
558
+
559
+ # --- Convs (3×3, no bias) ---
560
+ for i in range(4):
561
+ sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
562
+ npy(f"{prefix}convs_{i}/kernel.npy"))
563
+
564
+ # --- Fusion blocks ---
565
+ for i in range(4):
566
+ fb = f"{prefix}fusion_blocks_{i}/"
567
+ if i == 0:
568
+ # No residual unit, only 1 PreActResidualConvUnit
569
+ sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
570
+ npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
571
+ sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
572
+ npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
573
+ else:
574
+ # Residual unit (index 0) + main unit (index 1)
575
+ sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
576
+ npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
577
+ sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
578
+ npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
579
+ sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
580
+ npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
581
+ sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
582
+ npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
583
+
584
+ # out_conv (Conv2d 1×1)
585
+ sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
586
+ npy(f"{fb}Conv_0/kernel.npy"))
587
+ sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
588
+ npy(f"{fb}Conv_0/bias.npy"))
589
+
590
+ # --- Project ---
591
+ sd["project.weight"] = _conv_kernel_flax_to_torch(
592
+ npy(f"{prefix}project/kernel.npy"))
593
+ sd["project.bias"] = _bias(
594
+ npy(f"{prefix}project/bias.npy"))
595
+
596
+ # --- Normals head ---
597
+ sd["normals_head.weight"] = _linear_kernel_flax_to_torch(
598
+ npy("decoder/pixel_normals/kernel.npy"))
599
+ sd["normals_head.bias"] = _bias(
600
+ npy("decoder/pixel_normals/bias.npy"))
601
+
602
+ zf.close()
603
+
604
+ # Load into model
605
+ missing, unexpected = model.load_state_dict(sd, strict=True)
606
+ if missing:
607
+ print(f"WARNING: Missing keys: {missing}")
608
+ if unexpected:
609
+ print(f"WARNING: Unexpected keys: {unexpected}")
610
+ print(f"Loaded DPT normals head weights ({len(sd)} tensors)")
611
+ return model
612
+
613
+
614
+ def load_segmentation_weights(model: DPTSegmentationHead, zip_path: str):
615
+ """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
616
+ zf = zipfile.ZipFile(zip_path, "r")
617
+ npy = lambda name: _load_npy_from_zip(zf, name)
618
+ sd = {}
619
+ prefix = "decoder/dpt/"
620
+
621
+ # --- ReassembleBlocks ---
622
+ for i in range(4):
623
+ # out_projections (Conv2d 1×1)
624
+ sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
625
+ npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
626
+ sd[f"reassemble.out_projections.{i}.bias"] = _bias(
627
+ npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
628
+
629
+ # readout_projects (Linear)
630
+ sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
631
+ npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
632
+ sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
633
+ npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
634
+
635
+ # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
636
+ sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
637
+ npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
638
+ sd["reassemble.resize_layers.0.bias"] = _bias(
639
+ npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
640
+ sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
641
+ npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
642
+ sd["reassemble.resize_layers.1.bias"] = _bias(
643
+ npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
644
+ # resize_layers_2 = Identity (no weights)
645
+ sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
646
+ npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
647
+ sd["reassemble.resize_layers.3.bias"] = _bias(
648
+ npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
649
+
650
+ # --- Convs (3×3, no bias) ---
651
+ for i in range(4):
652
+ sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
653
+ npy(f"{prefix}convs_{i}/kernel.npy"))
654
+
655
+ # --- Fusion blocks ---
656
+ for i in range(4):
657
+ fb = f"{prefix}fusion_blocks_{i}/"
658
+ if i == 0:
659
+ # No residual unit, only 1 PreActResidualConvUnit
660
+ sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
661
+ npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
662
+ sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
663
+ npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
664
+ else:
665
+ # Residual unit (index 0) + main unit (index 1)
666
+ sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
667
+ npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
668
+ sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
669
+ npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
670
+ sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
671
+ npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
672
+ sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
673
+ npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
674
+
675
+ # out_conv (Conv2d 1×1)
676
+ sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
677
+ npy(f"{fb}Conv_0/kernel.npy"))
678
+ sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
679
+ npy(f"{fb}Conv_0/bias.npy"))
680
+
681
+ # --- Project ---
682
+ sd["project.weight"] = _conv_kernel_flax_to_torch(
683
+ npy(f"{prefix}project/kernel.npy"))
684
+ sd["project.bias"] = _bias(
685
+ npy(f"{prefix}project/bias.npy"))
686
+
687
+ # --- Segmentation head ---
688
+ sd["segmentation_head.weight"] = _linear_kernel_flax_to_torch(
689
+ npy("decoder/pixel_segmentation/kernel.npy"))
690
+ sd["segmentation_head.bias"] = _bias(
691
+ npy("decoder/pixel_segmentation/bias.npy"))
692
+
693
+ zf.close()
694
+
695
+ # Load into model
696
+ missing, unexpected = model.load_state_dict(sd, strict=True)
697
+ if missing:
698
+ print(f"WARNING: Missing keys: {missing}")
699
+ if unexpected:
700
+ print(f"WARNING: Unexpected keys: {unexpected}")
701
+ print(f"Loaded DPT segmentation head weights ({len(sd)} tensors)")
702
+ return model