from __future__ import annotations from dataclasses import dataclass from typing import Iterable, Optional import torch from .types import MemoryRecord, MemorySourceType @dataclass class MemoryBankQuery: target_frame: int source_type: Optional[MemorySourceType] = None include_generated: bool = True max_records: Optional[int] = None class CausalMemoryBank: """Small causal memory bank for DeMemWM records.""" def __init__(self, max_records: Optional[int] = None): self.max_records = max_records self._records: list[MemoryRecord] = [] def __len__(self) -> int: return len(self._records) @property def records(self) -> tuple[MemoryRecord, ...]: return tuple(self._records) def add_record(self, record: MemoryRecord) -> None: if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated: raise ValueError("generated records cannot be high-trust prefix anchors") self._records.append(record) if self.max_records is not None and len(self._records) > self.max_records: self._records = self._records[-self.max_records:] def add_prefix_anchors( self, tokens: torch.Tensor, mask: torch.Tensor, frame_indices: torch.Tensor, pose: Optional[torch.Tensor] = None, slots_per_anchor: Optional[int] = None, ) -> None: if tokens.ndim == 2: tokens = tokens.unsqueeze(0) if mask.ndim == 1: mask = mask.unsqueeze(0) flat_frames = frame_indices.detach().reshape(-1) if tokens.shape[0] != flat_frames.numel(): raise ValueError("tokens first dimension must match number of frame indices") for i, frame in enumerate(flat_frames.tolist()): rec_tokens = tokens[i] rec_mask = mask[i].bool() if slots_per_anchor is not None: rec_tokens = rec_tokens[:slots_per_anchor] rec_mask = rec_mask[:slots_per_anchor] self.add_record( MemoryRecord( tokens=rec_tokens, mask=rec_mask, source_start=int(frame), source_end=int(frame) + 1, frame_indices=torch.as_tensor([frame], device=rec_tokens.device), pose=None if pose is None else pose[i], source_type=MemorySourceType.PREFIX_GT, is_generated=False, chunk_id=f"prefix_{int(frame)}", ) ) def add_chunk_record( self, tokens: torch.Tensor, mask: torch.Tensor, frame_indices: torch.Tensor, pose: Optional[torch.Tensor] = None, source_type: MemorySourceType = MemorySourceType.PREFIX_GT, is_generated: bool = False, chunk_id: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: flat_frames = frame_indices.detach().reshape(-1) if flat_frames.numel() == 0: raise ValueError("chunk frame_indices must be non-empty") if tokens.ndim != 2: raise ValueError("chunk tokens must have shape (M,D)") if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]: raise ValueError("chunk mask must have shape (M,)") start = int(flat_frames.min().item()) end = int(flat_frames.max().item()) + 1 self.add_record( MemoryRecord( tokens=tokens, mask=mask.bool(), source_start=start, source_end=end, frame_indices=flat_frames.to(device=tokens.device), pose=pose, source_type=source_type, is_generated=bool(is_generated), chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}", metadata=dict(metadata or {}), ) ) def add_frame_record( self, tokens: torch.Tensor, mask: torch.Tensor, frame_index: torch.Tensor | int, pose: Optional[torch.Tensor] = None, source_type: MemorySourceType = MemorySourceType.REVISIT, is_generated: bool = False, record_id: Optional[str] = None, metadata: Optional[dict] = None, ) -> None: frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device) frame = int(frame_tensor.item()) self.add_record( MemoryRecord( tokens=tokens, mask=mask.bool(), source_start=frame, source_end=frame + 1, frame_indices=frame_tensor, pose=pose, source_type=source_type, is_generated=bool(is_generated), chunk_id=record_id or f"{source_type.value}_frame_{frame}", metadata=dict(metadata or {}), ) ) def add_generated_records( self, tokens: torch.Tensor, mask: torch.Tensor, frame_indices: torch.Tensor, pose: Optional[torch.Tensor] = None, source_type: MemorySourceType = MemorySourceType.GENERATED, ) -> None: if source_type == MemorySourceType.PREFIX_GT: raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default") if tokens.ndim == 2: tokens = tokens.unsqueeze(0) if mask.ndim == 1: mask = mask.unsqueeze(0) flat_frames = frame_indices.detach().reshape(-1) for i, frame in enumerate(flat_frames.tolist()): self.add_record( MemoryRecord( tokens=tokens[i], mask=mask[i].bool(), source_start=int(frame), source_end=int(frame) + 1, frame_indices=torch.as_tensor([frame], device=tokens.device), pose=None if pose is None else pose[i], source_type=source_type, is_generated=True, chunk_id=f"generated_{int(frame)}", ) ) def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]: if isinstance(query, int): query = MemoryBankQuery(target_frame=query, **kwargs) out: list[MemoryRecord] = [] for record in self._records: if int(record.source_end) > int(query.target_frame): continue if query.source_type is not None and record.source_type != query.source_type: continue if not query.include_generated and record.is_generated: continue out.append(record) if query.max_records is not None and len(out) >= query.max_records: break return out def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None: offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)] if offenders: raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}") def stack_record_tokens(records: list[MemoryRecord], target_slots: int | None = None): if not records: return None, None tokens = torch.cat([r.tokens for r in records], dim=0) mask = torch.cat([r.mask.bool() for r in records], dim=0) if target_slots is not None: valid_idx = mask.nonzero(as_tuple=False).flatten() tokens = tokens.index_select(0, valid_idx)[:target_slots] mask = mask.index_select(0, valid_idx)[:target_slots] return tokens, mask