| """DDP-safe scene-sequential sampler for online temporal training. |
| |
| Guarantees: |
| 1. Within each scene, frames are yielded in strict timestamp order. |
| 2. Scene order is shuffled per-epoch for training diversity. |
| 3. All ranks yield exactly the same number of micro-steps per epoch |
| (balanced by greedy scene assignment + deterministic replay padding). |
| 4. Epoch boundaries and replay-scene starts are detectable by the caller |
| via timestamp regression, so StreamPETR memory can be reset correctly. |
| """ |
|
|
| import random |
| from typing import Dict, Iterator, List, Sequence |
|
|
| from torch.utils.data import Sampler |
|
|
|
|
| class SceneSequentialSampler(Sampler[int]): |
| """Distributed temporal sampler with equal-step guarantee.""" |
|
|
| def __init__( |
| self, |
| scene_groups: Dict[str, List[int]], |
| num_replicas: int = 1, |
| rank: int = 0, |
| seed: int = 0, |
| shuffle_scenes: bool = True, |
| pad_to_multiple: int = 1, |
| ): |
| if num_replicas < 1: |
| raise ValueError(f"num_replicas must be >= 1, got {num_replicas}") |
| if rank < 0 or rank >= num_replicas: |
| raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}") |
| if not scene_groups: |
| raise ValueError("scene_groups must not be empty") |
|
|
| self.scene_groups = scene_groups |
| self.scene_ids = sorted(scene_groups.keys()) |
| self.num_replicas = int(num_replicas) |
| self.rank = int(rank) |
| self.seed = int(seed) |
| self.shuffle_scenes = bool(shuffle_scenes) |
| self.pad_to_multiple = max(1, int(pad_to_multiple)) |
| self.epoch = 0 |
| self._cached: List[int] = [] |
| self._target_len = 0 |
|
|
| def set_epoch(self, epoch: int) -> None: |
| self.epoch = int(epoch) |
| self._cached = [] |
| self._target_len = 0 |
|
|
| def _scene_len(self, sid: str) -> int: |
| return len(self.scene_groups[sid]) |
|
|
| def _build_indices(self) -> List[int]: |
| rng = random.Random(self.seed + self.epoch) |
| scene_order = list(self.scene_ids) |
| if self.shuffle_scenes: |
| rng.shuffle(scene_order) |
|
|
| per_rank_scenes: List[List[str]] = [[] for _ in range(self.num_replicas)] |
| per_rank_counts = [0] * self.num_replicas |
|
|
| for sid in sorted(scene_order, key=self._scene_len, reverse=True): |
| target = min(range(self.num_replicas), key=lambda r: per_rank_counts[r]) |
| per_rank_scenes[target].append(sid) |
| per_rank_counts[target] += self._scene_len(sid) |
|
|
| for rid in range(self.num_replicas): |
| if not per_rank_scenes[rid]: |
| fallback = scene_order[rid % len(scene_order)] |
| per_rank_scenes[rid].append(fallback) |
| per_rank_counts[rid] += self._scene_len(fallback) |
|
|
| target_count = max(per_rank_counts) |
| if target_count % self.pad_to_multiple != 0: |
| target_count = ( |
| (target_count + self.pad_to_multiple - 1) |
| // self.pad_to_multiple |
| ) * self.pad_to_multiple |
| self._target_len = target_count |
|
|
| my_scenes = per_rank_scenes[self.rank] |
| if self.shuffle_scenes: |
| rng2 = random.Random(self.seed + self.epoch + self.rank) |
| rng2.shuffle(my_scenes) |
|
|
| indices: List[int] = [] |
| for sid in my_scenes: |
| indices.extend(self.scene_groups[sid]) |
|
|
| if len(indices) < target_count: |
| replay_pool = list(my_scenes) |
| cursor = 0 |
| while len(indices) < target_count: |
| sid = replay_pool[cursor % len(replay_pool)] |
| indices.extend(self.scene_groups[sid]) |
| cursor += 1 |
| indices = indices[:target_count] |
|
|
| return indices |
|
|
| def __iter__(self) -> Iterator[int]: |
| self._cached = self._build_indices() |
| return iter(self._cached) |
|
|
| def __len__(self) -> int: |
| if self._target_len == 0: |
| self._cached = self._build_indices() |
| return self._target_len |
|
|