| """ |
| Per-claim verification of the LPD implementation against paper.tex. |
| |
| Each `check()` block ties one specific paper claim to the code that |
| realises it; failures are reported with the missing/incorrect piece. Run as |
| |
| cd /mnt/sig/pixel-perfect-depth |
| python -m ppd.lpd.tests.verify_paper |
| |
| Exits with non-zero status if any check fails. Designed to be a |
| single-pass audit, not a unit-test suite — it executes real tensor ops |
| on small inputs to confirm shapes, equations, and gradient flow. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
| import math |
| import inspect |
| from typing import Callable |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| CHECKS: list[tuple[str, Callable[[], None]]] = [] |
|
|
|
|
| def check(name: str): |
| def deco(fn: Callable[[], None]): |
| CHECKS.append((name, fn)) |
| return fn |
| return deco |
|
|
|
|
| def assert_close(actual: torch.Tensor, expected: torch.Tensor, msg: str, atol: float = 1e-5): |
| if not torch.allclose(actual, expected, atol=atol): |
| diff = (actual - expected).abs().max().item() |
| raise AssertionError(f"{msg}: max diff {diff:.3e}") |
|
|
|
|
| |
| @check("§3.1: sparse-prompt encoder pools at scales {4, 8, 16, 32}") |
| def _(): |
| from ppd.lpd.prompt_encoder import SparsePromptEncoder |
| enc = SparsePromptEncoder() |
| assert tuple(enc.scales) == (4, 8, 16, 32), f"scales={enc.scales}" |
|
|
|
|
| @check("§3.1: encoder produces both depth and density per scale") |
| def _(): |
| from ppd.lpd.prompt_encoder import masked_avg_pool |
| d = torch.zeros(1, 1, 32, 32); d[:, :, ::8, ::8] = 1.0 |
| m = (d > 0).float() |
| pooled, density = masked_avg_pool(d, m, kernel=4) |
| assert pooled.shape == density.shape == (1, 1, 8, 8) |
| |
| assert (density.max() - 1 / 16.0).abs() < 1e-5 |
| |
| assert pooled.max() == 1.0 |
|
|
|
|
| @check("§3.1: encoder applies a two-layer CNN + linear projection") |
| def _(): |
| from ppd.lpd.prompt_encoder import SparsePromptEncoder, _SmallCNN |
| enc = SparsePromptEncoder() |
| |
| n_convs = sum(1 for m in enc.per_scale[0].net if isinstance(m, torch.nn.Conv2d)) |
| assert n_convs == 2, f"expect two convs, got {n_convs}" |
| |
| assert isinstance(enc.fuse, torch.nn.Linear) |
|
|
|
|
| @check("§3.1: prompt-aware quantile log normalization produces ~[-0.5, 0.5]") |
| def _(): |
| from ppd.lpd.prompt_encoder import quantile_log_normalize |
| d = torch.linspace(0.5, 50.0, 256).reshape(1, 1, 16, 16) |
| m = torch.ones_like(d) |
| nd = quantile_log_normalize(d, m) |
| assert -0.5 - 1e-3 <= nd.min().item() <= 0.05 |
| assert 0.5 - 0.05 <= nd.max().item() <= 1.0 |
|
|
|
|
| @check("§3.1 Eq.(1): prompt gate computes s_sem + g(p,ρ,t) ⊙ m(s_sem,p,ρ,t)") |
| def _(): |
| from ppd.lpd.prompt_gate import PromptGate |
| g = PromptGate(embed_dim=32, timestep_dim=32, hidden=32) |
| B, T, D = 2, 8, 32 |
| s_sem = torch.randn(B, T, D) |
| p = torch.randn(B, T, D) |
| rho = torch.rand(B, T, 1) |
| t = torch.randn(B, D) |
| out = g(s_sem, p, rho, t) |
| |
| |
| assert_close(out, s_sem, "joint should equal s_sem at init", atol=1e-5) |
|
|
|
|
| @check("§3.1: m and g are zero-initialized so model starts as pretrained PPD") |
| def _(): |
| from ppd.lpd.prompt_gate import PromptGate |
| g = PromptGate(embed_dim=16, timestep_dim=16, hidden=16) |
| |
| assert torch.all(g.mixer[-1].weight == 0) |
| assert torch.all(g.mixer[-1].bias == 0) |
| |
| assert torch.all(g.gate[-2].weight == 0) |
| assert torch.all(g.gate[-2].bias == 0) |
|
|
|
|
| @check("§3.1: timestep embedding is projected before entering the gate") |
| def _(): |
| |
| from ppd.models.dit import TimestepEmbedder |
| t_embed = TimestepEmbedder(hidden_size=32) |
| assert isinstance(t_embed.mlp, torch.nn.Sequential) |
| assert sum(1 for m in t_embed.mlp if isinstance(m, torch.nn.Linear)) == 2 |
|
|
|
|
| |
| @check("§3.3 Eq.(5): LiDAR likelihood gradient = -M ⊙ (x - y) / R") |
| def _(): |
| from ppd.lpd.posterior_projection import posterior_project |
| x = torch.full((1, 1, 4, 4), 0.5) |
| y = torch.full((1, 1, 4, 4), 0.0) |
| M = torch.ones_like(x) |
| out = posterior_project( |
| x, sigma_t=torch.tensor(1.0), |
| sparse_depth=y, sparse_mask=M, R=0.1, |
| mu_prior=None, P_prior=None, alpha=1.0, |
| ) |
| |
| expected = x + 1.0 * (-M * (x - y) / 0.1) |
| assert_close(out, expected, "Eq.(5) projection") |
|
|
|
|
| @check("§3.3 Eq.(6): Kalman temporal-prior gradient = -(x - μ) / P") |
| def _(): |
| from ppd.lpd.posterior_projection import posterior_project |
| x = torch.full((1, 1, 4, 4), 0.3) |
| mu = torch.full((1, 1, 4, 4), 0.1) |
| P = torch.full((1, 1, 4, 4), 0.5) |
| out = posterior_project( |
| x, sigma_t=torch.tensor(1.0), |
| sparse_depth=torch.zeros_like(x), |
| sparse_mask=torch.zeros_like(x), |
| R=0.1, |
| mu_prior=mu, P_prior=P, alpha=1.0, |
| ) |
| expected = x + 1.0 * (-(x - mu) / P) |
| assert_close(out, expected, "Eq.(6) Kalman prior gradient") |
|
|
|
|
| @check("§3.3 Eq.(7): η_τ = α · σ_τ²") |
| def _(): |
| from ppd.lpd.posterior_projection import posterior_project |
| x = torch.full((1, 1, 2, 2), 1.0) |
| y = torch.zeros_like(x) |
| M = torch.ones_like(x) |
| sigma = torch.tensor(0.5) |
| out = posterior_project( |
| x, sigma, sparse_depth=y, sparse_mask=M, R=1.0, |
| mu_prior=None, P_prior=None, alpha=2.0, |
| ) |
| eta = 2.0 * 0.5 ** 2 |
| expected = x + eta * (-M * (x - y) / 1.0) |
| assert_close(out, expected, "Eq.(7) step-size schedule") |
|
|
|
|
| |
| @check("§3.4 Algorithm 1: Kalman gain K = P / (P + σ²)") |
| def _(): |
| P = torch.tensor(0.5); sig2 = torch.tensor(0.25) |
| K = P / (P + sig2) |
| assert (K - 2 / 3).abs() < 1e-6 |
|
|
|
|
| @check("§3.4 Algorithm 1: variance update P_τ = (1-K) P_{τ-1} ⇒ monotone decrease") |
| def _(): |
| P = torch.tensor(1.0) |
| for sig2 in [1.0, 0.5, 0.25, 0.0625]: |
| K = P / (P + sig2) |
| P_new = (1 - K) * P |
| assert P_new <= P + 1e-9, "variance must not grow" |
| P = P_new |
|
|
|
|
| @check("§3.4: Kalman state μ_τ = μ_{τ-1} + K (x̂_0 - μ_{τ-1})") |
| def _(): |
| mu = torch.tensor(0.0); P = torch.tensor(1.0) |
| x_hat = torch.tensor(1.0); sig2 = torch.tensor(1.0) |
| K = P / (P + sig2) |
| mu_new = mu + K * (x_hat - mu) |
| assert (mu_new - 0.5).abs() < 1e-6 |
|
|
|
|
| @check("§3.4: kalman_in_loop_sample returns (depth, posterior_variance)") |
| def _(): |
| import inspect |
| from ppd.lpd.kalman_in_loop import kalman_in_loop_sample |
| sig = inspect.signature(kalman_in_loop_sample) |
| assert "x_T" in sig.parameters and "sparse_depth" in sig.parameters |
|
|
|
|
| |
| @check("§3.5: predict step warps state and inflates variance by Q") |
| def _(): |
| from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig |
| kf = TemporalKalmanFilter( |
| shape=(1, 1, 8, 8), device=torch.device("cpu"), |
| config=TemporalKalmanConfig(Q_base=0.1, alpha=0.0, P_init=0.0, occ_threshold=999.0), |
| ) |
| kf.mu.fill_(1.0) |
| kf.P.fill_(0.0) |
| kf.has_state = True |
| flow = torch.zeros(1, 2, 8, 8) |
| kf.predict(flow_fwd=flow, flow_bwd=flow) |
| |
| assert (kf.P - 0.1).abs().max() < 1e-5 |
|
|
|
|
| @check("§3.5 Eq.(9): forward-backward error ε = ||p + f_fwd + f_bwd(p+f_fwd)||") |
| def _(): |
| from ppd.lpd.temporal_kalman import forward_backward_error |
| f_fwd = torch.zeros(1, 2, 8, 8); f_fwd[:, 0] = 2.0 |
| f_bwd = -f_fwd |
| eps = forward_backward_error(f_fwd, f_bwd) |
| assert eps.max() < 1e-3, f"ε should be ~0, got {eps.max().item()}" |
|
|
|
|
| @check("§3.5: occluded pixels (ε > τ_occ) reset variance to P_max") |
| def _(): |
| from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig |
| kf = TemporalKalmanFilter( |
| shape=(1, 1, 8, 8), device=torch.device("cpu"), |
| config=TemporalKalmanConfig(P_max=99.0, occ_threshold=0.5), |
| ) |
| kf.mu.fill_(1.0); kf.P.fill_(0.1); kf.has_state = True |
| f_fwd = torch.zeros(1, 2, 8, 8); f_fwd[:, 0] = 5.0 |
| f_bwd = torch.zeros_like(f_fwd) |
| kf.predict(f_fwd, f_bwd) |
| assert kf.P.max() >= 99.0 |
|
|
|
|
| @check("§3.5: update step Kalman gain K = P / (P + R) at observed pixels") |
| def _(): |
| from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig |
| kf = TemporalKalmanFilter( |
| shape=(1, 1, 4, 4), device=torch.device("cpu"), |
| config=TemporalKalmanConfig(R=0.1, P_init=1.0), |
| ) |
| sd = torch.full((1, 1, 4, 4), 0.5) |
| sm = torch.ones_like(sd) |
| mu, P = kf.update(sd, sm) |
| K = 1.0 / (1.0 + 0.1) |
| expected_mu = 0.0 + K * (0.5 - 0.0) |
| expected_P = (1 - K) * 1.0 |
| assert (mu - expected_mu).abs().max() < 1e-5 |
| assert (P - expected_P).abs().max() < 1e-5 |
|
|
|
|
| @check("§3.5: at unobserved pixels (mask=0), state passes through unchanged") |
| def _(): |
| from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig |
| kf = TemporalKalmanFilter( |
| shape=(1, 1, 4, 4), device=torch.device("cpu"), |
| config=TemporalKalmanConfig(R=0.1, P_init=0.5), |
| ) |
| kf.mu.fill_(0.7) |
| sd = torch.zeros(1, 1, 4, 4) |
| sm = torch.zeros_like(sd) |
| mu, P = kf.update(sd, sm) |
| assert (mu - 0.7).abs().max() < 1e-6 |
| assert (P - 0.5).abs().max() < 1e-6 |
|
|
|
|
| @check("§3.5: metric uncertainty = exp(sqrt(P)) - 1") |
| def _(): |
| from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig |
| kf = TemporalKalmanFilter( |
| shape=(1, 1, 1, 1), device=torch.device("cpu"), |
| config=TemporalKalmanConfig(P_init=0.25), |
| ) |
| expected = math.exp(math.sqrt(0.25)) - 1 |
| actual = kf.metric_uncertainty().item() |
| assert abs(actual - expected) < 1e-5 |
|
|
|
|
| |
| @check("§3.6 Eq.(8): ρ̃(p) = ρ(p) · (1 + P(p)/max P)") |
| def _(): |
| from ppd.lpd.uncertainty_modulation import modulate_density |
| rho = torch.full((1, 4, 1), 0.5) |
| P_full = torch.tensor([0.0, 0.5, 1.0, 2.0]).reshape(1, 1, 1, 4) |
| rho_tilde = modulate_density(rho, P_full) |
| |
| expected = 0.5 * (1 + P_full.squeeze(2).squeeze(1).reshape(1, 4, 1) / 2.0) |
| assert_close(rho_tilde, expected, "Eq.(8) modulation") |
|
|
|
|
| |
| @check("§3.7: anchor loss is L1(x̂_0 - y) over observed pixels") |
| def _(): |
| from ppd.lpd.losses import anchor_loss |
| x = torch.tensor([[[[0.0, 0.5, 1.0, 0.0]]]]).float() |
| y = torch.tensor([[[[0.5, 0.5, 0.5, 0.0]]]]).float() |
| m = torch.tensor([[[[1.0, 1.0, 1.0, 0.0]]]]) |
| |
| loss = anchor_loss(x, y, m).item() |
| assert abs(loss - 1.0 / 3) < 1e-6 |
|
|
|
|
| @check("§3.7: total training loss combines MSE + λ_a anchor + λ_g grad") |
| def _(): |
| src = inspect.getsource(__import__("ppd.lpd.lpd_train", fromlist=["LiDARPerfectDepth"])) |
| assert "lambda_anchor" in src |
| assert "anchor_loss" in src |
| assert "multi_scale_grad_loss" in src |
|
|
|
|
| @check("§3.7: backbone freeze leaves only prompt-encoder + gate trainable") |
| def _(): |
| from ppd.lpd.lpd_dit import LPDDiT |
| m = LPDDiT(hidden_size=128, depth=4, num_heads=4, patch_size=8) |
| m.freeze_backbone() |
| for n, p in m.named_parameters(): |
| if p.requires_grad: |
| assert n.startswith("sparse_prompt_encoder") or n.startswith("prompt_gate"), \ |
| f"unexpected trainable: {n}" |
|
|
|
|
| |
| @check("§4.1: sparse simulator implements random / scan_line / grid / hybrid") |
| def _(): |
| from ppd.lpd.sparse_simulator import simulate |
| d = torch.ones(1, 1, 32, 32); m = torch.ones_like(d, dtype=torch.bool) |
| for pat in ["random", "scan_line", "grid", "hybrid"]: |
| sd, sm = simulate(d, m, pattern=pat, density=0.05, n_lines=4, |
| line_density=0.5, grid_stride=8, min_points=4) |
| assert sm.sum() > 0, f"{pat} produced no observations" |
| |
| assert (sd[~sm] == 0).all().item() |
|
|
|
|
| |
| @check("§4.4: temporal Kalman defaults — R=0.01, Q_base=0.005, α=0.5, P_max=10, τ_occ=2.0") |
| def _(): |
| from ppd.lpd.temporal_kalman import TemporalKalmanConfig |
| c = TemporalKalmanConfig() |
| assert c.R == 0.01 |
| assert c.Q_base == 0.005 |
| assert c.alpha == 0.5 |
| assert c.P_max == 10.0 |
| assert c.occ_threshold == 2.0 |
|
|
|
|
| @check("§4.4: posterior projection R_proj defaults to 0.1") |
| def _(): |
| from ppd.lpd.kalman_in_loop import KalmanInLoopConfig |
| c = KalmanInLoopConfig() |
| assert c.R_proj == 0.1 |
|
|
|
|
| @check("§4.4: PPD weights load with smart partial loading (strict=False)") |
| def _(): |
| src = inspect.getsource(__import__("ppd.lpd.lpd_train", fromlist=["LiDARPerfectDepth"])) |
| assert "strict=False" in src |
| assert "_load_ppd_weights" in src |
|
|
|
|
| |
| @check("end-to-end: LPDDiT forward at training resolution returns (B,1,H,W)") |
| def _(): |
| from ppd.lpd.lpd_dit import LPDDiT |
| m = LPDDiT(hidden_size=128, depth=4, num_heads=4, patch_size=8) |
| B, H, W = 1, 64, 64 |
| x = torch.randn(B, 4, H, W) |
| sem = torch.randn(B, (H // 16) * (W // 16), 1024) |
| t = torch.tensor([100.0]) |
| sd = torch.zeros(B, 1, H, W); sd[:, :, ::8, ::8] = 0.3 |
| sm = (sd > 0) |
| out = m(x, sem, t, sparse_depth=sd, sparse_mask=sm) |
| assert out.shape == (B, 1, H, W) |
|
|
|
|
| @check("end-to-end: KIL sampler produces depth + variance maps with the right shapes") |
| def _(): |
| from ppd.utils.diffusion.timesteps import Timesteps |
| from ppd.utils.diffusion.schedule import LinearSchedule |
| from ppd.utils.diffusion.sampler import EulerSampler |
| from ppd.lpd.kalman_in_loop import kalman_in_loop_sample, KalmanInLoopConfig |
| sched = LinearSchedule(T=1000) |
| ts = Timesteps(T=1000, steps=4, device=torch.device("cpu")) |
| sampler = EulerSampler(schedule=sched, timesteps=ts, prediction_type="velocity") |
| B, H, W = 1, 32, 32 |
| x_T = torch.randn(B, 1, H, W) |
| cond = torch.randn(B, 3, H, W) |
| sd = torch.zeros(B, 1, H, W); sm = torch.zeros_like(sd) |
| def predict(x_tau, tau): return torch.zeros_like(x_tau) |
| out, P = kalman_in_loop_sample( |
| dit_predict_x0=predict, sampler=sampler, |
| timesteps=list(ts), x_T=x_T, cond=cond, |
| semantics_fn=lambda: None, |
| sparse_depth=sd, sparse_mask=sm, |
| config=KalmanInLoopConfig(), |
| ) |
| assert out.shape == (B, 1, H, W) |
| assert P.shape == (B, 1, H, W) |
|
|
|
|
| |
| def main() -> int: |
| ok = fail = 0 |
| width = max(len(name) for name, _ in CHECKS) |
| for name, fn in CHECKS: |
| try: |
| fn() |
| print(f" ✓ {name.ljust(width)}") |
| ok += 1 |
| except Exception as e: |
| print(f" ✗ {name.ljust(width)} → {type(e).__name__}: {e}") |
| fail += 1 |
| print(f"\n{ok} passed, {fail} failed (of {len(CHECKS)})") |
| return 0 if fail == 0 else 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.path.insert(0, os.getcwd()) |
| sys.exit(main()) |
|
|