modality_forcing / flux_rgbd /depth /preprocess.py
bartduis's picture
Initial public release
e298226
Raw
History Blame Contribute Delete
4.71 kB
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 World Labs.
"""Depth-side decoding: model tokens → depth pixel map.
Decoding steps:
1. Unpatchify the (tok_h, tok_w, patch_h * patch_w) token grid.
2. Divide out the post-normalization scalar `depth_value_scale`.
3. Apply the inverse of every recorded normalization stage, in reverse.
Only `"contract"` (the mip-NeRF 360 radial squash, arXiv 2111.12077
Eq. 10) is genuinely invertible — the other stages discard per-sample
statistics during encoding and pass through unchanged on decode.
"""
from __future__ import annotations
import dataclasses
import einops
import torch
from torch import Tensor
_KNOWN_STAGES = frozenset({"contract", "normal", "mean", "unit_mean", "none"})
@dataclasses.dataclass(frozen=True)
class DepthConfig:
"""How depth was normalized at training time. Decoder reads these."""
depth_normalize_mode: tuple[str, ...] | str = ("unit_mean", "contract")
patch_size: int = 16
depth_value_scale: float = 1.0
def __post_init__(self) -> None:
stages = self.depth_normalize_mode
if isinstance(stages, str):
object.__setattr__(self, "depth_normalize_mode", (stages,))
unknown = set(self.depth_normalize_mode) - _KNOWN_STAGES
if unknown:
raise ValueError(f"unknown stages {sorted(unknown)}; expected {sorted(_KNOWN_STAGES)}")
if self.depth_value_scale <= 0:
raise ValueError(f"depth_value_scale must be > 0; got {self.depth_value_scale}")
if self.patch_size <= 0:
raise ValueError(f"patch_size must be > 0; got {self.patch_size}")
def _inv_contract(z: Tensor) -> Tensor:
"""Inverse of the mip-NeRF 360 radial contract bijection (operates on last dim).
Contract maps R^n into the ball of radius 2; this undoes it. For |z| <= 1
the contract was identity, so this is also identity in that region.
"""
mag_sq = torch.clamp(z.square().sum(dim=-1, keepdim=True), min=1.0)
return z / (2.0 * mag_sq.sqrt() - mag_sq)
def _undo_stage(depth: Tensor, stage: str) -> Tensor:
if stage != "contract": # "normal"/"mean"/"unit_mean"/"none" are pass-through.
return depth
if depth.shape[-1] == 1:
return _inv_contract(depth)
return _inv_contract(depth.unsqueeze(-1)).squeeze(-1)
def decode_depth(tokens: Tensor, config: DepthConfig) -> Tensor:
"""Token grid -> depth pixel map.
Input shape: (..., tok_h, tok_w, patch_h * patch_w)
Output shape: (..., tok_h * patch_size, tok_w * patch_size, 1)
"""
p = config.patch_size
depth = einops.rearrange(tokens, "... h w (p1 p2 c) -> ... (h p1) (w p2) c", p1=p, p2=p)
if config.depth_value_scale != 1.0:
depth = depth / config.depth_value_scale
for stage in reversed(config.depth_normalize_mode):
depth = _undo_stage(depth, stage)
return depth
def _contract(z: Tensor) -> Tensor:
"""Forward mip-NeRF 360 radial contract (inverse of `_inv_contract`).
Identity for |z| <= 1; otherwise squashes radius r into (2 - 1/r).
"""
mag = z.norm(dim=-1, keepdim=True)
safe = mag.clamp(min=1e-12)
scale = torch.where(mag <= 1.0, torch.ones_like(mag), (2.0 - 1.0 / safe) / safe)
return z * scale
def _apply_stage(depth: Tensor, stage: str) -> Tensor:
"""Forward of `_undo_stage`. Only `contract` and `unit_mean` are active.
`unit_mean` divides by the per-map mean so the depth is scale-normalised
the way the model was trained; it is pass-through on decode (the absolute
scale is not recoverable), so encode/decode round-trips up to that scale.
"""
if stage == "unit_mean":
return depth / depth.mean().clamp(min=1e-12)
if stage != "contract": # "normal"/"mean"/"none" are pass-through.
return depth
if depth.shape[-1] == 1:
return _contract(depth)
return _contract(depth.unsqueeze(-1)).squeeze(-1)
def encode_depth(depth_map: Tensor, config: DepthConfig) -> Tensor:
"""Depth pixel map -> token grid (inverse of `decode_depth`).
Input shape: (..., H, W, 1)
Output shape: (..., H/patch_size, W/patch_size, patch_size**2)
Used for ``mode="d2i"`` to turn a depth map into the depth-stream tokens
the model conditions on. Because `unit_mean` is scale-normalising, only the
relative depth structure matters — the input need not be metric.
"""
p = config.patch_size
depth = depth_map
for stage in config.depth_normalize_mode:
depth = _apply_stage(depth, stage)
if config.depth_value_scale != 1.0:
depth = depth * config.depth_value_scale
return einops.rearrange(depth, "... (h p1) (w p2) c -> ... h w (p1 p2 c)", p1=p, p2=p)