# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """DPT (Dense Prediction Transformer) depth head in PyTorch. Ported from the Scenic/Flax implementation at: research/vision/scene_understanding/imsight/modules/dpt.py scenic/projects/dense_features/models/decoders.py Architecture: ReassembleBlocks → 4×Conv3x3 → 4×FeatureFusionBlock → project → DepthHead """ import io import os import urllib.request import zipfile import numpy as np import torch from torch import nn import torch.nn.functional as F # ── Building blocks ───────────────────────────────────────────────────────── class PreActResidualConvUnit(nn.Module): """Pre-activation residual convolution unit.""" def __init__(self, features: int): super().__init__() self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False) self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = F.relu(x) x = self.conv1(x) x = F.relu(x) x = self.conv2(x) return x + residual class FeatureFusionBlock(nn.Module): """Fuses features with optional residual input, then upsamples 2×.""" def __init__(self, features: int, has_residual: bool = False, expand: bool = False): super().__init__() self.has_residual = has_residual if has_residual: self.residual_unit = PreActResidualConvUnit(features) self.main_unit = PreActResidualConvUnit(features) out_features = features // 2 if expand else features self.out_conv = nn.Conv2d(features, out_features, 1, bias=True) def forward(self, x: torch.Tensor, residual: torch.Tensor = None) -> torch.Tensor: if self.has_residual and residual is not None: if residual.shape != x.shape: residual = F.interpolate( residual, size=x.shape[2:], mode="bilinear", align_corners=False) residual = self.residual_unit(residual) x = x + residual x = self.main_unit(x) # Upsample 2× with align_corners=True (matches Scenic reference) x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) x = self.out_conv(x) return x class ReassembleBlocks(nn.Module): """Projects and resizes intermediate ViT features to different scales.""" def __init__(self, input_embed_dim: int = 1024, out_channels: tuple = (128, 256, 512, 1024), readout_type: str = "project"): super().__init__() self.readout_type = readout_type # 1×1 conv to project to per-level channels self.out_projections = nn.ModuleList([ nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels ]) # Spatial resize layers: 4× up, 2× up, identity, 2× down self.resize_layers = nn.ModuleList([ nn.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0), nn.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0), nn.Identity(), nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2, padding=1), ]) # Readout projection (concatenate cls_token with patch features) if readout_type == "project": self.readout_projects = nn.ModuleList([ nn.Linear(2 * input_embed_dim, input_embed_dim) for _ in out_channels ]) def forward(self, features): """Process list of (cls_token, spatial_features) tuples. Args: features: list of (cls_token [B,D], patch_feats [B,D,H,W]) Returns: list of tensors at different scales. """ out = [] for i, (cls_token, x) in enumerate(features): B, D, H, W = x.shape if self.readout_type == "project": # Flatten spatial → (B, HW, D) x_flat = x.flatten(2).transpose(1, 2) # Expand cls_token → (B, HW, D) readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) # Concat + project + GELU x_cat = torch.cat([x_flat, readout], dim=-1) x_proj = F.gelu(self.readout_projects[i](x_cat)) # Reshape back to spatial x = x_proj.transpose(1, 2).reshape(B, D, H, W) # 1×1 projection x = self.out_projections[i](x) # Spatial resize x = self.resize_layers[i](x) out.append(x) return out class DPTDepthHead(nn.Module): """Full DPT head + depth classification decoder. Takes 4 intermediate ViT features and produces a depth map. """ def __init__(self, input_embed_dim: int = 1024, channels: int = 256, post_process_channels: tuple = (128, 256, 512, 1024), readout_type: str = "project", num_depth_bins: int = 256, min_depth: float = 1e-3, max_depth: float = 10.0): super().__init__() self.num_depth_bins = num_depth_bins self.min_depth = min_depth self.max_depth = max_depth # Reassemble: project + resize self.reassemble = ReassembleBlocks( input_embed_dim=input_embed_dim, out_channels=post_process_channels, readout_type=readout_type, ) # 3×3 convs to map each level to `channels` self.convs = nn.ModuleList([ nn.Conv2d(ch, channels, 3, padding=1, bias=False) for ch in post_process_channels ]) # Fusion blocks: first has no residual, rest have residual self.fusion_blocks = nn.ModuleList([ FeatureFusionBlock(channels, has_residual=False), FeatureFusionBlock(channels, has_residual=True), FeatureFusionBlock(channels, has_residual=True), FeatureFusionBlock(channels, has_residual=True), ]) # Final projection self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) # Depth classification head (Dense layer) self.depth_head = nn.Linear(channels, num_depth_bins) def forward(self, intermediate_features, image_size=None): """Run DPT depth prediction. Args: intermediate_features: list of 4 (cls_token, patch_feats) tuples image_size: (H, W) to resize output to, or None Returns: depth map tensor (B, 1, H, W) """ # Reassemble x = self.reassemble(intermediate_features) # 3×3 conv per level x = [self.convs[i](feat) for i, feat in enumerate(x)] # Fuse bottom-up: start from deepest (x[-1]) out = self.fusion_blocks[0](x[-1]) for i in range(1, 4): out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) # Project out = self.project(out) out = F.relu(out) # Depth classification # out: (B, C, H, W) → (B, H, W, C) out = out.permute(0, 2, 3, 1) out = self.depth_head(out) # (B, H, W, num_bins) # Classification-based depth prediction bin_centers = torch.linspace( self.min_depth, self.max_depth, self.num_depth_bins, device=out.device) out = F.relu(out) + self.min_depth out_norm = out / out.sum(dim=-1, keepdim=True) depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers) depth = depth.unsqueeze(1) # (B, 1, H, W) # Resize to original image size if image_size is not None: depth = F.interpolate(depth, size=image_size, mode="bilinear", align_corners=False) return depth class DPTNormalsHead(nn.Module): """Full DPT head + surface normals decoder. Takes 4 intermediate ViT features and produces a normal map. """ def __init__(self, input_embed_dim: int = 1024, channels: int = 256, post_process_channels: tuple = (128, 256, 512, 1024), readout_type: str = "project"): super().__init__() # Reassemble: project + resize self.reassemble = ReassembleBlocks( input_embed_dim=input_embed_dim, out_channels=post_process_channels, readout_type=readout_type, ) # 3×3 convs to map each level to `channels` self.convs = nn.ModuleList([ nn.Conv2d(ch, channels, 3, padding=1, bias=False) for ch in post_process_channels ]) # Fusion blocks: first has no residual, rest have residual self.fusion_blocks = nn.ModuleList([ FeatureFusionBlock(channels, has_residual=False), FeatureFusionBlock(channels, has_residual=True), FeatureFusionBlock(channels, has_residual=True), FeatureFusionBlock(channels, has_residual=True), ]) # Final projection self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) # Normals head (Dense layer) self.normals_head = nn.Linear(channels, 3) def forward(self, intermediate_features, image_size=None): """Run DPT normals prediction. Args: intermediate_features: list of 4 (cls_token, patch_feats) tuples image_size: (H, W) to resize output to, or None Returns: normal map tensor (B, 3, H, W) """ # Reassemble x = self.reassemble(intermediate_features) # 3×3 conv per level x = [self.convs[i](feat) for i, feat in enumerate(x)] # Fuse bottom-up: start from deepest (x[-1]) out = self.fusion_blocks[0](x[-1]) for i in range(1, 4): out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) # Project out = self.project(out) # Normals head # out: (B, C, H, W) → (B, H, W, C) out = out.permute(0, 2, 3, 1) out = self.normals_head(out) # (B, H, W, 3) # Normalize to unit length out = F.normalize(out, p=2, dim=-1) # Resize to original image size if image_size is not None: # PyTorch interpolate expects (B, C, H, W) out = out.permute(0, 3, 1, 2) out = F.interpolate(out, size=image_size, mode="bilinear", align_corners=False) else: out = out.permute(0, 3, 1, 2) return out class DPTSegmentationHead(nn.Module): """Full DPT head + segmentation decoder. Takes 4 intermediate ViT features and produces a segmentation map. """ def __init__(self, input_embed_dim: int = 1024, channels: int = 256, post_process_channels: tuple = (128, 256, 512, 1024), readout_type: str = "project", num_classes: int = 150): super().__init__() # Reassemble: project + resize self.reassemble = ReassembleBlocks( input_embed_dim=input_embed_dim, out_channels=post_process_channels, readout_type=readout_type, ) # 3×3 convs to map each level to `channels` self.convs = nn.ModuleList([ nn.Conv2d(ch, channels, 3, padding=1, bias=False) for ch in post_process_channels ]) # Fusion blocks: first has no residual, rest have residual self.fusion_blocks = nn.ModuleList([ FeatureFusionBlock(channels, has_residual=False), FeatureFusionBlock(channels, has_residual=True), FeatureFusionBlock(channels, has_residual=True), FeatureFusionBlock(channels, has_residual=True), ]) # Final projection self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) # Segmentation head (Dense layer) self.segmentation_head = nn.Linear(channels, num_classes) def forward(self, intermediate_features, image_size=None): """Run DPT segmentation prediction. Args: intermediate_features: list of 4 (cls_token, patch_feats) tuples image_size: (H, W) to resize output to, or None Returns: segmentation map tensor (B, num_classes, H, W) """ # Reassemble x = self.reassemble(intermediate_features) # 3×3 conv per level x = [self.convs[i](feat) for i, feat in enumerate(x)] # Fuse bottom-up: start from deepest (x[-1]) out = self.fusion_blocks[0](x[-1]) for i in range(1, 4): out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) # Project out = self.project(out) # Segmentation head # out: (B, C, H, W) → (B, H, W, C) out = out.permute(0, 2, 3, 1) out = self.segmentation_head(out) # (B, H, W, num_classes) # Resize to original image size if image_size is not None: # PyTorch interpolate expects (B, C, H, W) out = out.permute(0, 3, 1, 2) out = F.interpolate(out, size=image_size, mode="bilinear", align_corners=False) else: out = out.permute(0, 3, 1, 2) return out # ── Weight loading from Scenic/Flax checkpoint ───────────────────────────── def _load_npy_from_zip(zf, name): """Load a single .npy array from a zipfile.""" with zf.open(name) as f: return np.load(io.BytesIO(f.read())) def _conv_kernel_flax_to_torch(w): """Convert Flax conv kernel (H,W,Cin,Cout) → PyTorch (Cout,Cin,H,W).""" return torch.from_numpy(w.transpose(3, 2, 0, 1).copy()) def _conv_transpose_kernel_flax_to_torch(w): """Convert Flax ConvTranspose kernel (H,W,Cin,Cout) → PyTorch (Cin,Cout,H,W).""" return torch.from_numpy(w.transpose(2, 3, 0, 1).copy()) def _linear_kernel_flax_to_torch(w): """Convert Flax Dense kernel (in,out) → PyTorch Linear (out,in).""" return torch.from_numpy(w.T.copy()) def _bias(w): return torch.from_numpy(w.copy()) def load_dpt_weights(model: DPTDepthHead, zip_path: str): """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" zf = zipfile.ZipFile(zip_path, "r") npy = lambda name: _load_npy_from_zip(zf, name) sd = {} prefix = "decoder/dpt/" # --- ReassembleBlocks --- for i in range(4): # out_projections (Conv2d 1×1) sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) sd[f"reassemble.out_projections.{i}.bias"] = _bias( npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) # readout_projects (Linear) sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) sd[f"reassemble.readout_projects.{i}.bias"] = _bias( npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) sd["reassemble.resize_layers.0.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) sd["reassemble.resize_layers.1.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) # resize_layers_2 = Identity (no weights) sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) sd["reassemble.resize_layers.3.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) # --- Convs (3×3, no bias) --- for i in range(4): sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}convs_{i}/kernel.npy")) # --- Fusion blocks --- for i in range(4): fb = f"{prefix}fusion_blocks_{i}/" if i == 0: # No residual unit, only 1 PreActResidualConvUnit sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) else: # Residual unit (index 0) + main unit (index 1) sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) # out_conv (Conv2d 1×1) sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}Conv_0/kernel.npy")) sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( npy(f"{fb}Conv_0/bias.npy")) # --- Project --- sd["project.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}project/kernel.npy")) sd["project.bias"] = _bias( npy(f"{prefix}project/bias.npy")) # --- Depth classification head --- sd["depth_head.weight"] = _linear_kernel_flax_to_torch( npy("decoder/pixel_depth_classif/kernel.npy")) sd["depth_head.bias"] = _bias( npy("decoder/pixel_depth_classif/bias.npy")) zf.close() # Load into model missing, unexpected = model.load_state_dict(sd, strict=True) if missing: print(f"WARNING: Missing keys: {missing}") if unexpected: print(f"WARNING: Unexpected keys: {unexpected}") print(f"Loaded DPT depth head weights ({len(sd)} tensors)") return model def load_normals_weights(model: DPTNormalsHead, zip_path: str): """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" zf = zipfile.ZipFile(zip_path, "r") npy = lambda name: _load_npy_from_zip(zf, name) sd = {} prefix = "decoder/dpt/" # --- ReassembleBlocks --- for i in range(4): # out_projections (Conv2d 1×1) sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) sd[f"reassemble.out_projections.{i}.bias"] = _bias( npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) # readout_projects (Linear) sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) sd[f"reassemble.readout_projects.{i}.bias"] = _bias( npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) sd["reassemble.resize_layers.0.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) sd["reassemble.resize_layers.1.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) # resize_layers_2 = Identity (no weights) sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) sd["reassemble.resize_layers.3.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) # --- Convs (3×3, no bias) --- for i in range(4): sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}convs_{i}/kernel.npy")) # --- Fusion blocks --- for i in range(4): fb = f"{prefix}fusion_blocks_{i}/" if i == 0: # No residual unit, only 1 PreActResidualConvUnit sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) else: # Residual unit (index 0) + main unit (index 1) sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) # out_conv (Conv2d 1×1) sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}Conv_0/kernel.npy")) sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( npy(f"{fb}Conv_0/bias.npy")) # --- Project --- sd["project.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}project/kernel.npy")) sd["project.bias"] = _bias( npy(f"{prefix}project/bias.npy")) # --- Normals head --- sd["normals_head.weight"] = _linear_kernel_flax_to_torch( npy("decoder/pixel_normals/kernel.npy")) sd["normals_head.bias"] = _bias( npy("decoder/pixel_normals/bias.npy")) zf.close() # Load into model missing, unexpected = model.load_state_dict(sd, strict=True) if missing: print(f"WARNING: Missing keys: {missing}") if unexpected: print(f"WARNING: Unexpected keys: {unexpected}") print(f"Loaded DPT normals head weights ({len(sd)} tensors)") return model def load_segmentation_weights(model: DPTSegmentationHead, zip_path: str): """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" zf = zipfile.ZipFile(zip_path, "r") npy = lambda name: _load_npy_from_zip(zf, name) sd = {} prefix = "decoder/dpt/" # --- ReassembleBlocks --- for i in range(4): # out_projections (Conv2d 1×1) sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) sd[f"reassemble.out_projections.{i}.bias"] = _bias( npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) # readout_projects (Linear) sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) sd[f"reassemble.readout_projects.{i}.bias"] = _bias( npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) # resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) sd["reassemble.resize_layers.0.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) sd["reassemble.resize_layers.1.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) # resize_layers_2 = Identity (no weights) sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) sd["reassemble.resize_layers.3.bias"] = _bias( npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) # --- Convs (3×3, no bias) --- for i in range(4): sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}convs_{i}/kernel.npy")) # --- Fusion blocks --- for i in range(4): fb = f"{prefix}fusion_blocks_{i}/" if i == 0: # No residual unit, only 1 PreActResidualConvUnit sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) else: # Residual unit (index 0) + main unit (index 1) sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) # out_conv (Conv2d 1×1) sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( npy(f"{fb}Conv_0/kernel.npy")) sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( npy(f"{fb}Conv_0/bias.npy")) # --- Project --- sd["project.weight"] = _conv_kernel_flax_to_torch( npy(f"{prefix}project/kernel.npy")) sd["project.bias"] = _bias( npy(f"{prefix}project/bias.npy")) # --- Segmentation head --- sd["segmentation_head.weight"] = _linear_kernel_flax_to_torch( npy("decoder/pixel_segmentation/kernel.npy")) sd["segmentation_head.bias"] = _bias( npy("decoder/pixel_segmentation/bias.npy")) zf.close() # Load into model missing, unexpected = model.load_state_dict(sd, strict=True) if missing: print(f"WARNING: Missing keys: {missing}") if unexpected: print(f"WARNING: Unexpected keys: {unexpected}") print(f"Loaded DPT segmentation head weights ({len(sd)} tensors)") return model