| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """DPT (Dense Prediction Transformer) depth head in PyTorch. |
| |
| Ported from the Scenic/Flax implementation at: |
| research/vision/scene_understanding/imsight/modules/dpt.py |
| scenic/projects/dense_features/models/decoders.py |
| |
| Architecture: |
| ReassembleBlocks → 4×Conv3x3 → 4×FeatureFusionBlock → project → DepthHead |
| """ |
|
|
| import io |
| import os |
| import urllib.request |
| import zipfile |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
|
|
| |
|
|
|
|
| class PreActResidualConvUnit(nn.Module): |
| """Pre-activation residual convolution unit.""" |
|
|
| def __init__(self, features: int): |
| super().__init__() |
| self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False) |
| self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
| x = F.relu(x) |
| x = self.conv1(x) |
| x = F.relu(x) |
| x = self.conv2(x) |
| return x + residual |
|
|
|
|
| class FeatureFusionBlock(nn.Module): |
| """Fuses features with optional residual input, then upsamples 2×.""" |
|
|
| def __init__(self, features: int, has_residual: bool = False, |
| expand: bool = False): |
| super().__init__() |
| self.has_residual = has_residual |
| if has_residual: |
| self.residual_unit = PreActResidualConvUnit(features) |
| self.main_unit = PreActResidualConvUnit(features) |
| out_features = features // 2 if expand else features |
| self.out_conv = nn.Conv2d(features, out_features, 1, bias=True) |
|
|
| def forward(self, x: torch.Tensor, |
| residual: torch.Tensor = None) -> torch.Tensor: |
| if self.has_residual and residual is not None: |
| if residual.shape != x.shape: |
| residual = F.interpolate( |
| residual, size=x.shape[2:], mode="bilinear", |
| align_corners=False) |
| residual = self.residual_unit(residual) |
| x = x + residual |
| x = self.main_unit(x) |
| |
| x = F.interpolate(x, scale_factor=2, mode="bilinear", |
| align_corners=True) |
| x = self.out_conv(x) |
| return x |
|
|
|
|
| class ReassembleBlocks(nn.Module): |
| """Projects and resizes intermediate ViT features to different scales.""" |
|
|
| def __init__(self, input_embed_dim: int = 1024, |
| out_channels: tuple = (128, 256, 512, 1024), |
| readout_type: str = "project"): |
| super().__init__() |
| self.readout_type = readout_type |
|
|
| |
| self.out_projections = nn.ModuleList([ |
| nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels |
| ]) |
|
|
| |
| self.resize_layers = nn.ModuleList([ |
| nn.ConvTranspose2d(out_channels[0], out_channels[0], |
| kernel_size=4, stride=4, padding=0), |
| nn.ConvTranspose2d(out_channels[1], out_channels[1], |
| kernel_size=2, stride=2, padding=0), |
| nn.Identity(), |
| nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2, |
| padding=1), |
| ]) |
|
|
| |
| if readout_type == "project": |
| self.readout_projects = nn.ModuleList([ |
| nn.Linear(2 * input_embed_dim, input_embed_dim) |
| for _ in out_channels |
| ]) |
|
|
| def forward(self, features): |
| """Process list of (cls_token, spatial_features) tuples. |
| |
| Args: |
| features: list of (cls_token [B,D], patch_feats [B,D,H,W]) |
| |
| Returns: |
| list of tensors at different scales. |
| """ |
| out = [] |
| for i, (cls_token, x) in enumerate(features): |
| B, D, H, W = x.shape |
|
|
| if self.readout_type == "project": |
| |
| x_flat = x.flatten(2).transpose(1, 2) |
| |
| readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) |
| |
| x_cat = torch.cat([x_flat, readout], dim=-1) |
| x_proj = F.gelu(self.readout_projects[i](x_cat)) |
| |
| x = x_proj.transpose(1, 2).reshape(B, D, H, W) |
|
|
| |
| x = self.out_projections[i](x) |
| |
| x = self.resize_layers[i](x) |
| out.append(x) |
| return out |
|
|
|
|
| class DPTDepthHead(nn.Module): |
| """Full DPT head + depth classification decoder. |
| |
| Takes 4 intermediate ViT features and produces a depth map. |
| """ |
|
|
| def __init__(self, input_embed_dim: int = 1024, |
| channels: int = 256, |
| post_process_channels: tuple = (128, 256, 512, 1024), |
| readout_type: str = "project", |
| num_depth_bins: int = 256, |
| min_depth: float = 1e-3, |
| max_depth: float = 10.0): |
| super().__init__() |
| self.num_depth_bins = num_depth_bins |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
|
|
| |
| self.reassemble = ReassembleBlocks( |
| input_embed_dim=input_embed_dim, |
| out_channels=post_process_channels, |
| readout_type=readout_type, |
| ) |
|
|
| |
| self.convs = nn.ModuleList([ |
| nn.Conv2d(ch, channels, 3, padding=1, bias=False) |
| for ch in post_process_channels |
| ]) |
|
|
| |
| self.fusion_blocks = nn.ModuleList([ |
| FeatureFusionBlock(channels, has_residual=False), |
| FeatureFusionBlock(channels, has_residual=True), |
| FeatureFusionBlock(channels, has_residual=True), |
| FeatureFusionBlock(channels, has_residual=True), |
| ]) |
|
|
| |
| self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) |
|
|
| |
| self.depth_head = nn.Linear(channels, num_depth_bins) |
|
|
| def forward(self, intermediate_features, image_size=None): |
| """Run DPT depth prediction. |
| |
| Args: |
| intermediate_features: list of 4 (cls_token, patch_feats) tuples |
| image_size: (H, W) to resize output to, or None |
| |
| Returns: |
| depth map tensor (B, 1, H, W) |
| """ |
| |
| x = self.reassemble(intermediate_features) |
| |
| x = [self.convs[i](feat) for i, feat in enumerate(x)] |
|
|
| |
| out = self.fusion_blocks[0](x[-1]) |
| for i in range(1, 4): |
| out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) |
|
|
| |
| out = self.project(out) |
| out = F.relu(out) |
|
|
| |
| |
| out = out.permute(0, 2, 3, 1) |
| out = self.depth_head(out) |
|
|
| |
| bin_centers = torch.linspace( |
| self.min_depth, self.max_depth, self.num_depth_bins, |
| device=out.device) |
| out = F.relu(out) + self.min_depth |
| out_norm = out / out.sum(dim=-1, keepdim=True) |
| depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers) |
| depth = depth.unsqueeze(1) |
|
|
| |
| if image_size is not None: |
| depth = F.interpolate(depth, size=image_size, mode="bilinear", |
| align_corners=False) |
|
|
| return depth |
|
|
|
|
| class DPTNormalsHead(nn.Module): |
| """Full DPT head + surface normals decoder. |
| |
| Takes 4 intermediate ViT features and produces a normal map. |
| """ |
|
|
| def __init__(self, input_embed_dim: int = 1024, |
| channels: int = 256, |
| post_process_channels: tuple = (128, 256, 512, 1024), |
| readout_type: str = "project"): |
| super().__init__() |
|
|
| |
| self.reassemble = ReassembleBlocks( |
| input_embed_dim=input_embed_dim, |
| out_channels=post_process_channels, |
| readout_type=readout_type, |
| ) |
|
|
| |
| self.convs = nn.ModuleList([ |
| nn.Conv2d(ch, channels, 3, padding=1, bias=False) |
| for ch in post_process_channels |
| ]) |
|
|
| |
| self.fusion_blocks = nn.ModuleList([ |
| FeatureFusionBlock(channels, has_residual=False), |
| FeatureFusionBlock(channels, has_residual=True), |
| FeatureFusionBlock(channels, has_residual=True), |
| FeatureFusionBlock(channels, has_residual=True), |
| ]) |
|
|
| |
| self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) |
|
|
| |
| self.normals_head = nn.Linear(channels, 3) |
|
|
| def forward(self, intermediate_features, image_size=None): |
| """Run DPT normals prediction. |
| |
| Args: |
| intermediate_features: list of 4 (cls_token, patch_feats) tuples |
| image_size: (H, W) to resize output to, or None |
| |
| Returns: |
| normal map tensor (B, 3, H, W) |
| """ |
| |
| x = self.reassemble(intermediate_features) |
| |
| x = [self.convs[i](feat) for i, feat in enumerate(x)] |
|
|
| |
| out = self.fusion_blocks[0](x[-1]) |
| for i in range(1, 4): |
| out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) |
|
|
| |
| out = self.project(out) |
| |
| |
| |
| out = out.permute(0, 2, 3, 1) |
| out = self.normals_head(out) |
|
|
| |
| out = F.normalize(out, p=2, dim=-1) |
|
|
| |
| if image_size is not None: |
| |
| out = out.permute(0, 3, 1, 2) |
| out = F.interpolate(out, size=image_size, mode="bilinear", |
| align_corners=False) |
| else: |
| out = out.permute(0, 3, 1, 2) |
|
|
| return out |
|
|
|
|
| class DPTSegmentationHead(nn.Module): |
| """Full DPT head + segmentation decoder. |
| |
| Takes 4 intermediate ViT features and produces a segmentation map. |
| """ |
|
|
| def __init__(self, input_embed_dim: int = 1024, |
| channels: int = 256, |
| post_process_channels: tuple = (128, 256, 512, 1024), |
| readout_type: str = "project", |
| num_classes: int = 150): |
| super().__init__() |
|
|
| |
| self.reassemble = ReassembleBlocks( |
| input_embed_dim=input_embed_dim, |
| out_channels=post_process_channels, |
| readout_type=readout_type, |
| ) |
|
|
| |
| self.convs = nn.ModuleList([ |
| nn.Conv2d(ch, channels, 3, padding=1, bias=False) |
| for ch in post_process_channels |
| ]) |
|
|
| |
| self.fusion_blocks = nn.ModuleList([ |
| FeatureFusionBlock(channels, has_residual=False), |
| FeatureFusionBlock(channels, has_residual=True), |
| FeatureFusionBlock(channels, has_residual=True), |
| FeatureFusionBlock(channels, has_residual=True), |
| ]) |
|
|
| |
| self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True) |
|
|
| |
| self.segmentation_head = nn.Linear(channels, num_classes) |
|
|
| def forward(self, intermediate_features, image_size=None): |
| """Run DPT segmentation prediction. |
| |
| Args: |
| intermediate_features: list of 4 (cls_token, patch_feats) tuples |
| image_size: (H, W) to resize output to, or None |
| |
| Returns: |
| segmentation map tensor (B, num_classes, H, W) |
| """ |
| |
| x = self.reassemble(intermediate_features) |
| |
| x = [self.convs[i](feat) for i, feat in enumerate(x)] |
|
|
| |
| out = self.fusion_blocks[0](x[-1]) |
| for i in range(1, 4): |
| out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) |
|
|
| |
| out = self.project(out) |
| |
| |
| |
| out = out.permute(0, 2, 3, 1) |
| out = self.segmentation_head(out) |
|
|
| |
| if image_size is not None: |
| |
| out = out.permute(0, 3, 1, 2) |
| out = F.interpolate(out, size=image_size, mode="bilinear", |
| align_corners=False) |
| else: |
| out = out.permute(0, 3, 1, 2) |
|
|
| return out |
|
|
|
|
| |
|
|
|
|
| def _load_npy_from_zip(zf, name): |
| """Load a single .npy array from a zipfile.""" |
| with zf.open(name) as f: |
| return np.load(io.BytesIO(f.read())) |
|
|
|
|
| def _conv_kernel_flax_to_torch(w): |
| """Convert Flax conv kernel (H,W,Cin,Cout) → PyTorch (Cout,Cin,H,W).""" |
| return torch.from_numpy(w.transpose(3, 2, 0, 1).copy()) |
|
|
|
|
| def _conv_transpose_kernel_flax_to_torch(w): |
| """Convert Flax ConvTranspose kernel (H,W,Cin,Cout) → PyTorch (Cin,Cout,H,W).""" |
| return torch.from_numpy(w.transpose(2, 3, 0, 1).copy()) |
|
|
|
|
| def _linear_kernel_flax_to_torch(w): |
| """Convert Flax Dense kernel (in,out) → PyTorch Linear (out,in).""" |
| return torch.from_numpy(w.T.copy()) |
|
|
|
|
| def _bias(w): |
| return torch.from_numpy(w.copy()) |
|
|
|
|
| def load_dpt_weights(model: DPTDepthHead, zip_path: str): |
| """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" |
| zf = zipfile.ZipFile(zip_path, "r") |
| npy = lambda name: _load_npy_from_zip(zf, name) |
| sd = {} |
| prefix = "decoder/dpt/" |
|
|
| |
| for i in range(4): |
| |
| sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) |
| sd[f"reassemble.out_projections.{i}.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) |
|
|
| |
| sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) |
| sd[f"reassemble.readout_projects.{i}.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) |
|
|
| |
| sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) |
| sd["reassemble.resize_layers.0.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) |
| sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) |
| sd["reassemble.resize_layers.1.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) |
| |
| sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) |
| sd["reassemble.resize_layers.3.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) |
|
|
| |
| for i in range(4): |
| sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}convs_{i}/kernel.npy")) |
|
|
| |
| for i in range(4): |
| fb = f"{prefix}fusion_blocks_{i}/" |
| if i == 0: |
| |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) |
| else: |
| |
| sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) |
|
|
| |
| sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}Conv_0/kernel.npy")) |
| sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( |
| npy(f"{fb}Conv_0/bias.npy")) |
|
|
| |
| sd["project.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}project/kernel.npy")) |
| sd["project.bias"] = _bias( |
| npy(f"{prefix}project/bias.npy")) |
|
|
| |
| sd["depth_head.weight"] = _linear_kernel_flax_to_torch( |
| npy("decoder/pixel_depth_classif/kernel.npy")) |
| sd["depth_head.bias"] = _bias( |
| npy("decoder/pixel_depth_classif/bias.npy")) |
|
|
| zf.close() |
|
|
| |
| missing, unexpected = model.load_state_dict(sd, strict=True) |
| if missing: |
| print(f"WARNING: Missing keys: {missing}") |
| if unexpected: |
| print(f"WARNING: Unexpected keys: {unexpected}") |
| print(f"Loaded DPT depth head weights ({len(sd)} tensors)") |
| return model |
|
|
|
|
| def load_normals_weights(model: DPTNormalsHead, zip_path: str): |
| """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" |
| zf = zipfile.ZipFile(zip_path, "r") |
| npy = lambda name: _load_npy_from_zip(zf, name) |
| sd = {} |
| prefix = "decoder/dpt/" |
|
|
| |
| for i in range(4): |
| |
| sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) |
| sd[f"reassemble.out_projections.{i}.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) |
|
|
| |
| sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) |
| sd[f"reassemble.readout_projects.{i}.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) |
|
|
| |
| sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) |
| sd["reassemble.resize_layers.0.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) |
| sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) |
| sd["reassemble.resize_layers.1.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) |
| |
| sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) |
| sd["reassemble.resize_layers.3.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) |
|
|
| |
| for i in range(4): |
| sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}convs_{i}/kernel.npy")) |
|
|
| |
| for i in range(4): |
| fb = f"{prefix}fusion_blocks_{i}/" |
| if i == 0: |
| |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) |
| else: |
| |
| sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) |
|
|
| |
| sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}Conv_0/kernel.npy")) |
| sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( |
| npy(f"{fb}Conv_0/bias.npy")) |
|
|
| |
| sd["project.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}project/kernel.npy")) |
| sd["project.bias"] = _bias( |
| npy(f"{prefix}project/bias.npy")) |
|
|
| |
| sd["normals_head.weight"] = _linear_kernel_flax_to_torch( |
| npy("decoder/pixel_normals/kernel.npy")) |
| sd["normals_head.bias"] = _bias( |
| npy("decoder/pixel_normals/bias.npy")) |
|
|
| zf.close() |
|
|
| |
| missing, unexpected = model.load_state_dict(sd, strict=True) |
| if missing: |
| print(f"WARNING: Missing keys: {missing}") |
| if unexpected: |
| print(f"WARNING: Unexpected keys: {unexpected}") |
| print(f"Loaded DPT normals head weights ({len(sd)} tensors)") |
| return model |
|
|
|
|
| def load_segmentation_weights(model: DPTSegmentationHead, zip_path: str): |
| """Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model.""" |
| zf = zipfile.ZipFile(zip_path, "r") |
| npy = lambda name: _load_npy_from_zip(zf, name) |
| sd = {} |
| prefix = "decoder/dpt/" |
|
|
| |
| for i in range(4): |
| |
| sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy")) |
| sd[f"reassemble.out_projections.{i}.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy")) |
|
|
| |
| sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy")) |
| sd[f"reassemble.readout_projects.{i}.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy")) |
|
|
| |
| sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy")) |
| sd["reassemble.resize_layers.0.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy")) |
| sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy")) |
| sd["reassemble.resize_layers.1.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy")) |
| |
| sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy")) |
| sd["reassemble.resize_layers.3.bias"] = _bias( |
| npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy")) |
|
|
| |
| for i in range(4): |
| sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}convs_{i}/kernel.npy")) |
|
|
| |
| for i in range(4): |
| fb = f"{prefix}fusion_blocks_{i}/" |
| if i == 0: |
| |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) |
| else: |
| |
| sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy")) |
| sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy")) |
|
|
| |
| sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{fb}Conv_0/kernel.npy")) |
| sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias( |
| npy(f"{fb}Conv_0/bias.npy")) |
|
|
| |
| sd["project.weight"] = _conv_kernel_flax_to_torch( |
| npy(f"{prefix}project/kernel.npy")) |
| sd["project.bias"] = _bias( |
| npy(f"{prefix}project/bias.npy")) |
|
|
| |
| sd["segmentation_head.weight"] = _linear_kernel_flax_to_torch( |
| npy("decoder/pixel_segmentation/kernel.npy")) |
| sd["segmentation_head.bias"] = _bias( |
| npy("decoder/pixel_segmentation/bias.npy")) |
|
|
| zf.close() |
|
|
| |
| missing, unexpected = model.load_state_dict(sd, strict=True) |
| if missing: |
| print(f"WARNING: Missing keys: {missing}") |
| if unexpected: |
| print(f"WARNING: Unexpected keys: {unexpected}") |
| print(f"Loaded DPT segmentation head weights ({len(sd)} tensors)") |
| return model |
|
|