|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Iterable, Optional |
|
|
| import torch |
|
|
| from .types import MemoryRecord, MemorySourceType |
|
|
|
|
| @dataclass |
| class MemoryBankQuery: |
| target_frame: int |
| source_type: Optional[MemorySourceType] = None |
| include_generated: bool = True |
| max_records: Optional[int] = None |
|
|
|
|
| class CausalMemoryBank: |
| """Small causal memory bank for DeMemWM records.""" |
|
|
| def __init__(self, max_records: Optional[int] = None): |
| self.max_records = max_records |
| self._records: list[MemoryRecord] = [] |
|
|
| def __len__(self) -> int: |
| return len(self._records) |
|
|
| @property |
| def records(self) -> tuple[MemoryRecord, ...]: |
| return tuple(self._records) |
|
|
| def add_record(self, record: MemoryRecord) -> None: |
| if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated: |
| raise ValueError("generated records cannot be high-trust prefix anchors") |
| self._records.append(record) |
| if self.max_records is not None and len(self._records) > self.max_records: |
| self._records = self._records[-self.max_records:] |
|
|
| def add_prefix_anchors( |
| self, |
| tokens: torch.Tensor, |
| mask: torch.Tensor, |
| frame_indices: torch.Tensor, |
| pose: Optional[torch.Tensor] = None, |
| slots_per_anchor: Optional[int] = None, |
| ) -> None: |
| if tokens.ndim == 2: |
| tokens = tokens.unsqueeze(0) |
| if mask.ndim == 1: |
| mask = mask.unsqueeze(0) |
| flat_frames = frame_indices.detach().reshape(-1) |
| if tokens.shape[0] != flat_frames.numel(): |
| raise ValueError("tokens first dimension must match number of frame indices") |
| for i, frame in enumerate(flat_frames.tolist()): |
| rec_tokens = tokens[i] |
| rec_mask = mask[i].bool() |
| if slots_per_anchor is not None: |
| rec_tokens = rec_tokens[:slots_per_anchor] |
| rec_mask = rec_mask[:slots_per_anchor] |
| self.add_record( |
| MemoryRecord( |
| tokens=rec_tokens, |
| mask=rec_mask, |
| source_start=int(frame), |
| source_end=int(frame) + 1, |
| frame_indices=torch.as_tensor([frame], device=rec_tokens.device), |
| pose=None if pose is None else pose[i], |
| source_type=MemorySourceType.PREFIX_GT, |
| is_generated=False, |
| chunk_id=f"prefix_{int(frame)}", |
| ) |
| ) |
|
|
| def add_chunk_record( |
| self, |
| tokens: torch.Tensor, |
| mask: torch.Tensor, |
| frame_indices: torch.Tensor, |
| pose: Optional[torch.Tensor] = None, |
| source_type: MemorySourceType = MemorySourceType.PREFIX_GT, |
| is_generated: bool = False, |
| chunk_id: Optional[str] = None, |
| metadata: Optional[dict] = None, |
| ) -> None: |
| flat_frames = frame_indices.detach().reshape(-1) |
| if flat_frames.numel() == 0: |
| raise ValueError("chunk frame_indices must be non-empty") |
| if tokens.ndim != 2: |
| raise ValueError("chunk tokens must have shape (M,D)") |
| if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]: |
| raise ValueError("chunk mask must have shape (M,)") |
| start = int(flat_frames.min().item()) |
| end = int(flat_frames.max().item()) + 1 |
| self.add_record( |
| MemoryRecord( |
| tokens=tokens, |
| mask=mask.bool(), |
| source_start=start, |
| source_end=end, |
| frame_indices=flat_frames.to(device=tokens.device), |
| pose=pose, |
| source_type=source_type, |
| is_generated=bool(is_generated), |
| chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}", |
| metadata=dict(metadata or {}), |
| ) |
| ) |
|
|
| def add_frame_record( |
| self, |
| tokens: torch.Tensor, |
| mask: torch.Tensor, |
| frame_index: torch.Tensor | int, |
| pose: Optional[torch.Tensor] = None, |
| source_type: MemorySourceType = MemorySourceType.REVISIT, |
| is_generated: bool = False, |
| record_id: Optional[str] = None, |
| metadata: Optional[dict] = None, |
| ) -> None: |
| frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device) |
| frame = int(frame_tensor.item()) |
| self.add_record( |
| MemoryRecord( |
| tokens=tokens, |
| mask=mask.bool(), |
| source_start=frame, |
| source_end=frame + 1, |
| frame_indices=frame_tensor, |
| pose=pose, |
| source_type=source_type, |
| is_generated=bool(is_generated), |
| chunk_id=record_id or f"{source_type.value}_frame_{frame}", |
| metadata=dict(metadata or {}), |
| ) |
| ) |
|
|
| def add_generated_records( |
| self, |
| tokens: torch.Tensor, |
| mask: torch.Tensor, |
| frame_indices: torch.Tensor, |
| pose: Optional[torch.Tensor] = None, |
| source_type: MemorySourceType = MemorySourceType.GENERATED, |
| ) -> None: |
| if source_type == MemorySourceType.PREFIX_GT: |
| raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default") |
| if tokens.ndim == 2: |
| tokens = tokens.unsqueeze(0) |
| if mask.ndim == 1: |
| mask = mask.unsqueeze(0) |
| flat_frames = frame_indices.detach().reshape(-1) |
| for i, frame in enumerate(flat_frames.tolist()): |
| self.add_record( |
| MemoryRecord( |
| tokens=tokens[i], |
| mask=mask[i].bool(), |
| source_start=int(frame), |
| source_end=int(frame) + 1, |
| frame_indices=torch.as_tensor([frame], device=tokens.device), |
| pose=None if pose is None else pose[i], |
| source_type=source_type, |
| is_generated=True, |
| chunk_id=f"generated_{int(frame)}", |
| ) |
| ) |
|
|
| def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]: |
| if isinstance(query, int): |
| query = MemoryBankQuery(target_frame=query, **kwargs) |
| out: list[MemoryRecord] = [] |
| for record in self._records: |
| if int(record.source_end) > int(query.target_frame): |
| continue |
| if query.source_type is not None and record.source_type != query.source_type: |
| continue |
| if not query.include_generated and record.is_generated: |
| continue |
| out.append(record) |
| if query.max_records is not None and len(out) >= query.max_records: |
| break |
| return out |
|
|
| def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None: |
| offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)] |
| if offenders: |
| raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}") |
|
|
|
|
| def stack_record_tokens(records: list[MemoryRecord], target_slots: int | None = None): |
| if not records: |
| return None, None |
| tokens = torch.cat([r.tokens for r in records], dim=0) |
| mask = torch.cat([r.mask.bool() for r in records], dim=0) |
| if target_slots is not None: |
| valid_idx = mask.nonzero(as_tuple=False).flatten() |
| tokens = tokens.index_select(0, valid_idx)[:target_slots] |
| mask = mask.index_select(0, valid_idx)[:target_slots] |
| return tokens, mask |
|
|