File size: 5,776 Bytes
78d2329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import random
from dataclasses import dataclass, is_dataclass, fields, replace
from typing import Any

import torch

from optgs.dataset.data_types import BatchedExample
from optgs.model.types import Gaussians
from optgs.scene_trainer.optimizer.optimizer import OptimizerState


def to_device(obj: Any, device: torch.device | str, detach=True) -> Any:
    """
    Recursively moves all tensors (and nested dataclasses) to the given device.
    - Skips None fields
    - Works with nested dataclasses
    - Works with lists/tuples of tensors or dataclasses
    """
    if torch.is_tensor(obj):
        if detach:
            obj = obj.detach()
        return obj.to(device)

    elif is_dataclass(obj):
        kwargs = {}
        for f in fields(obj):
            val = getattr(obj, f.name)
            if val is not None:
                kwargs[f.name] = to_device(val, device, detach=detach)
        return replace(obj, **kwargs)

    elif isinstance(obj, (list, tuple)):
        return type(obj)(to_device(v, device, detach=detach) for v in obj)

    elif isinstance(obj, dict):
        return {k: to_device(v, device, detach=detach) for k, v in obj.items()}

    else:
        return obj  # Leave unchanged (e.g., int, float, str)


@dataclass
class GaussianEpisodeEntry:
    id: int
    t: int
    batch: BatchedExample
    gaussians: Gaussians
    state: OptimizerState | None = None
    info: dict[str, Any] | None = None


@dataclass
class ReplayBufferCfg:
    capacity: int  # number of snapshots to store
    sample_batch_size: int  # number of snapshots to sample when resuming training
    sample_prob: float | int  # probability of sampling from the buffer vs starting fresh
    insert_prob: float | int # probability of pushing to the buffer a new sample
    return_prob: float | int # probability of returning the sampled snapshot (vs discarding it)
    simulate_ahead: bool  # whether to simulate ahead before returning the updated snapshot
    simulate_ahead_min_steps: int  # min steps to simulate ahead
    simulate_ahead_max_steps: int  # max steps to simulate ahead
    simulate_ahead_grow: int  # number of steps to scale up the max steps over meta iterations
    max_t: int | None  # maximum number of inner steps per episode
    push_only_if_not_full: bool  # only push if buffer is not full
    remove_strategy_when_full: str  # strategy to remove entries when buffer is full: "oldest" or "random"


class EpisodeReplayBuffer:
    def __init__(self, cfg: ReplayBufferCfg):
        self.cfg = cfg
        self.buffer = []
        assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now."

    def push(self, entry, to_cpu=True):
        """Store one snapshot (intermediate state of training).

        If the buffer is full, the oldest snapshot will be removed.
        """
        if to_cpu:
            entry = to_device(entry, 'cpu', detach=True)

        self.buffer.append(entry)
        if len(self.buffer) > self.cfg.capacity:
            if self.cfg.remove_strategy_when_full == "oldest":
                self.buffer.pop(0)  # remove oldest if full
            elif self.cfg.remove_strategy_when_full == "random":
                idx = random.randint(0, len(self.buffer) - 2)  # remove random except the newly added one
                del self.buffer[idx]
            else:
                raise ValueError("Invalid remove strategy when full")

    def sample(self, device, leave_batch_fn=None):
        """Return and remove a random element from the buffer."""
        if len(self.buffer) < self.cfg.sample_batch_size:
            raise ValueError("Not enough elements in the buffer to sample")

        # Sample random entries
        indices = random.sample(range(len(self.buffer)), self.cfg.sample_batch_size)
        sampled_entries = [self.buffer[i] for i in indices]

        # Remove from buffer by index (must go in reverse to avoid shifting)
        for idx in sorted(indices, reverse=True):
            del self.buffer[idx]

        assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now."
        sampled_entries = sampled_entries[0]

        # Move to device
        if leave_batch_fn is not None:
            batch = sampled_entries.batch
            # should_move_batch = leave_batch_fn(batch)
        sampled_entries = to_device(sampled_entries, device)

        return sampled_entries

    def flipcoin(self, action: str):
        """Flip a coin to decide whether to sample or push."""
        if action == "sample":
            return random.random() < self.cfg.sample_prob
        elif action == "insert":
            return random.random() < self.cfg.insert_prob
        elif action == "return":
            return random.random() < self.cfg.return_prob
        else:
            raise ValueError("sample_or_push must be 'sample' or 'push'")

    def should_sample(self):
        buffer_is_not_full = len(self.buffer) < self.cfg.capacity
        if buffer_is_not_full:
            return False
        return len(self.buffer) >= self.cfg.sample_batch_size and self.flipcoin("sample")

    def should_push(self, new_sample: bool, t: int):
        if self.cfg.push_only_if_not_full and len(self.buffer) >= self.cfg.capacity:
            return  False # do not push if buffer is full

        if self.cfg.max_t is not None:
            if t >= self.cfg.max_t:
                return  # do not store entries beyond max_t

        if len(self.buffer) < self.cfg.capacity:
            # Always fill the buffer if possible
            return True

        if new_sample:
            return self.flipcoin("insert")
        else:
            return self.flipcoin("return")

    def __len__(self):
        return len(self.buffer)

    def clear(self):
        self.buffer.clear()