LiDAR-Perfect-Depth / code /ppd /lpd /tests /verify_paper.py
chenming-wu's picture
code
436b829 verified
"""
Per-claim verification of the LPD implementation against paper.tex.
Each `check()` block ties one specific paper claim to the code that
realises it; failures are reported with the missing/incorrect piece. Run as
cd /mnt/sig/pixel-perfect-depth
python -m ppd.lpd.tests.verify_paper
Exits with non-zero status if any check fails. Designed to be a
single-pass audit, not a unit-test suite — it executes real tensor ops
on small inputs to confirm shapes, equations, and gradient flow.
"""
from __future__ import annotations
import os
import sys
import math
import inspect
from typing import Callable
import torch
import torch.nn.functional as F
CHECKS: list[tuple[str, Callable[[], None]]] = []
def check(name: str):
def deco(fn: Callable[[], None]):
CHECKS.append((name, fn))
return fn
return deco
def assert_close(actual: torch.Tensor, expected: torch.Tensor, msg: str, atol: float = 1e-5):
if not torch.allclose(actual, expected, atol=atol):
diff = (actual - expected).abs().max().item()
raise AssertionError(f"{msg}: max diff {diff:.3e}")
# ----------------------------------------------------------------- §3.1 image
@check("§3.1: sparse-prompt encoder pools at scales {4, 8, 16, 32}")
def _():
from ppd.lpd.prompt_encoder import SparsePromptEncoder
enc = SparsePromptEncoder()
assert tuple(enc.scales) == (4, 8, 16, 32), f"scales={enc.scales}"
@check("§3.1: encoder produces both depth and density per scale")
def _():
from ppd.lpd.prompt_encoder import masked_avg_pool
d = torch.zeros(1, 1, 32, 32); d[:, :, ::8, ::8] = 1.0
m = (d > 0).float()
pooled, density = masked_avg_pool(d, m, kernel=4)
assert pooled.shape == density.shape == (1, 1, 8, 8)
# density at a fully-observed cell should be 1/16 (one observation in 4x4)
assert (density.max() - 1 / 16.0).abs() < 1e-5
# mask sum 16; pooled should equal 1 at sampled cells (since masked avg)
assert pooled.max() == 1.0
@check("§3.1: encoder applies a two-layer CNN + linear projection")
def _():
from ppd.lpd.prompt_encoder import SparsePromptEncoder, _SmallCNN
enc = SparsePromptEncoder()
# _SmallCNN: Conv → GELU → Conv (= 2 convs)
n_convs = sum(1 for m in enc.per_scale[0].net if isinstance(m, torch.nn.Conv2d))
assert n_convs == 2, f"expect two convs, got {n_convs}"
# Final projection should be a linear layer.
assert isinstance(enc.fuse, torch.nn.Linear)
@check("§3.1: prompt-aware quantile log normalization produces ~[-0.5, 0.5]")
def _():
from ppd.lpd.prompt_encoder import quantile_log_normalize
d = torch.linspace(0.5, 50.0, 256).reshape(1, 1, 16, 16)
m = torch.ones_like(d)
nd = quantile_log_normalize(d, m)
assert -0.5 - 1e-3 <= nd.min().item() <= 0.05
assert 0.5 - 0.05 <= nd.max().item() <= 1.0
@check("§3.1 Eq.(1): prompt gate computes s_sem + g(p,ρ,t) ⊙ m(s_sem,p,ρ,t)")
def _():
from ppd.lpd.prompt_gate import PromptGate
g = PromptGate(embed_dim=32, timestep_dim=32, hidden=32)
B, T, D = 2, 8, 32
s_sem = torch.randn(B, T, D)
p = torch.randn(B, T, D)
rho = torch.rand(B, T, 1)
t = torch.randn(B, D)
out = g(s_sem, p, rho, t)
# zero-init last layers ⇒ delta=0 and gate output passes sigmoid(0)=0.5,
# so g*delta = 0 and the joint should equal s_sem on init.
assert_close(out, s_sem, "joint should equal s_sem at init", atol=1e-5)
@check("§3.1: m and g are zero-initialized so model starts as pretrained PPD")
def _():
from ppd.lpd.prompt_gate import PromptGate
g = PromptGate(embed_dim=16, timestep_dim=16, hidden=16)
# last layer of mixer must be zero
assert torch.all(g.mixer[-1].weight == 0)
assert torch.all(g.mixer[-1].bias == 0)
# gate's pre-sigmoid linear must be zero (Linear is at index -2 before Sigmoid)
assert torch.all(g.gate[-2].weight == 0)
assert torch.all(g.gate[-2].bias == 0)
@check("§3.1: timestep embedding is projected before entering the gate")
def _():
# LPDDiT calls self.t_embedder(timestep) which contains a 2-layer MLP
from ppd.models.dit import TimestepEmbedder
t_embed = TimestepEmbedder(hidden_size=32)
assert isinstance(t_embed.mlp, torch.nn.Sequential)
assert sum(1 for m in t_embed.mlp if isinstance(m, torch.nn.Linear)) == 2
# ----------------------------------------------------------------- §3.3 score decomp
@check("§3.3 Eq.(5): LiDAR likelihood gradient = -M ⊙ (x - y) / R")
def _():
from ppd.lpd.posterior_projection import posterior_project
x = torch.full((1, 1, 4, 4), 0.5)
y = torch.full((1, 1, 4, 4), 0.0)
M = torch.ones_like(x)
out = posterior_project(
x, sigma_t=torch.tensor(1.0),
sparse_depth=y, sparse_mask=M, R=0.1,
mu_prior=None, P_prior=None, alpha=1.0,
)
# eta = sigma² · alpha = 1, kalman term zero ⇒ x ← x + 1 · (-M⊙(x-y)/R)
expected = x + 1.0 * (-M * (x - y) / 0.1)
assert_close(out, expected, "Eq.(5) projection")
@check("§3.3 Eq.(6): Kalman temporal-prior gradient = -(x - μ) / P")
def _():
from ppd.lpd.posterior_projection import posterior_project
x = torch.full((1, 1, 4, 4), 0.3)
mu = torch.full((1, 1, 4, 4), 0.1)
P = torch.full((1, 1, 4, 4), 0.5)
out = posterior_project(
x, sigma_t=torch.tensor(1.0),
sparse_depth=torch.zeros_like(x),
sparse_mask=torch.zeros_like(x),
R=0.1,
mu_prior=mu, P_prior=P, alpha=1.0,
)
expected = x + 1.0 * (-(x - mu) / P)
assert_close(out, expected, "Eq.(6) Kalman prior gradient")
@check("§3.3 Eq.(7): η_τ = α · σ_τ²")
def _():
from ppd.lpd.posterior_projection import posterior_project
x = torch.full((1, 1, 2, 2), 1.0)
y = torch.zeros_like(x)
M = torch.ones_like(x)
sigma = torch.tensor(0.5)
out = posterior_project(
x, sigma, sparse_depth=y, sparse_mask=M, R=1.0,
mu_prior=None, P_prior=None, alpha=2.0,
)
eta = 2.0 * 0.5 ** 2 # = 0.5
expected = x + eta * (-M * (x - y) / 1.0)
assert_close(out, expected, "Eq.(7) step-size schedule")
# ----------------------------------------------------------------- §3.4 KIL
@check("§3.4 Algorithm 1: Kalman gain K = P / (P + σ²)")
def _():
P = torch.tensor(0.5); sig2 = torch.tensor(0.25)
K = P / (P + sig2)
assert (K - 2 / 3).abs() < 1e-6
@check("§3.4 Algorithm 1: variance update P_τ = (1-K) P_{τ-1} ⇒ monotone decrease")
def _():
P = torch.tensor(1.0)
for sig2 in [1.0, 0.5, 0.25, 0.0625]:
K = P / (P + sig2)
P_new = (1 - K) * P
assert P_new <= P + 1e-9, "variance must not grow"
P = P_new
@check("§3.4: Kalman state μ_τ = μ_{τ-1} + K (x̂_0 - μ_{τ-1})")
def _():
mu = torch.tensor(0.0); P = torch.tensor(1.0)
x_hat = torch.tensor(1.0); sig2 = torch.tensor(1.0)
K = P / (P + sig2)
mu_new = mu + K * (x_hat - mu)
assert (mu_new - 0.5).abs() < 1e-6
@check("§3.4: kalman_in_loop_sample returns (depth, posterior_variance)")
def _():
import inspect
from ppd.lpd.kalman_in_loop import kalman_in_loop_sample
sig = inspect.signature(kalman_in_loop_sample)
assert "x_T" in sig.parameters and "sparse_depth" in sig.parameters
# ----------------------------------------------------------------- §3.5 temporal Kalman
@check("§3.5: predict step warps state and inflates variance by Q")
def _():
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
kf = TemporalKalmanFilter(
shape=(1, 1, 8, 8), device=torch.device("cpu"),
config=TemporalKalmanConfig(Q_base=0.1, alpha=0.0, P_init=0.0, occ_threshold=999.0),
)
kf.mu.fill_(1.0)
kf.P.fill_(0.0)
kf.has_state = True
flow = torch.zeros(1, 2, 8, 8) # zero flow ⇒ identity warp
kf.predict(flow_fwd=flow, flow_bwd=flow)
# variance should grow by Q_base since alpha=0
assert (kf.P - 0.1).abs().max() < 1e-5
@check("§3.5 Eq.(9): forward-backward error ε = ||p + f_fwd + f_bwd(p+f_fwd)||")
def _():
from ppd.lpd.temporal_kalman import forward_backward_error
f_fwd = torch.zeros(1, 2, 8, 8); f_fwd[:, 0] = 2.0 # +2 in x
f_bwd = -f_fwd # exact inverse ⇒ ε ≈ 0
eps = forward_backward_error(f_fwd, f_bwd)
assert eps.max() < 1e-3, f"ε should be ~0, got {eps.max().item()}"
@check("§3.5: occluded pixels (ε > τ_occ) reset variance to P_max")
def _():
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
kf = TemporalKalmanFilter(
shape=(1, 1, 8, 8), device=torch.device("cpu"),
config=TemporalKalmanConfig(P_max=99.0, occ_threshold=0.5),
)
kf.mu.fill_(1.0); kf.P.fill_(0.1); kf.has_state = True
f_fwd = torch.zeros(1, 2, 8, 8); f_fwd[:, 0] = 5.0 # 5px fwd
f_bwd = torch.zeros_like(f_fwd) # no return ⇒ ε = 5
kf.predict(f_fwd, f_bwd)
assert kf.P.max() >= 99.0
@check("§3.5: update step Kalman gain K = P / (P + R) at observed pixels")
def _():
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
kf = TemporalKalmanFilter(
shape=(1, 1, 4, 4), device=torch.device("cpu"),
config=TemporalKalmanConfig(R=0.1, P_init=1.0),
)
sd = torch.full((1, 1, 4, 4), 0.5)
sm = torch.ones_like(sd)
mu, P = kf.update(sd, sm)
K = 1.0 / (1.0 + 0.1)
expected_mu = 0.0 + K * (0.5 - 0.0)
expected_P = (1 - K) * 1.0
assert (mu - expected_mu).abs().max() < 1e-5
assert (P - expected_P).abs().max() < 1e-5
@check("§3.5: at unobserved pixels (mask=0), state passes through unchanged")
def _():
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
kf = TemporalKalmanFilter(
shape=(1, 1, 4, 4), device=torch.device("cpu"),
config=TemporalKalmanConfig(R=0.1, P_init=0.5),
)
kf.mu.fill_(0.7)
sd = torch.zeros(1, 1, 4, 4)
sm = torch.zeros_like(sd) # nothing observed
mu, P = kf.update(sd, sm)
assert (mu - 0.7).abs().max() < 1e-6
assert (P - 0.5).abs().max() < 1e-6
@check("§3.5: metric uncertainty = exp(sqrt(P)) - 1")
def _():
from ppd.lpd.temporal_kalman import TemporalKalmanFilter, TemporalKalmanConfig
kf = TemporalKalmanFilter(
shape=(1, 1, 1, 1), device=torch.device("cpu"),
config=TemporalKalmanConfig(P_init=0.25),
)
expected = math.exp(math.sqrt(0.25)) - 1
actual = kf.metric_uncertainty().item()
assert abs(actual - expected) < 1e-5
# ----------------------------------------------------------------- §3.6 modulation
@check("§3.6 Eq.(8): ρ̃(p) = ρ(p) · (1 + P(p)/max P)")
def _():
from ppd.lpd.uncertainty_modulation import modulate_density
rho = torch.full((1, 4, 1), 0.5)
P_full = torch.tensor([0.0, 0.5, 1.0, 2.0]).reshape(1, 1, 1, 4)
rho_tilde = modulate_density(rho, P_full)
# max P = 2.0; ρ̃ = 0.5 * (1 + P/2.0)
expected = 0.5 * (1 + P_full.squeeze(2).squeeze(1).reshape(1, 4, 1) / 2.0)
assert_close(rho_tilde, expected, "Eq.(8) modulation")
# ----------------------------------------------------------------- §3.7 training
@check("§3.7: anchor loss is L1(x̂_0 - y) over observed pixels")
def _():
from ppd.lpd.losses import anchor_loss
x = torch.tensor([[[[0.0, 0.5, 1.0, 0.0]]]]).float()
y = torch.tensor([[[[0.5, 0.5, 0.5, 0.0]]]]).float()
m = torch.tensor([[[[1.0, 1.0, 1.0, 0.0]]]])
# observed diffs: |0-0.5|+|0.5-0.5|+|1-0.5|=1.0; |M|=3 → 1/3
loss = anchor_loss(x, y, m).item()
assert abs(loss - 1.0 / 3) < 1e-6
@check("§3.7: total training loss combines MSE + λ_a anchor + λ_g grad")
def _():
src = inspect.getsource(__import__("ppd.lpd.lpd_train", fromlist=["LiDARPerfectDepth"]))
assert "lambda_anchor" in src
assert "anchor_loss" in src
assert "multi_scale_grad_loss" in src
@check("§3.7: backbone freeze leaves only prompt-encoder + gate trainable")
def _():
from ppd.lpd.lpd_dit import LPDDiT
m = LPDDiT(hidden_size=128, depth=4, num_heads=4, patch_size=8)
m.freeze_backbone()
for n, p in m.named_parameters():
if p.requires_grad:
assert n.startswith("sparse_prompt_encoder") or n.startswith("prompt_gate"), \
f"unexpected trainable: {n}"
# ----------------------------------------------------------------- §4.1 sparse simulator
@check("§4.1: sparse simulator implements random / scan_line / grid / hybrid")
def _():
from ppd.lpd.sparse_simulator import simulate
d = torch.ones(1, 1, 32, 32); m = torch.ones_like(d, dtype=torch.bool)
for pat in ["random", "scan_line", "grid", "hybrid"]:
sd, sm = simulate(d, m, pattern=pat, density=0.05, n_lines=4,
line_density=0.5, grid_stride=8, min_points=4)
assert sm.sum() > 0, f"{pat} produced no observations"
# depth should be zero where mask is false
assert (sd[~sm] == 0).all().item()
# ----------------------------------------------------------------- §4.4 implementation
@check("§4.4: temporal Kalman defaults — R=0.01, Q_base=0.005, α=0.5, P_max=10, τ_occ=2.0")
def _():
from ppd.lpd.temporal_kalman import TemporalKalmanConfig
c = TemporalKalmanConfig()
assert c.R == 0.01
assert c.Q_base == 0.005
assert c.alpha == 0.5
assert c.P_max == 10.0
assert c.occ_threshold == 2.0
@check("§4.4: posterior projection R_proj defaults to 0.1")
def _():
from ppd.lpd.kalman_in_loop import KalmanInLoopConfig
c = KalmanInLoopConfig()
assert c.R_proj == 0.1
@check("§4.4: PPD weights load with smart partial loading (strict=False)")
def _():
src = inspect.getsource(__import__("ppd.lpd.lpd_train", fromlist=["LiDARPerfectDepth"]))
assert "strict=False" in src
assert "_load_ppd_weights" in src
# ----------------------------------------------------------------- end-to-end shape sanity
@check("end-to-end: LPDDiT forward at training resolution returns (B,1,H,W)")
def _():
from ppd.lpd.lpd_dit import LPDDiT
m = LPDDiT(hidden_size=128, depth=4, num_heads=4, patch_size=8)
B, H, W = 1, 64, 64
x = torch.randn(B, 4, H, W)
sem = torch.randn(B, (H // 16) * (W // 16), 1024)
t = torch.tensor([100.0])
sd = torch.zeros(B, 1, H, W); sd[:, :, ::8, ::8] = 0.3
sm = (sd > 0)
out = m(x, sem, t, sparse_depth=sd, sparse_mask=sm)
assert out.shape == (B, 1, H, W)
@check("end-to-end: KIL sampler produces depth + variance maps with the right shapes")
def _():
from ppd.utils.diffusion.timesteps import Timesteps
from ppd.utils.diffusion.schedule import LinearSchedule
from ppd.utils.diffusion.sampler import EulerSampler
from ppd.lpd.kalman_in_loop import kalman_in_loop_sample, KalmanInLoopConfig
sched = LinearSchedule(T=1000)
ts = Timesteps(T=1000, steps=4, device=torch.device("cpu"))
sampler = EulerSampler(schedule=sched, timesteps=ts, prediction_type="velocity")
B, H, W = 1, 32, 32
x_T = torch.randn(B, 1, H, W)
cond = torch.randn(B, 3, H, W)
sd = torch.zeros(B, 1, H, W); sm = torch.zeros_like(sd)
def predict(x_tau, tau): return torch.zeros_like(x_tau)
out, P = kalman_in_loop_sample(
dit_predict_x0=predict, sampler=sampler,
timesteps=list(ts), x_T=x_T, cond=cond,
semantics_fn=lambda: None,
sparse_depth=sd, sparse_mask=sm,
config=KalmanInLoopConfig(),
)
assert out.shape == (B, 1, H, W)
assert P.shape == (B, 1, H, W)
# ----------------------------------------------------------------- runner
def main() -> int:
ok = fail = 0
width = max(len(name) for name, _ in CHECKS)
for name, fn in CHECKS:
try:
fn()
print(f" ✓ {name.ljust(width)}")
ok += 1
except Exception as e:
print(f" ✗ {name.ljust(width)}{type(e).__name__}: {e}")
fail += 1
print(f"\n{ok} passed, {fail} failed (of {len(CHECKS)})")
return 0 if fail == 0 else 1
if __name__ == "__main__":
sys.path.insert(0, os.getcwd())
sys.exit(main())