File size: 794 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""
LPD-specific losses.

`anchor_loss` (paper Eq. — sparse-anchor consistency):
    L = (1 / |M|) * sum_{p ∈ M} | x̂_0(p) - y(p) |

penalises deviations of the predicted clean depth at observed sparse-LiDAR
pixels in the same normalized log-depth space the DiT predicts in.
"""
from __future__ import annotations

import torch


def anchor_loss(
    pred_x0: torch.Tensor,
    sparse_target: torch.Tensor,
    sparse_mask: torch.Tensor,
    *,
    eps: float = 1e-6,
) -> torch.Tensor:
    """L1 between predicted x_0 and the sparse target at observed pixels.

    All inputs are in the model's normalized space ([-0.5, 0.5] log-quantile).
    """
    m = sparse_mask.float()
    diff = (pred_x0 - sparse_target).abs() * m
    denom = m.sum().clamp_min(eps)
    return diff.sum() / denom