Learn2Splat / optgs /meta_trainer /replay_buffer.py
SteEsp's picture
Add Docker-based Learn2Splat demo (viser GUI)
78d2329 verified
import random
from dataclasses import dataclass, is_dataclass, fields, replace
from typing import Any
import torch
from optgs.dataset.data_types import BatchedExample
from optgs.model.types import Gaussians
from optgs.scene_trainer.optimizer.optimizer import OptimizerState
def to_device(obj: Any, device: torch.device | str, detach=True) -> Any:
"""
Recursively moves all tensors (and nested dataclasses) to the given device.
- Skips None fields
- Works with nested dataclasses
- Works with lists/tuples of tensors or dataclasses
"""
if torch.is_tensor(obj):
if detach:
obj = obj.detach()
return obj.to(device)
elif is_dataclass(obj):
kwargs = {}
for f in fields(obj):
val = getattr(obj, f.name)
if val is not None:
kwargs[f.name] = to_device(val, device, detach=detach)
return replace(obj, **kwargs)
elif isinstance(obj, (list, tuple)):
return type(obj)(to_device(v, device, detach=detach) for v in obj)
elif isinstance(obj, dict):
return {k: to_device(v, device, detach=detach) for k, v in obj.items()}
else:
return obj # Leave unchanged (e.g., int, float, str)
@dataclass
class GaussianEpisodeEntry:
id: int
t: int
batch: BatchedExample
gaussians: Gaussians
state: OptimizerState | None = None
info: dict[str, Any] | None = None
@dataclass
class ReplayBufferCfg:
capacity: int # number of snapshots to store
sample_batch_size: int # number of snapshots to sample when resuming training
sample_prob: float | int # probability of sampling from the buffer vs starting fresh
insert_prob: float | int # probability of pushing to the buffer a new sample
return_prob: float | int # probability of returning the sampled snapshot (vs discarding it)
simulate_ahead: bool # whether to simulate ahead before returning the updated snapshot
simulate_ahead_min_steps: int # min steps to simulate ahead
simulate_ahead_max_steps: int # max steps to simulate ahead
simulate_ahead_grow: int # number of steps to scale up the max steps over meta iterations
max_t: int | None # maximum number of inner steps per episode
push_only_if_not_full: bool # only push if buffer is not full
remove_strategy_when_full: str # strategy to remove entries when buffer is full: "oldest" or "random"
class EpisodeReplayBuffer:
def __init__(self, cfg: ReplayBufferCfg):
self.cfg = cfg
self.buffer = []
assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now."
def push(self, entry, to_cpu=True):
"""Store one snapshot (intermediate state of training).
If the buffer is full, the oldest snapshot will be removed.
"""
if to_cpu:
entry = to_device(entry, 'cpu', detach=True)
self.buffer.append(entry)
if len(self.buffer) > self.cfg.capacity:
if self.cfg.remove_strategy_when_full == "oldest":
self.buffer.pop(0) # remove oldest if full
elif self.cfg.remove_strategy_when_full == "random":
idx = random.randint(0, len(self.buffer) - 2) # remove random except the newly added one
del self.buffer[idx]
else:
raise ValueError("Invalid remove strategy when full")
def sample(self, device, leave_batch_fn=None):
"""Return and remove a random element from the buffer."""
if len(self.buffer) < self.cfg.sample_batch_size:
raise ValueError("Not enough elements in the buffer to sample")
# Sample random entries
indices = random.sample(range(len(self.buffer)), self.cfg.sample_batch_size)
sampled_entries = [self.buffer[i] for i in indices]
# Remove from buffer by index (must go in reverse to avoid shifting)
for idx in sorted(indices, reverse=True):
del self.buffer[idx]
assert self.cfg.sample_batch_size == 1, "Only batch size of 1 is supported for now."
sampled_entries = sampled_entries[0]
# Move to device
if leave_batch_fn is not None:
batch = sampled_entries.batch
# should_move_batch = leave_batch_fn(batch)
sampled_entries = to_device(sampled_entries, device)
return sampled_entries
def flipcoin(self, action: str):
"""Flip a coin to decide whether to sample or push."""
if action == "sample":
return random.random() < self.cfg.sample_prob
elif action == "insert":
return random.random() < self.cfg.insert_prob
elif action == "return":
return random.random() < self.cfg.return_prob
else:
raise ValueError("sample_or_push must be 'sample' or 'push'")
def should_sample(self):
buffer_is_not_full = len(self.buffer) < self.cfg.capacity
if buffer_is_not_full:
return False
return len(self.buffer) >= self.cfg.sample_batch_size and self.flipcoin("sample")
def should_push(self, new_sample: bool, t: int):
if self.cfg.push_only_if_not_full and len(self.buffer) >= self.cfg.capacity:
return False # do not push if buffer is full
if self.cfg.max_t is not None:
if t >= self.cfg.max_t:
return # do not store entries beyond max_t
if len(self.buffer) < self.cfg.capacity:
# Always fill the buffer if possible
return True
if new_sample:
return self.flipcoin("insert")
else:
return self.flipcoin("return")
def __len__(self):
return len(self.buffer)
def clear(self):
self.buffer.clear()