""" Simulate sparse-LiDAR observations from dense ground-truth depth. Patterns: random / scan-line / grid / hybrid. Used during training so the prompt encoder sees realistic sparsity. Simulation runs on tensors so it can sit inside the data loader or the training step. """ from __future__ import annotations import math import torch def _validate(depth: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert depth.dim() == 4 and mask.dim() == 4, "expect (B,1,H,W)" assert depth.shape == mask.shape return depth, mask def random_pattern(mask: torch.Tensor, density: float, generator: torch.Generator | None = None) -> torch.Tensor: """Bernoulli sparse mask with the given fraction of valid pixels kept.""" keep = torch.rand(mask.shape, generator=generator, device=mask.device) < density return mask & keep def scan_line_pattern( mask: torch.Tensor, n_lines: int, line_density: float, generator: torch.Generator | None = None, ) -> torch.Tensor: """Velodyne-like horizontal scan lines. Picks `n_lines` distinct row indices and within each row keeps a Bernoulli fraction `line_density` of points. """ B, _, H, W = mask.shape out = torch.zeros_like(mask) for b in range(B): n = max(1, min(n_lines, H)) rows = torch.randperm(H, generator=generator, device=mask.device)[:n] line_mask = torch.zeros((H, W), dtype=torch.bool, device=mask.device) line_mask[rows] = torch.rand((n, W), generator=generator, device=mask.device) < line_density out[b, 0] = line_mask return mask & out def grid_pattern(mask: torch.Tensor, stride: int) -> torch.Tensor: """Regular grid: keep every `stride` pixel along each axis.""" B, _, H, W = mask.shape out = torch.zeros_like(mask) out[:, :, ::stride, ::stride] = True return mask & out def hybrid_pattern( mask: torch.Tensor, density: float, n_lines: int, line_density: float, grid_stride: int, generator: torch.Generator | None = None, ) -> torch.Tensor: """Union of random + scan-line + grid (some real LiDARs fall in the middle).""" a = random_pattern(mask, density, generator=generator) b = scan_line_pattern(mask, n_lines, line_density, generator=generator) c = grid_pattern(mask, grid_stride) return a | b | c def simulate( depth: torch.Tensor, mask: torch.Tensor, pattern: str = "hybrid", *, density: float = 0.005, n_lines: int = 64, line_density: float = 0.5, grid_stride: int = 32, min_points: int = 16, max_attempts: int = 4, measurement_noise_std: float = 0.0, generator: torch.Generator | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Return (sparse_depth, sparse_mask) for a batch of dense GT. `min_points` guarantees at least that many observed pixels per sample (re-samples with looser params if not). Output sparse_depth is zero where mask is false. """ depth, mask = _validate(depth, mask) B = depth.shape[0] out_mask = torch.zeros_like(mask) for attempt in range(max_attempts): if pattern == "random": cur = random_pattern(mask, density, generator=generator) elif pattern == "scan_line": cur = scan_line_pattern(mask, n_lines, line_density, generator=generator) elif pattern == "grid": cur = grid_pattern(mask, grid_stride) elif pattern == "hybrid": cur = hybrid_pattern( mask, density, n_lines, line_density, grid_stride, generator=generator ) else: raise ValueError(f"unknown pattern: {pattern}") # update samples that already cleared the threshold per_sample = cur.flatten(1).sum(dim=1) ok = per_sample >= min_points out_mask = torch.where(ok[:, None, None, None], cur, out_mask) if ok.all(): break # loosen for next attempt: double density, grid step halves density = min(density * 2.0, 0.5) n_lines = min(n_lines * 2, depth.shape[-2]) line_density = min(line_density * 1.5, 1.0) grid_stride = max(grid_stride // 2, 2) sparse_depth = depth * out_mask.float() if measurement_noise_std > 0.0: noise = torch.randn(sparse_depth.shape, generator=generator, device=sparse_depth.device) sparse_depth = sparse_depth + noise * measurement_noise_std * out_mask.float() sparse_depth = sparse_depth.clamp_min(0.0) return sparse_depth, out_mask def random_pattern_choice(rng: torch.Generator | None = None) -> str: """Sample a pattern name uniformly. Used by the trainer to mix patterns per-step.""" options = ["random", "scan_line", "grid", "hybrid"] if rng is None: idx = int(torch.randint(0, len(options), (1,)).item()) else: idx = int(torch.randint(0, len(options), (1,), generator=rng).item()) return options[idx]