File size: 4,960 Bytes
436b829 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Simulate sparse-LiDAR observations from dense ground-truth depth.
Patterns: random / scan-line / grid / hybrid. Used during training so the prompt
encoder sees realistic sparsity. Simulation runs on tensors so it can sit
inside the data loader or the training step.
"""
from __future__ import annotations
import math
import torch
def _validate(depth: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
assert depth.dim() == 4 and mask.dim() == 4, "expect (B,1,H,W)"
assert depth.shape == mask.shape
return depth, mask
def random_pattern(mask: torch.Tensor, density: float, generator: torch.Generator | None = None) -> torch.Tensor:
"""Bernoulli sparse mask with the given fraction of valid pixels kept."""
keep = torch.rand(mask.shape, generator=generator, device=mask.device) < density
return mask & keep
def scan_line_pattern(
mask: torch.Tensor,
n_lines: int,
line_density: float,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Velodyne-like horizontal scan lines.
Picks `n_lines` distinct row indices and within each row keeps a Bernoulli
fraction `line_density` of points.
"""
B, _, H, W = mask.shape
out = torch.zeros_like(mask)
for b in range(B):
n = max(1, min(n_lines, H))
rows = torch.randperm(H, generator=generator, device=mask.device)[:n]
line_mask = torch.zeros((H, W), dtype=torch.bool, device=mask.device)
line_mask[rows] = torch.rand((n, W), generator=generator, device=mask.device) < line_density
out[b, 0] = line_mask
return mask & out
def grid_pattern(mask: torch.Tensor, stride: int) -> torch.Tensor:
"""Regular grid: keep every `stride` pixel along each axis."""
B, _, H, W = mask.shape
out = torch.zeros_like(mask)
out[:, :, ::stride, ::stride] = True
return mask & out
def hybrid_pattern(
mask: torch.Tensor,
density: float,
n_lines: int,
line_density: float,
grid_stride: int,
generator: torch.Generator | None = None,
) -> torch.Tensor:
"""Union of random + scan-line + grid (some real LiDARs fall in the middle)."""
a = random_pattern(mask, density, generator=generator)
b = scan_line_pattern(mask, n_lines, line_density, generator=generator)
c = grid_pattern(mask, grid_stride)
return a | b | c
def simulate(
depth: torch.Tensor,
mask: torch.Tensor,
pattern: str = "hybrid",
*,
density: float = 0.005,
n_lines: int = 64,
line_density: float = 0.5,
grid_stride: int = 32,
min_points: int = 16,
max_attempts: int = 4,
measurement_noise_std: float = 0.0,
generator: torch.Generator | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (sparse_depth, sparse_mask) for a batch of dense GT.
`min_points` guarantees at least that many observed pixels per sample
(re-samples with looser params if not). Output sparse_depth is zero where
mask is false.
"""
depth, mask = _validate(depth, mask)
B = depth.shape[0]
out_mask = torch.zeros_like(mask)
for attempt in range(max_attempts):
if pattern == "random":
cur = random_pattern(mask, density, generator=generator)
elif pattern == "scan_line":
cur = scan_line_pattern(mask, n_lines, line_density, generator=generator)
elif pattern == "grid":
cur = grid_pattern(mask, grid_stride)
elif pattern == "hybrid":
cur = hybrid_pattern(
mask, density, n_lines, line_density, grid_stride, generator=generator
)
else:
raise ValueError(f"unknown pattern: {pattern}")
# update samples that already cleared the threshold
per_sample = cur.flatten(1).sum(dim=1)
ok = per_sample >= min_points
out_mask = torch.where(ok[:, None, None, None], cur, out_mask)
if ok.all():
break
# loosen for next attempt: double density, grid step halves
density = min(density * 2.0, 0.5)
n_lines = min(n_lines * 2, depth.shape[-2])
line_density = min(line_density * 1.5, 1.0)
grid_stride = max(grid_stride // 2, 2)
sparse_depth = depth * out_mask.float()
if measurement_noise_std > 0.0:
noise = torch.randn(sparse_depth.shape, generator=generator, device=sparse_depth.device)
sparse_depth = sparse_depth + noise * measurement_noise_std * out_mask.float()
sparse_depth = sparse_depth.clamp_min(0.0)
return sparse_depth, out_mask
def random_pattern_choice(rng: torch.Generator | None = None) -> str:
"""Sample a pattern name uniformly. Used by the trainer to mix patterns per-step."""
options = ["random", "scan_line", "grid", "hybrid"]
if rng is None:
idx = int(torch.randint(0, len(options), (1,)).item())
else:
idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
return options[idx]
|