File size: 4,960 Bytes
436b829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Simulate sparse-LiDAR observations from dense ground-truth depth.

Patterns: random / scan-line / grid / hybrid. Used during training so the prompt
encoder sees realistic sparsity. Simulation runs on tensors so it can sit
inside the data loader or the training step.
"""
from __future__ import annotations

import math
import torch


def _validate(depth: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    assert depth.dim() == 4 and mask.dim() == 4, "expect (B,1,H,W)"
    assert depth.shape == mask.shape
    return depth, mask


def random_pattern(mask: torch.Tensor, density: float, generator: torch.Generator | None = None) -> torch.Tensor:
    """Bernoulli sparse mask with the given fraction of valid pixels kept."""
    keep = torch.rand(mask.shape, generator=generator, device=mask.device) < density
    return mask & keep


def scan_line_pattern(
    mask: torch.Tensor,
    n_lines: int,
    line_density: float,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Velodyne-like horizontal scan lines.

    Picks `n_lines` distinct row indices and within each row keeps a Bernoulli
    fraction `line_density` of points.
    """
    B, _, H, W = mask.shape
    out = torch.zeros_like(mask)
    for b in range(B):
        n = max(1, min(n_lines, H))
        rows = torch.randperm(H, generator=generator, device=mask.device)[:n]
        line_mask = torch.zeros((H, W), dtype=torch.bool, device=mask.device)
        line_mask[rows] = torch.rand((n, W), generator=generator, device=mask.device) < line_density
        out[b, 0] = line_mask
    return mask & out


def grid_pattern(mask: torch.Tensor, stride: int) -> torch.Tensor:
    """Regular grid: keep every `stride` pixel along each axis."""
    B, _, H, W = mask.shape
    out = torch.zeros_like(mask)
    out[:, :, ::stride, ::stride] = True
    return mask & out


def hybrid_pattern(
    mask: torch.Tensor,
    density: float,
    n_lines: int,
    line_density: float,
    grid_stride: int,
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    """Union of random + scan-line + grid (some real LiDARs fall in the middle)."""
    a = random_pattern(mask, density, generator=generator)
    b = scan_line_pattern(mask, n_lines, line_density, generator=generator)
    c = grid_pattern(mask, grid_stride)
    return a | b | c


def simulate(
    depth: torch.Tensor,
    mask: torch.Tensor,
    pattern: str = "hybrid",
    *,
    density: float = 0.005,
    n_lines: int = 64,
    line_density: float = 0.5,
    grid_stride: int = 32,
    min_points: int = 16,
    max_attempts: int = 4,
    measurement_noise_std: float = 0.0,
    generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Return (sparse_depth, sparse_mask) for a batch of dense GT.

    `min_points` guarantees at least that many observed pixels per sample
    (re-samples with looser params if not). Output sparse_depth is zero where
    mask is false.
    """
    depth, mask = _validate(depth, mask)
    B = depth.shape[0]

    out_mask = torch.zeros_like(mask)
    for attempt in range(max_attempts):
        if pattern == "random":
            cur = random_pattern(mask, density, generator=generator)
        elif pattern == "scan_line":
            cur = scan_line_pattern(mask, n_lines, line_density, generator=generator)
        elif pattern == "grid":
            cur = grid_pattern(mask, grid_stride)
        elif pattern == "hybrid":
            cur = hybrid_pattern(
                mask, density, n_lines, line_density, grid_stride, generator=generator
            )
        else:
            raise ValueError(f"unknown pattern: {pattern}")
        # update samples that already cleared the threshold
        per_sample = cur.flatten(1).sum(dim=1)
        ok = per_sample >= min_points
        out_mask = torch.where(ok[:, None, None, None], cur, out_mask)
        if ok.all():
            break
        # loosen for next attempt: double density, grid step halves
        density = min(density * 2.0, 0.5)
        n_lines = min(n_lines * 2, depth.shape[-2])
        line_density = min(line_density * 1.5, 1.0)
        grid_stride = max(grid_stride // 2, 2)

    sparse_depth = depth * out_mask.float()
    if measurement_noise_std > 0.0:
        noise = torch.randn(sparse_depth.shape, generator=generator, device=sparse_depth.device)
        sparse_depth = sparse_depth + noise * measurement_noise_std * out_mask.float()
        sparse_depth = sparse_depth.clamp_min(0.0)
    return sparse_depth, out_mask


def random_pattern_choice(rng: torch.Generator | None = None) -> str:
    """Sample a pattern name uniformly. Used by the trainer to mix patterns per-step."""
    options = ["random", "scan_line", "grid", "hybrid"]
    if rng is None:
        idx = int(torch.randint(0, len(options), (1,)).item())
    else:
        idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
    return options[idx]