File size: 28,239 Bytes
031a449 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 | # Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""DPT (Dense Prediction Transformer) depth head in PyTorch.
Ported from the Scenic/Flax implementation at:
research/vision/scene_understanding/imsight/modules/dpt.py
scenic/projects/dense_features/models/decoders.py
Architecture:
ReassembleBlocks β 4ΓConv3x3 β 4ΓFeatureFusionBlock β project β DepthHead
"""
import io
import os
import urllib.request
import zipfile
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
# ββ Building blocks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class PreActResidualConvUnit(nn.Module):
"""Pre-activation residual convolution unit."""
def __init__(self, features: int):
super().__init__()
self.conv1 = nn.Conv2d(features, features, 3, padding=1, bias=False)
self.conv2 = nn.Conv2d(features, features, 3, padding=1, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = F.relu(x)
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
return x + residual
class FeatureFusionBlock(nn.Module):
"""Fuses features with optional residual input, then upsamples 2Γ."""
def __init__(self, features: int, has_residual: bool = False,
expand: bool = False):
super().__init__()
self.has_residual = has_residual
if has_residual:
self.residual_unit = PreActResidualConvUnit(features)
self.main_unit = PreActResidualConvUnit(features)
out_features = features // 2 if expand else features
self.out_conv = nn.Conv2d(features, out_features, 1, bias=True)
def forward(self, x: torch.Tensor,
residual: torch.Tensor = None) -> torch.Tensor:
if self.has_residual and residual is not None:
if residual.shape != x.shape:
residual = F.interpolate(
residual, size=x.shape[2:], mode="bilinear",
align_corners=False)
residual = self.residual_unit(residual)
x = x + residual
x = self.main_unit(x)
# Upsample 2Γ with align_corners=True (matches Scenic reference)
x = F.interpolate(x, scale_factor=2, mode="bilinear",
align_corners=True)
x = self.out_conv(x)
return x
class ReassembleBlocks(nn.Module):
"""Projects and resizes intermediate ViT features to different scales."""
def __init__(self, input_embed_dim: int = 1024,
out_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project"):
super().__init__()
self.readout_type = readout_type
# 1Γ1 conv to project to per-level channels
self.out_projections = nn.ModuleList([
nn.Conv2d(input_embed_dim, ch, 1) for ch in out_channels
])
# Spatial resize layers: 4Γ up, 2Γ up, identity, 2Γ down
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(out_channels[0], out_channels[0],
kernel_size=4, stride=4, padding=0),
nn.ConvTranspose2d(out_channels[1], out_channels[1],
kernel_size=2, stride=2, padding=0),
nn.Identity(),
nn.Conv2d(out_channels[3], out_channels[3], 3, stride=2,
padding=1),
])
# Readout projection (concatenate cls_token with patch features)
if readout_type == "project":
self.readout_projects = nn.ModuleList([
nn.Linear(2 * input_embed_dim, input_embed_dim)
for _ in out_channels
])
def forward(self, features):
"""Process list of (cls_token, spatial_features) tuples.
Args:
features: list of (cls_token [B,D], patch_feats [B,D,H,W])
Returns:
list of tensors at different scales.
"""
out = []
for i, (cls_token, x) in enumerate(features):
B, D, H, W = x.shape
if self.readout_type == "project":
# Flatten spatial β (B, HW, D)
x_flat = x.flatten(2).transpose(1, 2)
# Expand cls_token β (B, HW, D)
readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1)
# Concat + project + GELU
x_cat = torch.cat([x_flat, readout], dim=-1)
x_proj = F.gelu(self.readout_projects[i](x_cat))
# Reshape back to spatial
x = x_proj.transpose(1, 2).reshape(B, D, H, W)
# 1Γ1 projection
x = self.out_projections[i](x)
# Spatial resize
x = self.resize_layers[i](x)
out.append(x)
return out
class DPTDepthHead(nn.Module):
"""Full DPT head + depth classification decoder.
Takes 4 intermediate ViT features and produces a depth map.
"""
def __init__(self, input_embed_dim: int = 1024,
channels: int = 256,
post_process_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project",
num_depth_bins: int = 256,
min_depth: float = 1e-3,
max_depth: float = 10.0):
super().__init__()
self.num_depth_bins = num_depth_bins
self.min_depth = min_depth
self.max_depth = max_depth
# Reassemble: project + resize
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
readout_type=readout_type,
)
# 3Γ3 convs to map each level to `channels`
self.convs = nn.ModuleList([
nn.Conv2d(ch, channels, 3, padding=1, bias=False)
for ch in post_process_channels
])
# Fusion blocks: first has no residual, rest have residual
self.fusion_blocks = nn.ModuleList([
FeatureFusionBlock(channels, has_residual=False),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
])
# Final projection
self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
# Depth classification head (Dense layer)
self.depth_head = nn.Linear(channels, num_depth_bins)
def forward(self, intermediate_features, image_size=None):
"""Run DPT depth prediction.
Args:
intermediate_features: list of 4 (cls_token, patch_feats) tuples
image_size: (H, W) to resize output to, or None
Returns:
depth map tensor (B, 1, H, W)
"""
# Reassemble
x = self.reassemble(intermediate_features)
# 3Γ3 conv per level
x = [self.convs[i](feat) for i, feat in enumerate(x)]
# Fuse bottom-up: start from deepest (x[-1])
out = self.fusion_blocks[0](x[-1])
for i in range(1, 4):
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
# Project
out = self.project(out)
out = F.relu(out)
# Depth classification
# out: (B, C, H, W) β (B, H, W, C)
out = out.permute(0, 2, 3, 1)
out = self.depth_head(out) # (B, H, W, num_bins)
# Classification-based depth prediction
bin_centers = torch.linspace(
self.min_depth, self.max_depth, self.num_depth_bins,
device=out.device)
out = F.relu(out) + self.min_depth
out_norm = out / out.sum(dim=-1, keepdim=True)
depth = torch.einsum("bhwn,n->bhw", out_norm, bin_centers)
depth = depth.unsqueeze(1) # (B, 1, H, W)
# Resize to original image size
if image_size is not None:
depth = F.interpolate(depth, size=image_size, mode="bilinear",
align_corners=False)
return depth
class DPTNormalsHead(nn.Module):
"""Full DPT head + surface normals decoder.
Takes 4 intermediate ViT features and produces a normal map.
"""
def __init__(self, input_embed_dim: int = 1024,
channels: int = 256,
post_process_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project"):
super().__init__()
# Reassemble: project + resize
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
readout_type=readout_type,
)
# 3Γ3 convs to map each level to `channels`
self.convs = nn.ModuleList([
nn.Conv2d(ch, channels, 3, padding=1, bias=False)
for ch in post_process_channels
])
# Fusion blocks: first has no residual, rest have residual
self.fusion_blocks = nn.ModuleList([
FeatureFusionBlock(channels, has_residual=False),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
])
# Final projection
self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
# Normals head (Dense layer)
self.normals_head = nn.Linear(channels, 3)
def forward(self, intermediate_features, image_size=None):
"""Run DPT normals prediction.
Args:
intermediate_features: list of 4 (cls_token, patch_feats) tuples
image_size: (H, W) to resize output to, or None
Returns:
normal map tensor (B, 3, H, W)
"""
# Reassemble
x = self.reassemble(intermediate_features)
# 3Γ3 conv per level
x = [self.convs[i](feat) for i, feat in enumerate(x)]
# Fuse bottom-up: start from deepest (x[-1])
out = self.fusion_blocks[0](x[-1])
for i in range(1, 4):
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
# Project
out = self.project(out)
# Normals head
# out: (B, C, H, W) β (B, H, W, C)
out = out.permute(0, 2, 3, 1)
out = self.normals_head(out) # (B, H, W, 3)
# Normalize to unit length
out = F.normalize(out, p=2, dim=-1)
# Resize to original image size
if image_size is not None:
# PyTorch interpolate expects (B, C, H, W)
out = out.permute(0, 3, 1, 2)
out = F.interpolate(out, size=image_size, mode="bilinear",
align_corners=False)
else:
out = out.permute(0, 3, 1, 2)
return out
class DPTSegmentationHead(nn.Module):
"""Full DPT head + segmentation decoder.
Takes 4 intermediate ViT features and produces a segmentation map.
"""
def __init__(self, input_embed_dim: int = 1024,
channels: int = 256,
post_process_channels: tuple = (128, 256, 512, 1024),
readout_type: str = "project",
num_classes: int = 150):
super().__init__()
# Reassemble: project + resize
self.reassemble = ReassembleBlocks(
input_embed_dim=input_embed_dim,
out_channels=post_process_channels,
readout_type=readout_type,
)
# 3Γ3 convs to map each level to `channels`
self.convs = nn.ModuleList([
nn.Conv2d(ch, channels, 3, padding=1, bias=False)
for ch in post_process_channels
])
# Fusion blocks: first has no residual, rest have residual
self.fusion_blocks = nn.ModuleList([
FeatureFusionBlock(channels, has_residual=False),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
FeatureFusionBlock(channels, has_residual=True),
])
# Final projection
self.project = nn.Conv2d(channels, channels, 3, padding=1, bias=True)
# Segmentation head (Dense layer)
self.segmentation_head = nn.Linear(channels, num_classes)
def forward(self, intermediate_features, image_size=None):
"""Run DPT segmentation prediction.
Args:
intermediate_features: list of 4 (cls_token, patch_feats) tuples
image_size: (H, W) to resize output to, or None
Returns:
segmentation map tensor (B, num_classes, H, W)
"""
# Reassemble
x = self.reassemble(intermediate_features)
# 3Γ3 conv per level
x = [self.convs[i](feat) for i, feat in enumerate(x)]
# Fuse bottom-up: start from deepest (x[-1])
out = self.fusion_blocks[0](x[-1])
for i in range(1, 4):
out = self.fusion_blocks[i](out, residual=x[-(i + 1)])
# Project
out = self.project(out)
# Segmentation head
# out: (B, C, H, W) β (B, H, W, C)
out = out.permute(0, 2, 3, 1)
out = self.segmentation_head(out) # (B, H, W, num_classes)
# Resize to original image size
if image_size is not None:
# PyTorch interpolate expects (B, C, H, W)
out = out.permute(0, 3, 1, 2)
out = F.interpolate(out, size=image_size, mode="bilinear",
align_corners=False)
else:
out = out.permute(0, 3, 1, 2)
return out
# ββ Weight loading from Scenic/Flax checkpoint βββββββββββββββββββββββββββββ
def _load_npy_from_zip(zf, name):
"""Load a single .npy array from a zipfile."""
with zf.open(name) as f:
return np.load(io.BytesIO(f.read()))
def _conv_kernel_flax_to_torch(w):
"""Convert Flax conv kernel (H,W,Cin,Cout) β PyTorch (Cout,Cin,H,W)."""
return torch.from_numpy(w.transpose(3, 2, 0, 1).copy())
def _conv_transpose_kernel_flax_to_torch(w):
"""Convert Flax ConvTranspose kernel (H,W,Cin,Cout) β PyTorch (Cin,Cout,H,W)."""
return torch.from_numpy(w.transpose(2, 3, 0, 1).copy())
def _linear_kernel_flax_to_torch(w):
"""Convert Flax Dense kernel (in,out) β PyTorch Linear (out,in)."""
return torch.from_numpy(w.T.copy())
def _bias(w):
return torch.from_numpy(w.copy())
def load_dpt_weights(model: DPTDepthHead, zip_path: str):
"""Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
zf = zipfile.ZipFile(zip_path, "r")
npy = lambda name: _load_npy_from_zip(zf, name)
sd = {}
prefix = "decoder/dpt/"
# --- ReassembleBlocks ---
for i in range(4):
# out_projections (Conv2d 1Γ1)
sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
sd[f"reassemble.out_projections.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
# readout_projects (Linear)
sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
sd["reassemble.resize_layers.0.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
sd["reassemble.resize_layers.1.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
# resize_layers_2 = Identity (no weights)
sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
sd["reassemble.resize_layers.3.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
# --- Convs (3Γ3, no bias) ---
for i in range(4):
sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}convs_{i}/kernel.npy"))
# --- Fusion blocks ---
for i in range(4):
fb = f"{prefix}fusion_blocks_{i}/"
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
else:
# Residual unit (index 0) + main unit (index 1)
sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
# out_conv (Conv2d 1Γ1)
sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}Conv_0/kernel.npy"))
sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
npy(f"{fb}Conv_0/bias.npy"))
# --- Project ---
sd["project.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}project/kernel.npy"))
sd["project.bias"] = _bias(
npy(f"{prefix}project/bias.npy"))
# --- Depth classification head ---
sd["depth_head.weight"] = _linear_kernel_flax_to_torch(
npy("decoder/pixel_depth_classif/kernel.npy"))
sd["depth_head.bias"] = _bias(
npy("decoder/pixel_depth_classif/bias.npy"))
zf.close()
# Load into model
missing, unexpected = model.load_state_dict(sd, strict=True)
if missing:
print(f"WARNING: Missing keys: {missing}")
if unexpected:
print(f"WARNING: Unexpected keys: {unexpected}")
print(f"Loaded DPT depth head weights ({len(sd)} tensors)")
return model
def load_normals_weights(model: DPTNormalsHead, zip_path: str):
"""Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
zf = zipfile.ZipFile(zip_path, "r")
npy = lambda name: _load_npy_from_zip(zf, name)
sd = {}
prefix = "decoder/dpt/"
# --- ReassembleBlocks ---
for i in range(4):
# out_projections (Conv2d 1Γ1)
sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
sd[f"reassemble.out_projections.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
# readout_projects (Linear)
sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
sd["reassemble.resize_layers.0.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
sd["reassemble.resize_layers.1.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
# resize_layers_2 = Identity (no weights)
sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
sd["reassemble.resize_layers.3.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
# --- Convs (3Γ3, no bias) ---
for i in range(4):
sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}convs_{i}/kernel.npy"))
# --- Fusion blocks ---
for i in range(4):
fb = f"{prefix}fusion_blocks_{i}/"
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
else:
# Residual unit (index 0) + main unit (index 1)
sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
# out_conv (Conv2d 1Γ1)
sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}Conv_0/kernel.npy"))
sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
npy(f"{fb}Conv_0/bias.npy"))
# --- Project ---
sd["project.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}project/kernel.npy"))
sd["project.bias"] = _bias(
npy(f"{prefix}project/bias.npy"))
# --- Normals head ---
sd["normals_head.weight"] = _linear_kernel_flax_to_torch(
npy("decoder/pixel_normals/kernel.npy"))
sd["normals_head.bias"] = _bias(
npy("decoder/pixel_normals/bias.npy"))
zf.close()
# Load into model
missing, unexpected = model.load_state_dict(sd, strict=True)
if missing:
print(f"WARNING: Missing keys: {missing}")
if unexpected:
print(f"WARNING: Unexpected keys: {unexpected}")
print(f"Loaded DPT normals head weights ({len(sd)} tensors)")
return model
def load_segmentation_weights(model: DPTSegmentationHead, zip_path: str):
"""Load Scenic/Flax DPT weights from a zip/npz file into PyTorch model."""
zf = zipfile.ZipFile(zip_path, "r")
npy = lambda name: _load_npy_from_zip(zf, name)
sd = {}
prefix = "decoder/dpt/"
# --- ReassembleBlocks ---
for i in range(4):
# out_projections (Conv2d 1Γ1)
sd[f"reassemble.out_projections.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/kernel.npy"))
sd[f"reassemble.out_projections.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/out_projection_{i}/bias.npy"))
# readout_projects (Linear)
sd[f"reassemble.readout_projects.{i}.weight"] = _linear_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/kernel.npy"))
sd[f"reassemble.readout_projects.{i}.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/readout_projects_{i}/bias.npy"))
# resize_layers: 0=ConvTranspose, 1=ConvTranspose, 2=Identity, 3=Conv
sd["reassemble.resize_layers.0.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_0/kernel.npy"))
sd["reassemble.resize_layers.0.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_0/bias.npy"))
sd["reassemble.resize_layers.1.weight"] = _conv_transpose_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_1/kernel.npy"))
sd["reassemble.resize_layers.1.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_1/bias.npy"))
# resize_layers_2 = Identity (no weights)
sd["reassemble.resize_layers.3.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}reassemble_blocks/resize_layers_3/kernel.npy"))
sd["reassemble.resize_layers.3.bias"] = _bias(
npy(f"{prefix}reassemble_blocks/resize_layers_3/bias.npy"))
# --- Convs (3Γ3, no bias) ---
for i in range(4):
sd[f"convs.{i}.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}convs_{i}/kernel.npy"))
# --- Fusion blocks ---
for i in range(4):
fb = f"{prefix}fusion_blocks_{i}/"
if i == 0:
# No residual unit, only 1 PreActResidualConvUnit
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
else:
# Residual unit (index 0) + main unit (index 1)
sd[f"fusion_blocks.{i}.residual_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.residual_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_0/conv2/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv1.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv1/kernel.npy"))
sd[f"fusion_blocks.{i}.main_unit.conv2.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}PreActResidualConvUnit_1/conv2/kernel.npy"))
# out_conv (Conv2d 1Γ1)
sd[f"fusion_blocks.{i}.out_conv.weight"] = _conv_kernel_flax_to_torch(
npy(f"{fb}Conv_0/kernel.npy"))
sd[f"fusion_blocks.{i}.out_conv.bias"] = _bias(
npy(f"{fb}Conv_0/bias.npy"))
# --- Project ---
sd["project.weight"] = _conv_kernel_flax_to_torch(
npy(f"{prefix}project/kernel.npy"))
sd["project.bias"] = _bias(
npy(f"{prefix}project/bias.npy"))
# --- Segmentation head ---
sd["segmentation_head.weight"] = _linear_kernel_flax_to_torch(
npy("decoder/pixel_segmentation/kernel.npy"))
sd["segmentation_head.bias"] = _bias(
npy("decoder/pixel_segmentation/bias.npy"))
zf.close()
# Load into model
missing, unexpected = model.load_state_dict(sd, strict=True)
if missing:
print(f"WARNING: Missing keys: {missing}")
if unexpected:
print(f"WARNING: Unexpected keys: {unexpected}")
print(f"Loaded DPT segmentation head weights ({len(sd)} tensors)")
return model
|