tipsv2-l14-dpt / dpt_head.py
gberton's picture
Upload dpt_head.py with huggingface_hub
031a449 verified
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DPT (Dense Prediction Transformer) depth head in PyTorch.
Ported from the Scenic/Flax implementation at:
research/vision/scene_understanding/imsight/modules/dpt.py
scenic/projects/dense_features/models/decoders.py
Architecture:
ReassembleBlocks → 4×Conv3x3 → 4×FeatureFusionBlock → project → DepthHead
"""
import io
import os
import urllib.request
import zipfile
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
# ── Building blocks ─────────────────────────────────────────────────────────
class PreActResidualConvUnit(nn.Module):
"""Pre-activation residual convolution unit."""
def __init__(self, features: int):
super().__init__()
self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False)
self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = F.relu(x)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
return x + residual
class FeatureFusionBlock(nn.Module):
"""Fuses features with optional residual input, then upsamples 2×."""
def __init__(self, features: int, has_residual: bool = False,
expand: bool = False):
super().__init__()
self.has_residual = has_residual
if has_residual:
self.residual_unit = PreActResidualConvUnit(features)
self.main_unit = PreActResidualConvUnit(features)
out_features = features // 2 if expand else features
self.out_conv = nn.Conv2d(features, out_features, 1, bias=True)
def forward(self, x: torch.Tensor,
residual: torch.Tensor = None) -> torch.Tensor:
if self.has_residual and residual is not None:
if residual.shape != x.shape:
residual = F.interpolate(
residual, size=x.shape[2:], mode="bilinear",
align_corners=False)
residual = self.residual_unit(residual)
x = x + residual
x = self.main_unit(x)
# Upsample 2× with align_corners=True (matches Scenic reference)
x = F.interpolate(x, scale_factor=2, mode="bilinear",
align_corners=True)
x = self.out_conv(x)
return x
class ReassembleBlocks(nn.Module):
"""Projects and resizes intermediate ViT features to different scales."""
def __init__(self, input_embed_dim: int = 1024,
out_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project"):
super().__init__()
self.readout_type = readout_type
# 1×1 conv to project to per-level channels
self.out_projections = nn.ModuleList([
nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels
])
# Spatial resize layers: 4× up, 2× up, identity, 2× down
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(out_channels[0], out_channels[0],
kernel_size=4, stride=4, padding=0),
nn.ConvTranspose2d(out_channels[1], out_channels[1],
kernel_size=2, stride=2, padding=0),
nn.Identity(),
nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2,
padding=1),
])
# Readout projection (concatenate cls_token with patch features)
if readout_type == "project":
self.readout_projects = nn.ModuleList([
nn.Linear(2 * input_embed_dim, input_embed_dim)
for _ in out_channels
])
def forward(self, features):
"""Process list of (cls_token, spatial_features) tuples.
Args:
features: list of (cls_token [B,D], patch_feats [B,D,H,W])
Returns:
list of tensors at different scales.
"""
out = []
for i, (cls_token, x) in enumerate(features):
B, D, H, W = x.shape
if self.readout_type == "project":
# Flatten spatial → (B, HW, D)
x_flat = x.flatten(2).transpose(1, 2)
# Expand cls_token → (B, HW, D)
readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1)
# Concat + project + GELU
x_cat = torch.cat([x_flat, readout], dim=-1)
x_proj = F.gelu(self.readout_projects[i](x_cat))
# Reshape back to spatial
x = x_proj.transpose(1, 2).reshape(B, D, H, W)
# 1×1 projection
x = self.out_projections[i](x)
# Spatial resize
x = self.resize_layers[i](x)
out.append(x)
return out
class DPTDepthHead(nn.Module):
"""Full DPT head + depth classification decoder.
Takes 4 intermediate ViT features and produces a depth map.
"""
def __init__(self, input_embed_dim: int = 1024,
channels: int = 256,
post_process_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project",
num_depth_bins: int = 256,
min_depth: float = 1e-3,
max_depth: float = 10.0):
super().__init__()
self.num_depth_bins = num_depth_bins
self.min_depth = min_depth
self.max_depth = max_depth
# Reassemble: project + resize
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
readout_type=readout_type,
)
# 3×3 convs to map each level to `channels`
self.convs = nn.ModuleList([
nn.Conv2d(ch, channels, 3, padding=1, bias=False)
for ch in post_process_channels
])
# Fusion blocks: first has no residual, rest have residual
self.fusion_blocks = nn.ModuleList([
FeatureFusionBlock(channels, has_residual=False),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
])
# Final projection
self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
# Depth classification head (Dense layer)
self.depth_head = nn.Linear(channels, num_depth_bins)
def forward(self, intermediate_features, image_size=None):
"""Run DPT depth prediction.
Args:
intermediate_features: list of 4 (cls_token, patch_feats) tuples
image_size: (H, W) to resize output to, or None
Returns:
depth map tensor (B, 1, H, W)
"""
# Reassemble
x = self.reassemble(intermediate_features)
# 3×3 conv per level
x = [self.convs[i](feat) for i, feat in enumerate(x)]
# Fuse bottom-up: start from deepest (x[-1])
out = self.fusion_blocks[0](x[-1])
for i in range(1, 4):
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
# Project
out = self.project(out)
out = F.relu(out)
# Depth classification
# out: (B, C, H, W) → (B, H, W, C)
out = out.permute(0, 2, 3, 1)
out = self.depth_head(out) # (B, H, W, num_bins)
# Classification-based depth prediction
bin_centers = torch.linspace(
self.min_depth, self.max_depth, self.num_depth_bins,
device=out.device)
out = F.relu(out) + self.min_depth
out_norm = out / out.sum(dim=-1, keepdim=True)
depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers)
depth = depth.unsqueeze(1) # (B, 1, H, W)
# Resize to original image size
if image_size is not None:
depth = F.interpolate(depth, size=image_size, mode="bilinear",
align_corners=False)
return depth
class DPTNormalsHead(nn.Module):
"""Full DPT head + surface normals decoder.
Takes 4 intermediate ViT features and produces a normal map.
"""
def __init__(self, input_embed_dim: int = 1024,
channels: int = 256,
post_process_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project"):
super().__init__()
# Reassemble: project + resize
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
readout_type=readout_type,
)
# 3×3 convs to map each level to `channels`
self.convs = nn.ModuleList([
nn.Conv2d(ch, channels, 3, padding=1, bias=False)
for ch in post_process_channels
])
# Fusion blocks: first has no residual, rest have residual
self.fusion_blocks = nn.ModuleList([
FeatureFusionBlock(channels, has_residual=False),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
])
# Final projection
self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
# Normals head (Dense layer)
self.normals_head = nn.Linear(channels, 3)
def forward(self, intermediate_features, image_size=None):
"""Run DPT normals prediction.
Args:
intermediate_features: list of 4 (cls_token, patch_feats) tuples
image_size: (H, W) to resize output to, or None
Returns:
normal map tensor (B, 3, H, W)
"""
# Reassemble
x = self.reassemble(intermediate_features)
# 3×3 conv per level
x = [self.convs[i](feat) for i, feat in enumerate(x)]
# Fuse bottom-up: start from deepest (x[-1])
out = self.fusion_blocks[0](x[-1])
for i in range(1, 4):
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
# Project
out = self.project(out)
# Normals head
# out: (B, C, H, W) → (B, H, W, C)
out = out.permute(0, 2, 3, 1)
out = self.normals_head(out) # (B, H, W, 3)
# Normalize to unit length
out = F.normalize(out, p=2, dim=-1)
# Resize to original image size
if image_size is not None:
# PyTorch interpolate expects (B, C, H, W)
out = out.permute(0, 3, 1, 2)
out = F.interpolate(out, size=image_size, mode="bilinear",
align_corners=False)
else:
out = out.permute(0, 3, 1, 2)
return out
class DPTSegmentationHead(nn.Module):
"""Full DPT head + segmentation decoder.
Takes 4 intermediate ViT features and produces a segmentation map.
"""
def __init__(self, input_embed_dim: int = 1024,
channels: int = 256,
post_process_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project",
num_classes: int = 150):
super().__init__()
# Reassemble: project + resize
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
readout_type=readout_type,
)
# 3×3 convs to map each level to `channels`
self.convs = nn.ModuleList([
nn.Conv2d(ch, channels, 3, padding=1, bias=False)
for ch in post_process_channels
])
# Fusion blocks: first has no residual, rest have residual
self.fusion_blocks = nn.ModuleList([
FeatureFusionBlock(channels, has_residual=False),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
])
# Final projection
self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
# Segmentation head (Dense layer)
self.segmentation_head = nn.Linear(channels, num_classes)
def forward(self, intermediate_features, image_size=None):
"""Run DPT segmentation prediction.
Args:
intermediate_features: list of 4 (cls_token, patch_feats) tuples
image_size: (H, W) to resize output to, or None
Returns:
segmentation map tensor (B, num_classes, H, W)
"""
# Reassemble
x = self.reassemble(intermediate_features)
# 3×3 conv per level
x = [self.convs[i](feat) for i, feat in enumerate(x)]
# Fuse bottom-up: start from deepest (x[-1])
out = self.fusion_blocks[0](x[-1])
for i in range(1, 4):
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
# Project
out = self.project(out)
# Segmentation head
# out: (B, C, H, W) → (B, H, W, C)
out = out.permute(0, 2, 3, 1)
out = self.segmentation_head(out) # (B, H, W, num_classes)
# Resize to original image size
if image_size is not None:
# PyTorch interpolate expects (B, C, H, W)
out = out.permute(0, 3, 1, 2)
out = F.interpolate(out, size=image_size, mode="bilinear",
align_corners=False)
else:
out = out.permute(0, 3, 1, 2)
return out
# ── Weight loading from Scenic/Flax checkpoint ─────────────────────────────
def _load_npy_from_zip(zf, name):
"""Load a single .npy array from a zipfile."""
with zf.open(name) as f:
return np.load(io.BytesIO(f.read()))
def _conv_kernel_flax_to_torch(w):
"""Convert Flax conv kernel (H,W,Cin,Cout) → PyTorch (Cout,Cin,H,W)."""
return torch.from_numpy(w.transpose(3, 2, 0, 1).copy())
def _conv_transpose_kernel_flax_to_torch(w):
"""Convert Flax ConvTranspose kernel (H,W,Cin,Cout) → PyTorch (Cin,Cout,H,W)."""
return torch.from_numpy(w.transpose(2, 3, 0, 1).copy())
def _linear_kernel_flax_to_torch(w):
"""Convert Flax Dense kernel (in,out) → PyTorch Linear (out,in)."""
return torch.from_numpy(w.T.copy())
def _bias(w):
return torch.from_numpy(w.copy())
def load_dpt_weights(model: DPTDepthHead, zip_path: str):
"""Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
zf = zipfile.ZipFile(zip_path, "r")
npy = lambda name: _load_npy_from_zip(zf, name)
sd = {}
prefix = "decoder/dpt/"
# --- ReassembleBlocks ---
for i in range(4):
# out_projections (Conv2d 1×1)
sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
sd[f"reassemble.out_projections.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
# readout_projects (Linear)
sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
sd["reassemble.resize_layers.0.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
sd["reassemble.resize_layers.1.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
# resize_layers_2 = Identity (no weights)
sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
sd["reassemble.resize_layers.3.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
# --- Convs (3×3, no bias) ---
for i in range(4):
sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}convs_{i}/kernel.npy"))
# --- Fusion blocks ---
for i in range(4):
fb = f"{prefix}fusion_blocks_{i}/"
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
else:
# Residual unit (index 0) + main unit (index 1)
sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
# out_conv (Conv2d 1×1)
sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}Conv_0/kernel.npy"))
sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
npy(f"{fb}Conv_0/bias.npy"))
# --- Project ---
sd["project.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}project/kernel.npy"))
sd["project.bias"] = _bias(
npy(f"{prefix}project/bias.npy"))
# --- Depth classification head ---
sd["depth_head.weight"] = _linear_kernel_flax_to_torch(
npy("decoder/pixel_depth_classif/kernel.npy"))
sd["depth_head.bias"] = _bias(
npy("decoder/pixel_depth_classif/bias.npy"))
zf.close()
# Load into model
missing, unexpected = model.load_state_dict(sd, strict=True)
if missing:
print(f"WARNING: Missing keys: {missing}")
if unexpected:
print(f"WARNING: Unexpected keys: {unexpected}")
print(f"Loaded DPT depth head weights ({len(sd)} tensors)")
return model
def load_normals_weights(model: DPTNormalsHead, zip_path: str):
"""Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
zf = zipfile.ZipFile(zip_path, "r")
npy = lambda name: _load_npy_from_zip(zf, name)
sd = {}
prefix = "decoder/dpt/"
# --- ReassembleBlocks ---
for i in range(4):
# out_projections (Conv2d 1×1)
sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
sd[f"reassemble.out_projections.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
# readout_projects (Linear)
sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
sd["reassemble.resize_layers.0.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
sd["reassemble.resize_layers.1.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
# resize_layers_2 = Identity (no weights)
sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
sd["reassemble.resize_layers.3.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
# --- Convs (3×3, no bias) ---
for i in range(4):
sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}convs_{i}/kernel.npy"))
# --- Fusion blocks ---
for i in range(4):
fb = f"{prefix}fusion_blocks_{i}/"
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
else:
# Residual unit (index 0) + main unit (index 1)
sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
# out_conv (Conv2d 1×1)
sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}Conv_0/kernel.npy"))
sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
npy(f"{fb}Conv_0/bias.npy"))
# --- Project ---
sd["project.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}project/kernel.npy"))
sd["project.bias"] = _bias(
npy(f"{prefix}project/bias.npy"))
# --- Normals head ---
sd["normals_head.weight"] = _linear_kernel_flax_to_torch(
npy("decoder/pixel_normals/kernel.npy"))
sd["normals_head.bias"] = _bias(
npy("decoder/pixel_normals/bias.npy"))
zf.close()
# Load into model
missing, unexpected = model.load_state_dict(sd, strict=True)
if missing:
print(f"WARNING: Missing keys: {missing}")
if unexpected:
print(f"WARNING: Unexpected keys: {unexpected}")
print(f"Loaded DPT normals head weights ({len(sd)} tensors)")
return model
def load_segmentation_weights(model: DPTSegmentationHead, zip_path: str):
"""Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
zf = zipfile.ZipFile(zip_path, "r")
npy = lambda name: _load_npy_from_zip(zf, name)
sd = {}
prefix = "decoder/dpt/"
# --- ReassembleBlocks ---
for i in range(4):
# out_projections (Conv2d 1×1)
sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
sd[f"reassemble.out_projections.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
# readout_projects (Linear)
sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
sd["reassemble.resize_layers.0.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
sd["reassemble.resize_layers.1.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
# resize_layers_2 = Identity (no weights)
sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
sd["reassemble.resize_layers.3.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
# --- Convs (3×3, no bias) ---
for i in range(4):
sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}convs_{i}/kernel.npy"))
# --- Fusion blocks ---
for i in range(4):
fb = f"{prefix}fusion_blocks_{i}/"
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
else:
# Residual unit (index 0) + main unit (index 1)
sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
# out_conv (Conv2d 1×1)
sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}Conv_0/kernel.npy"))
sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
npy(f"{fb}Conv_0/bias.npy"))
# --- Project ---
sd["project.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}project/kernel.npy"))
sd["project.bias"] = _bias(
npy(f"{prefix}project/bias.npy"))
# --- Segmentation head ---
sd["segmentation_head.weight"] = _linear_kernel_flax_to_torch(
npy("decoder/pixel_segmentation/kernel.npy"))
sd["segmentation_head.bias"] = _bias(
npy("decoder/pixel_segmentation/bias.npy"))
zf.close()
# Load into model
missing, unexpected = model.load_state_dict(sd, strict=True)
if missing:
print(f"WARNING: Missing keys: {missing}")
if unexpected:
print(f"WARNING: Unexpected keys: {unexpected}")
print(f"Loaded DPT segmentation head weights ({len(sd)} tensors)")
return model