Spaces:
Running on Zero
Running on Zero
| # SPDX-License-Identifier: Apache-2.0 | |
| # Copyright (c) 2026 World Labs. | |
| """Depth-side decoding: model tokens → depth pixel map. | |
| Decoding steps: | |
| 1. Unpatchify the (tok_h, tok_w, patch_h * patch_w) token grid. | |
| 2. Divide out the post-normalization scalar `depth_value_scale`. | |
| 3. Apply the inverse of every recorded normalization stage, in reverse. | |
| Only `"contract"` (the mip-NeRF 360 radial squash, arXiv 2111.12077 | |
| Eq. 10) is genuinely invertible — the other stages discard per-sample | |
| statistics during encoding and pass through unchanged on decode. | |
| """ | |
| from __future__ import annotations | |
| import dataclasses | |
| import einops | |
| import torch | |
| from torch import Tensor | |
| _KNOWN_STAGES = frozenset({"contract", "normal", "mean", "unit_mean", "none"}) | |
| class DepthConfig: | |
| """How depth was normalized at training time. Decoder reads these.""" | |
| depth_normalize_mode: tuple[str, ...] | str = ("unit_mean", "contract") | |
| patch_size: int = 16 | |
| depth_value_scale: float = 1.0 | |
| def __post_init__(self) -> None: | |
| stages = self.depth_normalize_mode | |
| if isinstance(stages, str): | |
| object.__setattr__(self, "depth_normalize_mode", (stages,)) | |
| unknown = set(self.depth_normalize_mode) - _KNOWN_STAGES | |
| if unknown: | |
| raise ValueError(f"unknown stages {sorted(unknown)}; expected {sorted(_KNOWN_STAGES)}") | |
| if self.depth_value_scale <= 0: | |
| raise ValueError(f"depth_value_scale must be > 0; got {self.depth_value_scale}") | |
| if self.patch_size <= 0: | |
| raise ValueError(f"patch_size must be > 0; got {self.patch_size}") | |
| def _inv_contract(z: Tensor) -> Tensor: | |
| """Inverse of the mip-NeRF 360 radial contract bijection (operates on last dim). | |
| Contract maps R^n into the ball of radius 2; this undoes it. For |z| <= 1 | |
| the contract was identity, so this is also identity in that region. | |
| """ | |
| mag_sq = torch.clamp(z.square().sum(dim=-1, keepdim=True), min=1.0) | |
| return z / (2.0 * mag_sq.sqrt() - mag_sq) | |
| def _undo_stage(depth: Tensor, stage: str) -> Tensor: | |
| if stage != "contract": # "normal"/"mean"/"unit_mean"/"none" are pass-through. | |
| return depth | |
| if depth.shape[-1] == 1: | |
| return _inv_contract(depth) | |
| return _inv_contract(depth.unsqueeze(-1)).squeeze(-1) | |
| def decode_depth(tokens: Tensor, config: DepthConfig) -> Tensor: | |
| """Token grid -> depth pixel map. | |
| Input shape: (..., tok_h, tok_w, patch_h * patch_w) | |
| Output shape: (..., tok_h * patch_size, tok_w * patch_size, 1) | |
| """ | |
| p = config.patch_size | |
| depth = einops.rearrange(tokens, "... h w (p1 p2 c) -> ... (h p1) (w p2) c", p1=p, p2=p) | |
| if config.depth_value_scale != 1.0: | |
| depth = depth / config.depth_value_scale | |
| for stage in reversed(config.depth_normalize_mode): | |
| depth = _undo_stage(depth, stage) | |
| return depth | |
| def _contract(z: Tensor) -> Tensor: | |
| """Forward mip-NeRF 360 radial contract (inverse of `_inv_contract`). | |
| Identity for |z| <= 1; otherwise squashes radius r into (2 - 1/r). | |
| """ | |
| mag = z.norm(dim=-1, keepdim=True) | |
| safe = mag.clamp(min=1e-12) | |
| scale = torch.where(mag <= 1.0, torch.ones_like(mag), (2.0 - 1.0 / safe) / safe) | |
| return z * scale | |
| def _apply_stage(depth: Tensor, stage: str) -> Tensor: | |
| """Forward of `_undo_stage`. Only `contract` and `unit_mean` are active. | |
| `unit_mean` divides by the per-map mean so the depth is scale-normalised | |
| the way the model was trained; it is pass-through on decode (the absolute | |
| scale is not recoverable), so encode/decode round-trips up to that scale. | |
| """ | |
| if stage == "unit_mean": | |
| return depth / depth.mean().clamp(min=1e-12) | |
| if stage != "contract": # "normal"/"mean"/"none" are pass-through. | |
| return depth | |
| if depth.shape[-1] == 1: | |
| return _contract(depth) | |
| return _contract(depth.unsqueeze(-1)).squeeze(-1) | |
| def encode_depth(depth_map: Tensor, config: DepthConfig) -> Tensor: | |
| """Depth pixel map -> token grid (inverse of `decode_depth`). | |
| Input shape: (..., H, W, 1) | |
| Output shape: (..., H/patch_size, W/patch_size, patch_size**2) | |
| Used for ``mode="d2i"`` to turn a depth map into the depth-stream tokens | |
| the model conditions on. Because `unit_mean` is scale-normalising, only the | |
| relative depth structure matters — the input need not be metric. | |
| """ | |
| p = config.patch_size | |
| depth = depth_map | |
| for stage in config.depth_normalize_mode: | |
| depth = _apply_stage(depth, stage) | |
| if config.depth_value_scale != 1.0: | |
| depth = depth * config.depth_value_scale | |
| return einops.rearrange(depth, "... (h p1) (w p2) c -> ... h w (p1 p2 c)", p1=p, p2=p) | |