guoyb0's picture
Upload code snapshot (2task with caption)
95f6448 verified
"""DDP-safe scene-sequential sampler for online temporal training.
Guarantees:
1. Within each scene, frames are yielded in strict timestamp order.
2. Scene order is shuffled per-epoch for training diversity.
3. All ranks yield exactly the same number of micro-steps per epoch
(balanced by greedy scene assignment + deterministic replay padding).
4. Epoch boundaries and replay-scene starts are detectable by the caller
via timestamp regression, so StreamPETR memory can be reset correctly.
"""
import math
import random
from collections import Counter
from typing import Dict, Iterator, List, Sequence, Tuple
from torch.utils.data import Sampler
def _lcm(a: int, b: int) -> int:
a = int(a)
b = int(b)
if a == 0 or b == 0:
return 0
return abs(a * b) // math.gcd(a, b)
def _assign_scenes_greedily(
scene_ids: Sequence[str],
scene_costs: Dict[str, int],
num_replicas: int,
) -> Tuple[List[List[str]], List[int]]:
per_rank_scenes: List[List[str]] = [[] for _ in range(num_replicas)]
per_rank_costs = [0] * num_replicas
for sid in sorted(scene_ids, key=lambda s: scene_costs[s], reverse=True):
target = min(range(num_replicas), key=lambda r: per_rank_costs[r])
per_rank_scenes[target].append(sid)
per_rank_costs[target] += int(scene_costs[sid])
return per_rank_scenes, per_rank_costs
class SceneSequentialSampler(Sampler[int]):
"""Distributed temporal sampler with equal-step guarantee."""
def __init__(
self,
scene_groups: Dict[str, List[int]],
num_replicas: int = 1,
rank: int = 0,
seed: int = 0,
shuffle_scenes: bool = True,
pad_to_multiple: int = 1,
):
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if rank < 0 or rank >= num_replicas:
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
if not scene_groups:
raise ValueError("scene_groups must not be empty")
self.scene_groups = scene_groups
self.scene_ids = sorted(scene_groups.keys())
self.num_replicas = int(num_replicas)
self.rank = int(rank)
self.seed = int(seed)
self.shuffle_scenes = bool(shuffle_scenes)
self.pad_to_multiple = max(1, int(pad_to_multiple))
self.epoch = 0
self._cached: List[int] = []
self._target_len = 0
def set_epoch(self, epoch: int) -> None:
self.epoch = int(epoch)
self._cached = []
self._target_len = 0
def _scene_len(self, sid: str) -> int:
return len(self.scene_groups[sid])
def _build_indices(self) -> List[int]:
rng = random.Random(self.seed + self.epoch)
scene_order = list(self.scene_ids)
if self.shuffle_scenes:
rng.shuffle(scene_order)
scene_costs = {sid: self._scene_len(sid) for sid in scene_order}
per_rank_scenes, per_rank_counts = _assign_scenes_greedily(
scene_order, scene_costs, self.num_replicas
)
for rid in range(self.num_replicas):
if not per_rank_scenes[rid]:
fallback = scene_order[rid % len(scene_order)]
per_rank_scenes[rid].append(fallback)
per_rank_counts[rid] += self._scene_len(fallback)
target_count = max(per_rank_counts)
if target_count % self.pad_to_multiple != 0:
target_count = (
(target_count + self.pad_to_multiple - 1)
// self.pad_to_multiple
) * self.pad_to_multiple
self._target_len = target_count
my_scenes = per_rank_scenes[self.rank]
if self.shuffle_scenes:
rng2 = random.Random(self.seed + self.epoch + self.rank)
rng2.shuffle(my_scenes)
indices: List[int] = []
for sid in my_scenes:
indices.extend(self.scene_groups[sid])
if len(indices) < target_count:
replay_pool = list(my_scenes)
cursor = 0
while len(indices) < target_count:
sid = replay_pool[cursor % len(replay_pool)]
indices.extend(self.scene_groups[sid])
cursor += 1
indices = indices[:target_count]
return indices
def __iter__(self) -> Iterator[int]:
self._cached = self._build_indices()
return iter(self._cached)
def __len__(self) -> int:
if self._target_len == 0:
self._cached = self._build_indices()
return self._target_len
class SceneUnitTaskBalancedSampler(Sampler[int]):
"""Scene-sequential sampler with unit-level 1:1:1 task balancing.
Each scene is first converted into ordered timestamp units. A unit can
contain one detection sample, one planning sample, and six caption samples.
The sampler keeps scene/time monotonicity and emits raw dataset indices with
an exact unit-level balance across detection/planning/caption, while
naturally expanding caption to six raw samples.
"""
TASKS = ("detection", "planning", "caption")
UNIT_RAW_BLOCK_SIZE = 8
def __init__(
self,
scene_unit_groups: Dict[str, List[Dict[str, object]]],
num_replicas: int = 1,
rank: int = 0,
seed: int = 0,
shuffle_scenes: bool = True,
pad_to_multiple: int = 1,
):
if num_replicas < 1:
raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
if rank < 0 or rank >= num_replicas:
raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
if not scene_unit_groups:
raise ValueError("scene_unit_groups must not be empty")
self.scene_unit_groups = scene_unit_groups
self.num_replicas = int(num_replicas)
self.rank = int(rank)
self.seed = int(seed)
self.shuffle_scenes = bool(shuffle_scenes)
self.pad_to_multiple = max(1, int(pad_to_multiple))
self.epoch = 0
self._cached: List[int] = []
self._target_len = 0
self._last_epoch_stats: Dict[str, object] = {}
self.scene_ids: List[str] = []
self._scene_sequences: Dict[str, List[int]] = {}
self._scene_unit_counts: Dict[str, Dict[str, int]] = {}
self._scene_raw_counts: Dict[str, Dict[str, int]] = {}
self._scene_lengths: Dict[str, int] = {}
self._skipped_scenes: List[str] = []
self._summary: Dict[str, object] = {}
self._build_scene_sequences()
def set_epoch(self, epoch: int) -> None:
self.epoch = int(epoch)
self._cached = []
self._target_len = 0
self._last_epoch_stats = {}
def _task_available(self, unit: Dict[str, object], task: str) -> bool:
if task == "detection":
return unit.get("detection_idx") is not None
if task == "planning":
return unit.get("planning_idx") is not None
if task == "caption":
return bool(unit.get("caption_indices"))
raise ValueError(f"Unsupported task: {task}")
def _emit_task(self, raw_indices: List[int], raw_counts: Counter, unit: Dict[str, object], task: str) -> None:
if task == "detection":
raw_indices.append(int(unit["detection_idx"]))
raw_counts["detection"] += 1
return
if task == "planning":
raw_indices.append(int(unit["planning_idx"]))
raw_counts["planning"] += 1
return
if task == "caption":
caption_indices = [int(v) for v in unit["caption_indices"]]
raw_indices.extend(caption_indices)
raw_counts["caption"] += len(caption_indices)
return
raise ValueError(f"Unsupported task: {task}")
def _build_scene_sequence(
self, scene_id: str, units: List[Dict[str, object]]
) -> Tuple[List[int], Dict[str, int], Dict[str, int]]:
available_unit_counts = {
task: sum(1 for unit in units if self._task_available(unit, task))
for task in self.TASKS
}
target_units = min(available_unit_counts.values())
if target_units <= 0:
return [], {task: 0 for task in self.TASKS}, {task: 0 for task in self.TASKS}
remaining = {task: [0] * (len(units) + 1) for task in self.TASKS}
for i in range(len(units) - 1, -1, -1):
unit = units[i]
for task in self.TASKS:
remaining[task][i] = remaining[task][i + 1] + int(self._task_available(unit, task))
selected_unit_counts: Counter = Counter()
raw_counts: Counter = Counter()
raw_indices: List[int] = []
for i, unit in enumerate(units):
selected_tasks: List[str] = []
if self._task_available(unit, "planning") and selected_unit_counts["planning"] < target_units:
selected_tasks.append("planning")
selected_unit_counts["planning"] += 1
for task in ("detection", "caption"):
if not self._task_available(unit, task):
continue
if selected_unit_counts[task] >= target_units:
continue
need = target_units - selected_unit_counts[task]
remaining_with_task = remaining[task][i]
lagging_planning = selected_unit_counts[task] < selected_unit_counts["planning"]
must_take = need >= remaining_with_task
if lagging_planning or must_take:
selected_tasks.append(task)
selected_unit_counts[task] += 1
for task in ("detection", "planning", "caption"):
if task in selected_tasks:
self._emit_task(raw_indices, raw_counts, unit, task)
expected = {task: int(target_units) for task in self.TASKS}
actual = {task: int(selected_unit_counts[task]) for task in self.TASKS}
if actual != expected:
raise RuntimeError(
"Failed to build exact scene-unit balance for "
f"scene={scene_id}: target={expected}, actual={actual}"
)
return raw_indices, actual, {task: int(raw_counts[task]) for task in self.TASKS}
def _build_scene_sequences(self) -> None:
summary_unit_counts: Counter = Counter()
summary_raw_counts: Counter = Counter()
for scene_id in sorted(self.scene_unit_groups.keys()):
raw_indices, unit_counts, raw_counts = self._build_scene_sequence(
scene_id, self.scene_unit_groups[scene_id]
)
if not raw_indices:
self._skipped_scenes.append(scene_id)
continue
self.scene_ids.append(scene_id)
self._scene_sequences[scene_id] = raw_indices
self._scene_unit_counts[scene_id] = unit_counts
self._scene_raw_counts[scene_id] = raw_counts
self._scene_lengths[scene_id] = len(raw_indices)
summary_unit_counts.update(unit_counts)
summary_raw_counts.update(raw_counts)
if not self.scene_ids:
raise ValueError("No scenes remain after scene-unit balancing")
self._summary = {
"num_scenes": len(self.scene_ids),
"skipped_scenes": len(self._skipped_scenes),
"unit_counts": {task: int(summary_unit_counts[task]) for task in self.TASKS},
"raw_counts": {task: int(summary_raw_counts[task]) for task in self.TASKS},
"raw_total": int(sum(summary_raw_counts.values())),
}
def get_summary(self) -> Dict[str, object]:
return dict(self._summary)
def get_last_epoch_stats(self) -> Dict[str, object]:
return dict(self._last_epoch_stats)
def _build_indices(self) -> List[int]:
rng = random.Random(self.seed + self.epoch)
scene_order = list(self.scene_ids)
if self.shuffle_scenes:
rng.shuffle(scene_order)
scene_costs = {sid: self._scene_lengths[sid] for sid in scene_order}
per_rank_scenes, per_rank_counts = _assign_scenes_greedily(
scene_order, scene_costs, self.num_replicas
)
for rid in range(self.num_replicas):
if not per_rank_scenes[rid]:
fallback = scene_order[rid % len(scene_order)]
per_rank_scenes[rid].append(fallback)
per_rank_counts[rid] += self._scene_lengths[fallback]
target_count = max(per_rank_counts)
block_multiple = _lcm(self.pad_to_multiple, self.UNIT_RAW_BLOCK_SIZE)
if target_count % block_multiple != 0:
target_count = (
(target_count + block_multiple - 1)
// block_multiple
) * block_multiple
self._target_len = target_count
my_scenes = per_rank_scenes[self.rank]
if self.shuffle_scenes:
rng2 = random.Random(self.seed + self.epoch + self.rank)
rng2.shuffle(my_scenes)
indices: List[int] = []
unit_counts_epoch: Counter = Counter()
raw_counts_epoch: Counter = Counter()
for sid in my_scenes:
indices.extend(self._scene_sequences[sid])
unit_counts_epoch.update(self._scene_unit_counts[sid])
raw_counts_epoch.update(self._scene_raw_counts[sid])
prepad_len = len(indices)
replay_counts: Counter = Counter()
if len(indices) < target_count:
replay_pool = list(my_scenes)
cursor = 0
while len(indices) < target_count:
sid = replay_pool[cursor % len(replay_pool)]
indices.extend(self._scene_sequences[sid])
unit_counts_epoch.update(self._scene_unit_counts[sid])
raw_counts_epoch.update(self._scene_raw_counts[sid])
replay_counts[sid] += 1
cursor += 1
indices = indices[:target_count]
self._last_epoch_stats = {
"num_scenes_total": len(self.scene_ids),
"num_scenes_rank": len(my_scenes),
"num_skipped_scenes": len(self._skipped_scenes),
"prepad_len": int(prepad_len),
"target_len": int(target_count),
"replay_extra": int(target_count - prepad_len),
"unit_counts": {task: int(unit_counts_epoch[task]) for task in self.TASKS},
"raw_counts": {task: int(raw_counts_epoch[task]) for task in self.TASKS},
"replayed_scenes": {sid: int(v) for sid, v in replay_counts.items()},
}
return indices
def __iter__(self) -> Iterator[int]:
self._cached = self._build_indices()
return iter(self._cached)
def __len__(self) -> int:
if self._target_len == 0:
self._cached = self._build_indices()
return self._target_len