| """ |
| LPD-specific losses. |
| |
| `anchor_loss` (paper Eq. — sparse-anchor consistency): |
| L = (1 / |M|) * sum_{p ∈ M} | x̂_0(p) - y(p) | |
| |
| penalises deviations of the predicted clean depth at observed sparse-LiDAR |
| pixels in the same normalized log-depth space the DiT predicts in. |
| """ |
| from __future__ import annotations |
|
|
| import torch |
|
|
|
|
| def anchor_loss( |
| pred_x0: torch.Tensor, |
| sparse_target: torch.Tensor, |
| sparse_mask: torch.Tensor, |
| *, |
| eps: float = 1e-6, |
| ) -> torch.Tensor: |
| """L1 between predicted x_0 and the sparse target at observed pixels. |
| |
| All inputs are in the model's normalized space ([-0.5, 0.5] log-quantile). |
| """ |
| m = sparse_mask.float() |
| diff = (pred_x0 - sparse_target).abs() * m |
| denom = m.sum().clamp_min(eps) |
| return diff.sum() / denom |
|
|