Atlas-online / src /dataset /scene_sampler.py
guoyb0's picture
Add files using upload-large-folder tool
9fe982a verified
"""DDP-safe scene-sequential sampler for online temporal training.
Guarantees:
1. Within each scene, frames are yielded in strict timestamp order.
2. Scene order is shuffled per-epoch for training diversity.
3. All ranks yield exactly the same number of micro-steps per epoch
(balanced by greedy scene assignment + deterministic replay padding).
4. Epoch boundaries and replay-scene starts are detectable by the caller
via timestamp regression, so StreamPETR memory can be reset correctly.
"""
import random
from typing import Dict, Iterator, List, Sequence
from torch.utils.data import Sampler
class SceneSequentialSampler(Sampler[int]):
"""Distributed temporal sampler with equal-step guarantee."""
def __init__(
self,
scene_groups: Dict[str, List[int]],
num_replicas: int = 1,
rank: int = 0,
seed: int = 0,
shuffle_scenes: bool = True,
pad_to_multiple: int = 1,
):
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if rank < 0 or rank >= num_replicas:
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
if not scene_groups:
raise ValueError("scene_groups must not be empty")
self.scene_groups = scene_groups
self.scene_ids = sorted(scene_groups.keys())
self.num_replicas = int(num_replicas)
self.rank = int(rank)
self.seed = int(seed)
self.shuffle_scenes = bool(shuffle_scenes)
self.pad_to_multiple = max(1, int(pad_to_multiple))
self.epoch = 0
self._cached: List[int] = []
self._target_len = 0
def set_epoch(self, epoch: int) -> None:
self.epoch = int(epoch)
self._cached = []
self._target_len = 0
def _scene_len(self, sid: str) -> int:
return len(self.scene_groups[sid])
def _build_indices(self) -> List[int]:
rng = random.Random(self.seed + self.epoch)
scene_order = list(self.scene_ids)
if self.shuffle_scenes:
rng.shuffle(scene_order)
per_rank_scenes: List[List[str]] = [[] for _ in range(self.num_replicas)]
per_rank_counts = [0] * self.num_replicas
for sid in sorted(scene_order, key=self._scene_len, reverse=True):
target = min(range(self.num_replicas), key=lambda r: per_rank_counts[r])
per_rank_scenes[target].append(sid)
per_rank_counts[target] += self._scene_len(sid)
for rid in range(self.num_replicas):
if not per_rank_scenes[rid]:
fallback = scene_order[rid % len(scene_order)]
per_rank_scenes[rid].append(fallback)
per_rank_counts[rid] += self._scene_len(fallback)
target_count = max(per_rank_counts)
if target_count % self.pad_to_multiple != 0:
target_count = (
(target_count + self.pad_to_multiple - 1)
// self.pad_to_multiple
) * self.pad_to_multiple
self._target_len = target_count
my_scenes = per_rank_scenes[self.rank]
if self.shuffle_scenes:
rng2 = random.Random(self.seed + self.epoch + self.rank)
rng2.shuffle(my_scenes)
indices: List[int] = []
for sid in my_scenes:
indices.extend(self.scene_groups[sid])
if len(indices) < target_count:
replay_pool = list(my_scenes)
cursor = 0
while len(indices) < target_count:
sid = replay_pool[cursor % len(replay_pool)]
indices.extend(self.scene_groups[sid])
cursor += 1
indices = indices[:target_count]
return indices
def __iter__(self) -> Iterator[int]:
self._cached = self._build_indices()
return iter(self._cached)
def __len__(self) -> int:
if self._target_len == 0:
self._cached = self._build_indices()
return self._target_len