File size: 4,021 Bytes
9fe982a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""DDP-safe scene-sequential sampler for online temporal training.

Guarantees:
  1. Within each scene, frames are yielded in strict timestamp order.
  2. Scene order is shuffled per-epoch for training diversity.
  3. All ranks yield exactly the same number of micro-steps per epoch
     (balanced by greedy scene assignment + deterministic replay padding).
  4. Epoch boundaries and replay-scene starts are detectable by the caller
     via timestamp regression, so StreamPETR memory can be reset correctly.
"""

import random
from typing import Dict, Iterator, List, Sequence

from torch.utils.data import Sampler


class SceneSequentialSampler(Sampler[int]):
    """Distributed temporal sampler with equal-step guarantee."""

    def __init__(
        self,
        scene_groups: Dict[str, List[int]],
        num_replicas: int = 1,
        rank: int = 0,
        seed: int = 0,
        shuffle_scenes: bool = True,
        pad_to_multiple: int = 1,
    ):
        if num_replicas < 1:
            raise ValueError(f"num_replicas must be >= 1, got {num_replicas}")
        if rank < 0 or rank >= num_replicas:
            raise ValueError(f"rank must be in [0, {num_replicas}), got {rank}")
        if not scene_groups:
            raise ValueError("scene_groups must not be empty")

        self.scene_groups = scene_groups
        self.scene_ids = sorted(scene_groups.keys())
        self.num_replicas = int(num_replicas)
        self.rank = int(rank)
        self.seed = int(seed)
        self.shuffle_scenes = bool(shuffle_scenes)
        self.pad_to_multiple = max(1, int(pad_to_multiple))
        self.epoch = 0
        self._cached: List[int] = []
        self._target_len = 0

    def set_epoch(self, epoch: int) -> None:
        self.epoch = int(epoch)
        self._cached = []
        self._target_len = 0

    def _scene_len(self, sid: str) -> int:
        return len(self.scene_groups[sid])

    def _build_indices(self) -> List[int]:
        rng = random.Random(self.seed + self.epoch)
        scene_order = list(self.scene_ids)
        if self.shuffle_scenes:
            rng.shuffle(scene_order)

        per_rank_scenes: List[List[str]] = [[] for _ in range(self.num_replicas)]
        per_rank_counts = [0] * self.num_replicas

        for sid in sorted(scene_order, key=self._scene_len, reverse=True):
            target = min(range(self.num_replicas), key=lambda r: per_rank_counts[r])
            per_rank_scenes[target].append(sid)
            per_rank_counts[target] += self._scene_len(sid)

        for rid in range(self.num_replicas):
            if not per_rank_scenes[rid]:
                fallback = scene_order[rid % len(scene_order)]
                per_rank_scenes[rid].append(fallback)
                per_rank_counts[rid] += self._scene_len(fallback)

        target_count = max(per_rank_counts)
        if target_count % self.pad_to_multiple != 0:
            target_count = (
                (target_count + self.pad_to_multiple - 1)
                // self.pad_to_multiple
            ) * self.pad_to_multiple
        self._target_len = target_count

        my_scenes = per_rank_scenes[self.rank]
        if self.shuffle_scenes:
            rng2 = random.Random(self.seed + self.epoch + self.rank)
            rng2.shuffle(my_scenes)

        indices: List[int] = []
        for sid in my_scenes:
            indices.extend(self.scene_groups[sid])

        if len(indices) < target_count:
            replay_pool = list(my_scenes)
            cursor = 0
            while len(indices) < target_count:
                sid = replay_pool[cursor % len(replay_pool)]
                indices.extend(self.scene_groups[sid])
                cursor += 1
            indices = indices[:target_count]

        return indices

    def __iter__(self) -> Iterator[int]:
        self._cached = self._build_indices()
        return iter(self._cached)

    def __len__(self) -> int:
        if self._target_len == 0:
            self._cached = self._build_indices()
        return self._target_len