chenming-wu's picture
code
436b829 verified
"""
LPD-specific losses.
`anchor_loss` (paper Eq. — sparse-anchor consistency):
L = (1 / |M|) * sum_{p ∈ M} | x̂_0(p) - y(p) |
penalises deviations of the predicted clean depth at observed sparse-LiDAR
pixels in the same normalized log-depth space the DiT predicts in.
"""
from __future__ import annotations
import torch
def anchor_loss(
pred_x0: torch.Tensor,
sparse_target: torch.Tensor,
sparse_mask: torch.Tensor,
*,
eps: float = 1e-6,
) -> torch.Tensor:
"""L1 between predicted x_0 and the sparse target at observed pixels.
All inputs are in the model's normalized space ([-0.5, 0.5] log-quantile).
"""
m = sparse_mask.float()
diff = (pred_x0 - sparse_target).abs() * m
denom = m.sum().clamp_min(eps)
return diff.sum() / denom