File size: 4,021 Bytes
9fe982a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | """DDP-safe scene-sequential sampler for online temporal training.
Guarantees:
1. Within each scene, frames are yielded in strict timestamp order.
2. Scene order is shuffled per-epoch for training diversity.
3. All ranks yield exactly the same number of micro-steps per epoch
(balanced by greedy scene assignment + deterministic replay padding).
4. Epoch boundaries and replay-scene starts are detectable by the caller
via timestamp regression, so StreamPETR memory can be reset correctly.
"""
import random
from typing import Dict, Iterator, List, Sequence
from torch.utils.data import Sampler
class SceneSequentialSampler(Sampler[int]):
"""Distributed temporal sampler with equal-step guarantee."""
def __init__(
self,
scene_groups: Dict[str, List[int]],
num_replicas: int = 1,
rank: int = 0,
seed: int = 0,
shuffle_scenes: bool = True,
pad_to_multiple: int = 1,
):
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if rank < 0 or rank >= num_replicas:
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
if not scene_groups:
raise ValueError("scene_groups must not be empty")
self.scene_groups = scene_groups
self.scene_ids = sorted(scene_groups.keys())
self.num_replicas = int(num_replicas)
self.rank = int(rank)
self.seed = int(seed)
self.shuffle_scenes = bool(shuffle_scenes)
self.pad_to_multiple = max(1, int(pad_to_multiple))
self.epoch = 0
self._cached: List[int] = []
self._target_len = 0
def set_epoch(self, epoch: int) -> None:
self.epoch = int(epoch)
self._cached = []
self._target_len = 0
def _scene_len(self, sid: str) -> int:
return len(self.scene_groups[sid])
def _build_indices(self) -> List[int]:
rng = random.Random(self.seed + self.epoch)
scene_order = list(self.scene_ids)
if self.shuffle_scenes:
rng.shuffle(scene_order)
per_rank_scenes: List[List[str]] = [[] for _ in range(self.num_replicas)]
per_rank_counts = [0] * self.num_replicas
for sid in sorted(scene_order, key=self._scene_len, reverse=True):
target = min(range(self.num_replicas), key=lambda r: per_rank_counts[r])
per_rank_scenes[target].append(sid)
per_rank_counts[target] += self._scene_len(sid)
for rid in range(self.num_replicas):
if not per_rank_scenes[rid]:
fallback = scene_order[rid % len(scene_order)]
per_rank_scenes[rid].append(fallback)
per_rank_counts[rid] += self._scene_len(fallback)
target_count = max(per_rank_counts)
if target_count % self.pad_to_multiple != 0:
target_count = (
(target_count + self.pad_to_multiple - 1)
// self.pad_to_multiple
) * self.pad_to_multiple
self._target_len = target_count
my_scenes = per_rank_scenes[self.rank]
if self.shuffle_scenes:
rng2 = random.Random(self.seed + self.epoch + self.rank)
rng2.shuffle(my_scenes)
indices: List[int] = []
for sid in my_scenes:
indices.extend(self.scene_groups[sid])
if len(indices) < target_count:
replay_pool = list(my_scenes)
cursor = 0
while len(indices) < target_count:
sid = replay_pool[cursor % len(replay_pool)]
indices.extend(self.scene_groups[sid])
cursor += 1
indices = indices[:target_count]
return indices
def __iter__(self) -> Iterator[int]:
self._cached = self._build_indices()
return iter(self._cached)
def __len__(self) -> int:
if self._target_len == 0:
self._cached = self._build_indices()
return self._target_len
|