Spaces:
Runtime error
Runtime error
File size: 5,776 Bytes
78d2329 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import random
from dataclasses import dataclass, is_dataclass, fields, replace
from typing import Any
import torch
from optgs.dataset.data_types import BatchedExample
from optgs.model.types import Gaussians
from optgs.scene_trainer.optimizer.optimizer import OptimizerState
def to_device(obj: Any, device: torch.device | str, detach=True) -> Any:
"""
Recursively moves all tensors (and nested dataclasses) to the given device.
- Skips None fields
- Works with nested dataclasses
- Works with lists/tuples of tensors or dataclasses
"""
if torch.is_tensor(obj):
if detach:
obj = obj.detach()
return obj.to(device)
elif is_dataclass(obj):
kwargs = {}
for f in fields(obj):
val = getattr(obj, f.name)
if val is not None:
kwargs[f.name] = to_device(val, device, detach=detach)
return replace(obj, **kwargs)
elif isinstance(obj, (list, tuple)):
return type(obj)(to_device(v, device, detach=detach) for v in obj)
elif isinstance(obj, dict):
return {k: to_device(v, device, detach=detach) for k, v in obj.items()}
else:
return obj # Leave unchanged (e.g., int, float, str)
@dataclass
class GaussianEpisodeEntry:
id: int
t: int
batch: BatchedExample
gaussians: Gaussians
state: OptimizerState | None = None
info: dict[str, Any] | None = None
@dataclass
class ReplayBufferCfg:
capacity: int # number of snapshots to store
sample_batch_size: int # number of snapshots to sample when resuming training
sample_prob: float | int # probability of sampling from the buffer vs starting fresh
insert_prob: float | int # probability of pushing to the buffer a new sample
return_prob: float | int # probability of returning the sampled snapshot (vs discarding it)
simulate_ahead: bool # whether to simulate ahead before returning the updated snapshot
simulate_ahead_min_steps: int # min steps to simulate ahead
simulate_ahead_max_steps: int # max steps to simulate ahead
simulate_ahead_grow: int # number of steps to scale up the max steps over meta iterations
max_t: int | None # maximum number of inner steps per episode
push_only_if_not_full: bool # only push if buffer is not full
remove_strategy_when_full: str # strategy to remove entries when buffer is full: "oldest" or "random"
class EpisodeReplayBuffer:
def __init__(self, cfg: ReplayBufferCfg):
self.cfg = cfg
self.buffer = []
assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now."
def push(self, entry, to_cpu=True):
"""Store one snapshot (intermediate state of training).
If the buffer is full, the oldest snapshot will be removed.
"""
if to_cpu:
entry = to_device(entry, 'cpu', detach=True)
self.buffer.append(entry)
if len(self.buffer) > self.cfg.capacity:
if self.cfg.remove_strategy_when_full == "oldest":
self.buffer.pop(0) # remove oldest if full
elif self.cfg.remove_strategy_when_full == "random":
idx = random.randint(0, len(self.buffer) - 2) # remove random except the newly added one
del self.buffer[idx]
else:
raise ValueError("Invalid remove strategy when full")
def sample(self, device, leave_batch_fn=None):
"""Return and remove a random element from the buffer."""
if len(self.buffer) < self.cfg.sample_batch_size:
raise ValueError("Not enough elements in the buffer to sample")
# Sample random entries
indices = random.sample(range(len(self.buffer)), self.cfg.sample_batch_size)
sampled_entries = [self.buffer[i] for i in indices]
# Remove from buffer by index (must go in reverse to avoid shifting)
for idx in sorted(indices, reverse=True):
del self.buffer[idx]
assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now."
sampled_entries = sampled_entries[0]
# Move to device
if leave_batch_fn is not None:
batch = sampled_entries.batch
# should_move_batch = leave_batch_fn(batch)
sampled_entries = to_device(sampled_entries, device)
return sampled_entries
def flipcoin(self, action: str):
"""Flip a coin to decide whether to sample or push."""
if action == "sample":
return random.random() < self.cfg.sample_prob
elif action == "insert":
return random.random() < self.cfg.insert_prob
elif action == "return":
return random.random() < self.cfg.return_prob
else:
raise ValueError("sample_or_push must be 'sample' or 'push'")
def should_sample(self):
buffer_is_not_full = len(self.buffer) < self.cfg.capacity
if buffer_is_not_full:
return False
return len(self.buffer) >= self.cfg.sample_batch_size and self.flipcoin("sample")
def should_push(self, new_sample: bool, t: int):
if self.cfg.push_only_if_not_full and len(self.buffer) >= self.cfg.capacity:
return False # do not push if buffer is full
if self.cfg.max_t is not None:
if t >= self.cfg.max_t:
return # do not store entries beyond max_t
if len(self.buffer) < self.cfg.capacity:
# Always fill the buffer if possible
return True
if new_sample:
return self.flipcoin("insert")
else:
return self.flipcoin("return")
def __len__(self):
return len(self.buffer)
def clear(self):
self.buffer.clear() |