| """DDP-safe scene-sequential sampler for online temporal training. |
| |
| Guarantees: |
| 1. Within each scene, frames are yielded in strict timestamp order. |
| 2. Scene order is shuffled per-epoch for training diversity. |
| 3. All ranks yield exactly the same number of micro-steps per epoch |
| (balanced by greedy scene assignment + deterministic replay padding). |
| 4. Epoch boundaries and replay-scene starts are detectable by the caller |
| via timestamp regression, so StreamPETR memory can be reset correctly. |
| """ |
|
|
| import math |
| import random |
| from collections import Counter |
| from typing import Dict, Iterator, List, Sequence, Tuple |
|
|
| from torch.utils.data import Sampler |
|
|
|
|
| def _lcm(a: int, b: int) -> int: |
| a = int(a) |
| b = int(b) |
| if a == 0 or b == 0: |
| return 0 |
| return abs(a * b) // math.gcd(a, b) |
|
|
|
|
| def _assign_scenes_greedily( |
| scene_ids: Sequence[str], |
| scene_costs: Dict[str, int], |
| num_replicas: int, |
| ) -> Tuple[List[List[str]], List[int]]: |
| per_rank_scenes: List[List[str]] = [[] for _ in range(num_replicas)] |
| per_rank_costs = [0] * num_replicas |
|
|
| for sid in sorted(scene_ids, key=lambda s: scene_costs[s], reverse=True): |
| target = min(range(num_replicas), key=lambda r: per_rank_costs[r]) |
| per_rank_scenes[target].append(sid) |
| per_rank_costs[target] += int(scene_costs[sid]) |
|
|
| return per_rank_scenes, per_rank_costs |
|
|
|
|
| class SceneSequentialSampler(Sampler[int]): |
| """Distributed temporal sampler with equal-step guarantee.""" |
|
|
| def __init__( |
| self, |
| scene_groups: Dict[str, List[int]], |
| num_replicas: int = 1, |
| rank: int = 0, |
| seed: int = 0, |
| shuffle_scenes: bool = True, |
| pad_to_multiple: int = 1, |
| ): |
| if num_replicas < 1: |
| raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") |
| if rank < 0 or rank >= num_replicas: |
| raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}") |
| if not scene_groups: |
| raise ValueError("scene_groups must not be empty") |
|
|
| self.scene_groups = scene_groups |
| self.scene_ids = sorted(scene_groups.keys()) |
| self.num_replicas = int(num_replicas) |
| self.rank = int(rank) |
| self.seed = int(seed) |
| self.shuffle_scenes = bool(shuffle_scenes) |
| self.pad_to_multiple = max(1, int(pad_to_multiple)) |
| self.epoch = 0 |
| self._cached: List[int] = [] |
| self._target_len = 0 |
|
|
| def set_epoch(self, epoch: int) -> None: |
| self.epoch = int(epoch) |
| self._cached = [] |
| self._target_len = 0 |
|
|
| def _scene_len(self, sid: str) -> int: |
| return len(self.scene_groups[sid]) |
|
|
| def _build_indices(self) -> List[int]: |
| rng = random.Random(self.seed + self.epoch) |
| scene_order = list(self.scene_ids) |
| if self.shuffle_scenes: |
| rng.shuffle(scene_order) |
|
|
| scene_costs = {sid: self._scene_len(sid) for sid in scene_order} |
| per_rank_scenes, per_rank_counts = _assign_scenes_greedily( |
| scene_order, scene_costs, self.num_replicas |
| ) |
|
|
| for rid in range(self.num_replicas): |
| if not per_rank_scenes[rid]: |
| fallback = scene_order[rid % len(scene_order)] |
| per_rank_scenes[rid].append(fallback) |
| per_rank_counts[rid] += self._scene_len(fallback) |
|
|
| target_count = max(per_rank_counts) |
| if target_count % self.pad_to_multiple != 0: |
| target_count = ( |
| (target_count + self.pad_to_multiple - 1) |
| // self.pad_to_multiple |
| ) * self.pad_to_multiple |
| self._target_len = target_count |
|
|
| my_scenes = per_rank_scenes[self.rank] |
| if self.shuffle_scenes: |
| rng2 = random.Random(self.seed + self.epoch + self.rank) |
| rng2.shuffle(my_scenes) |
|
|
| indices: List[int] = [] |
| for sid in my_scenes: |
| indices.extend(self.scene_groups[sid]) |
|
|
| if len(indices) < target_count: |
| replay_pool = list(my_scenes) |
| cursor = 0 |
| while len(indices) < target_count: |
| sid = replay_pool[cursor % len(replay_pool)] |
| indices.extend(self.scene_groups[sid]) |
| cursor += 1 |
| indices = indices[:target_count] |
|
|
| return indices |
|
|
| def __iter__(self) -> Iterator[int]: |
| self._cached = self._build_indices() |
| return iter(self._cached) |
|
|
| def __len__(self) -> int: |
| if self._target_len == 0: |
| self._cached = self._build_indices() |
| return self._target_len |
|
|
|
|
| class SceneUnitTaskBalancedSampler(Sampler[int]): |
| """Scene-sequential sampler with unit-level 1:1:1 task balancing. |
| |
| Each scene is first converted into ordered timestamp units. A unit can |
| contain one detection sample, one planning sample, and six caption samples. |
| The sampler keeps scene/time monotonicity and emits raw dataset indices with |
| an exact unit-level balance across detection/planning/caption, while |
| naturally expanding caption to six raw samples. |
| """ |
|
|
| TASKS = ("detection", "planning", "caption") |
| UNIT_RAW_BLOCK_SIZE = 8 |
|
|
| def __init__( |
| self, |
| scene_unit_groups: Dict[str, List[Dict[str, object]]], |
| num_replicas: int = 1, |
| rank: int = 0, |
| seed: int = 0, |
| shuffle_scenes: bool = True, |
| pad_to_multiple: int = 1, |
| ): |
| if num_replicas < 1: |
| raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") |
| if rank < 0 or rank >= num_replicas: |
| raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}") |
| if not scene_unit_groups: |
| raise ValueError("scene_unit_groups must not be empty") |
|
|
| self.scene_unit_groups = scene_unit_groups |
| self.num_replicas = int(num_replicas) |
| self.rank = int(rank) |
| self.seed = int(seed) |
| self.shuffle_scenes = bool(shuffle_scenes) |
| self.pad_to_multiple = max(1, int(pad_to_multiple)) |
| self.epoch = 0 |
| self._cached: List[int] = [] |
| self._target_len = 0 |
| self._last_epoch_stats: Dict[str, object] = {} |
|
|
| self.scene_ids: List[str] = [] |
| self._scene_sequences: Dict[str, List[int]] = {} |
| self._scene_unit_counts: Dict[str, Dict[str, int]] = {} |
| self._scene_raw_counts: Dict[str, Dict[str, int]] = {} |
| self._scene_lengths: Dict[str, int] = {} |
| self._skipped_scenes: List[str] = [] |
| self._summary: Dict[str, object] = {} |
| self._build_scene_sequences() |
|
|
| def set_epoch(self, epoch: int) -> None: |
| self.epoch = int(epoch) |
| self._cached = [] |
| self._target_len = 0 |
| self._last_epoch_stats = {} |
|
|
| def _task_available(self, unit: Dict[str, object], task: str) -> bool: |
| if task == "detection": |
| return unit.get("detection_idx") is not None |
| if task == "planning": |
| return unit.get("planning_idx") is not None |
| if task == "caption": |
| return bool(unit.get("caption_indices")) |
| raise ValueError(f"Unsupported task: {task}") |
|
|
| def _emit_task(self, raw_indices: List[int], raw_counts: Counter, unit: Dict[str, object], task: str) -> None: |
| if task == "detection": |
| raw_indices.append(int(unit["detection_idx"])) |
| raw_counts["detection"] += 1 |
| return |
| if task == "planning": |
| raw_indices.append(int(unit["planning_idx"])) |
| raw_counts["planning"] += 1 |
| return |
| if task == "caption": |
| caption_indices = [int(v) for v in unit["caption_indices"]] |
| raw_indices.extend(caption_indices) |
| raw_counts["caption"] += len(caption_indices) |
| return |
| raise ValueError(f"Unsupported task: {task}") |
|
|
| def _build_scene_sequence( |
| self, scene_id: str, units: List[Dict[str, object]] |
| ) -> Tuple[List[int], Dict[str, int], Dict[str, int]]: |
| available_unit_counts = { |
| task: sum(1 for unit in units if self._task_available(unit, task)) |
| for task in self.TASKS |
| } |
| target_units = min(available_unit_counts.values()) |
| if target_units <= 0: |
| return [], {task: 0 for task in self.TASKS}, {task: 0 for task in self.TASKS} |
|
|
| remaining = {task: [0] * (len(units) + 1) for task in self.TASKS} |
| for i in range(len(units) - 1, -1, -1): |
| unit = units[i] |
| for task in self.TASKS: |
| remaining[task][i] = remaining[task][i + 1] + int(self._task_available(unit, task)) |
|
|
| selected_unit_counts: Counter = Counter() |
| raw_counts: Counter = Counter() |
| raw_indices: List[int] = [] |
|
|
| for i, unit in enumerate(units): |
| selected_tasks: List[str] = [] |
|
|
| if self._task_available(unit, "planning") and selected_unit_counts["planning"] < target_units: |
| selected_tasks.append("planning") |
| selected_unit_counts["planning"] += 1 |
|
|
| for task in ("detection", "caption"): |
| if not self._task_available(unit, task): |
| continue |
| if selected_unit_counts[task] >= target_units: |
| continue |
| need = target_units - selected_unit_counts[task] |
| remaining_with_task = remaining[task][i] |
| lagging_planning = selected_unit_counts[task] < selected_unit_counts["planning"] |
| must_take = need >= remaining_with_task |
| if lagging_planning or must_take: |
| selected_tasks.append(task) |
| selected_unit_counts[task] += 1 |
|
|
| for task in ("detection", "planning", "caption"): |
| if task in selected_tasks: |
| self._emit_task(raw_indices, raw_counts, unit, task) |
|
|
| expected = {task: int(target_units) for task in self.TASKS} |
| actual = {task: int(selected_unit_counts[task]) for task in self.TASKS} |
| if actual != expected: |
| raise RuntimeError( |
| "Failed to build exact scene-unit balance for " |
| f"scene={scene_id}: target={expected}, actual={actual}" |
| ) |
|
|
| return raw_indices, actual, {task: int(raw_counts[task]) for task in self.TASKS} |
|
|
| def _build_scene_sequences(self) -> None: |
| summary_unit_counts: Counter = Counter() |
| summary_raw_counts: Counter = Counter() |
|
|
| for scene_id in sorted(self.scene_unit_groups.keys()): |
| raw_indices, unit_counts, raw_counts = self._build_scene_sequence( |
| scene_id, self.scene_unit_groups[scene_id] |
| ) |
| if not raw_indices: |
| self._skipped_scenes.append(scene_id) |
| continue |
| self.scene_ids.append(scene_id) |
| self._scene_sequences[scene_id] = raw_indices |
| self._scene_unit_counts[scene_id] = unit_counts |
| self._scene_raw_counts[scene_id] = raw_counts |
| self._scene_lengths[scene_id] = len(raw_indices) |
| summary_unit_counts.update(unit_counts) |
| summary_raw_counts.update(raw_counts) |
|
|
| if not self.scene_ids: |
| raise ValueError("No scenes remain after scene-unit balancing") |
|
|
| self._summary = { |
| "num_scenes": len(self.scene_ids), |
| "skipped_scenes": len(self._skipped_scenes), |
| "unit_counts": {task: int(summary_unit_counts[task]) for task in self.TASKS}, |
| "raw_counts": {task: int(summary_raw_counts[task]) for task in self.TASKS}, |
| "raw_total": int(sum(summary_raw_counts.values())), |
| } |
|
|
| def get_summary(self) -> Dict[str, object]: |
| return dict(self._summary) |
|
|
| def get_last_epoch_stats(self) -> Dict[str, object]: |
| return dict(self._last_epoch_stats) |
|
|
| def _build_indices(self) -> List[int]: |
| rng = random.Random(self.seed + self.epoch) |
| scene_order = list(self.scene_ids) |
| if self.shuffle_scenes: |
| rng.shuffle(scene_order) |
|
|
| scene_costs = {sid: self._scene_lengths[sid] for sid in scene_order} |
| per_rank_scenes, per_rank_counts = _assign_scenes_greedily( |
| scene_order, scene_costs, self.num_replicas |
| ) |
|
|
| for rid in range(self.num_replicas): |
| if not per_rank_scenes[rid]: |
| fallback = scene_order[rid % len(scene_order)] |
| per_rank_scenes[rid].append(fallback) |
| per_rank_counts[rid] += self._scene_lengths[fallback] |
|
|
| target_count = max(per_rank_counts) |
| block_multiple = _lcm(self.pad_to_multiple, self.UNIT_RAW_BLOCK_SIZE) |
| if target_count % block_multiple != 0: |
| target_count = ( |
| (target_count + block_multiple - 1) |
| // block_multiple |
| ) * block_multiple |
| self._target_len = target_count |
|
|
| my_scenes = per_rank_scenes[self.rank] |
| if self.shuffle_scenes: |
| rng2 = random.Random(self.seed + self.epoch + self.rank) |
| rng2.shuffle(my_scenes) |
|
|
| indices: List[int] = [] |
| unit_counts_epoch: Counter = Counter() |
| raw_counts_epoch: Counter = Counter() |
| for sid in my_scenes: |
| indices.extend(self._scene_sequences[sid]) |
| unit_counts_epoch.update(self._scene_unit_counts[sid]) |
| raw_counts_epoch.update(self._scene_raw_counts[sid]) |
|
|
| prepad_len = len(indices) |
| replay_counts: Counter = Counter() |
| if len(indices) < target_count: |
| replay_pool = list(my_scenes) |
| cursor = 0 |
| while len(indices) < target_count: |
| sid = replay_pool[cursor % len(replay_pool)] |
| indices.extend(self._scene_sequences[sid]) |
| unit_counts_epoch.update(self._scene_unit_counts[sid]) |
| raw_counts_epoch.update(self._scene_raw_counts[sid]) |
| replay_counts[sid] += 1 |
| cursor += 1 |
| indices = indices[:target_count] |
|
|
| self._last_epoch_stats = { |
| "num_scenes_total": len(self.scene_ids), |
| "num_scenes_rank": len(my_scenes), |
| "num_skipped_scenes": len(self._skipped_scenes), |
| "prepad_len": int(prepad_len), |
| "target_len": int(target_count), |
| "replay_extra": int(target_count - prepad_len), |
| "unit_counts": {task: int(unit_counts_epoch[task]) for task in self.TASKS}, |
| "raw_counts": {task: int(raw_counts_epoch[task]) for task in self.TASKS}, |
| "replayed_scenes": {sid: int(v) for sid, v in replay_counts.items()}, |
| } |
| return indices |
|
|
| def __iter__(self) -> Iterator[int]: |
| self._cached = self._build_indices() |
| return iter(self._cached) |
|
|
| def __len__(self) -> int: |
| if self._target_len == 0: |
| self._cached = self._build_indices() |
| return self._target_len |
|
|