BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from .types import MemoryRecord, MemorySourceType
@dataclass
class MemoryBankQuery:
target_frame: int
source_type: Optional[MemorySourceType] = None
include_generated: bool = True
max_records: Optional[int] = None
class CausalMemoryBank:
"""Small causal memory bank for DeMemWM records."""
def __init__(self, max_records: Optional[int] = None):
self.max_records = max_records
self._records: list[MemoryRecord] = []
def __len__(self) -> int:
return len(self._records)
@property
def records(self) -> tuple[MemoryRecord, ...]:
return tuple(self._records)
def add_record(self, record: MemoryRecord) -> None:
if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated:
raise ValueError("generated records cannot be high-trust prefix anchors")
self._records.append(record)
if self.max_records is not None and len(self._records) > self.max_records:
self._records = self._records[-self.max_records:]
def add_prefix_anchors(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor] = None,
slots_per_anchor: Optional[int] = None,
) -> None:
if tokens.ndim == 2:
tokens = tokens.unsqueeze(0)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
flat_frames = frame_indices.detach().reshape(-1)
if tokens.shape[0] != flat_frames.numel():
raise ValueError("tokens first dimension must match number of frame indices")
for i, frame in enumerate(flat_frames.tolist()):
rec_tokens = tokens[i]
rec_mask = mask[i].bool()
if slots_per_anchor is not None:
rec_tokens = rec_tokens[:slots_per_anchor]
rec_mask = rec_mask[:slots_per_anchor]
self.add_record(
MemoryRecord(
tokens=rec_tokens,
mask=rec_mask,
source_start=int(frame),
source_end=int(frame) + 1,
frame_indices=torch.as_tensor([frame], device=rec_tokens.device),
pose=None if pose is None else pose[i],
source_type=MemorySourceType.PREFIX_GT,
is_generated=False,
chunk_id=f"prefix_{int(frame)}",
)
)
def add_chunk_record(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor] = None,
source_type: MemorySourceType = MemorySourceType.PREFIX_GT,
is_generated: bool = False,
chunk_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> None:
flat_frames = frame_indices.detach().reshape(-1)
if flat_frames.numel() == 0:
raise ValueError("chunk frame_indices must be non-empty")
if tokens.ndim != 2:
raise ValueError("chunk tokens must have shape (M,D)")
if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]:
raise ValueError("chunk mask must have shape (M,)")
start = int(flat_frames.min().item())
end = int(flat_frames.max().item()) + 1
self.add_record(
MemoryRecord(
tokens=tokens,
mask=mask.bool(),
source_start=start,
source_end=end,
frame_indices=flat_frames.to(device=tokens.device),
pose=pose,
source_type=source_type,
is_generated=bool(is_generated),
chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}",
metadata=dict(metadata or {}),
)
)
def add_frame_record(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_index: torch.Tensor | int,
pose: Optional[torch.Tensor] = None,
source_type: MemorySourceType = MemorySourceType.REVISIT,
is_generated: bool = False,
record_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> None:
frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device)
frame = int(frame_tensor.item())
self.add_record(
MemoryRecord(
tokens=tokens,
mask=mask.bool(),
source_start=frame,
source_end=frame + 1,
frame_indices=frame_tensor,
pose=pose,
source_type=source_type,
is_generated=bool(is_generated),
chunk_id=record_id or f"{source_type.value}_frame_{frame}",
metadata=dict(metadata or {}),
)
)
def add_generated_records(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor] = None,
source_type: MemorySourceType = MemorySourceType.GENERATED,
) -> None:
if source_type == MemorySourceType.PREFIX_GT:
raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default")
if tokens.ndim == 2:
tokens = tokens.unsqueeze(0)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
flat_frames = frame_indices.detach().reshape(-1)
for i, frame in enumerate(flat_frames.tolist()):
self.add_record(
MemoryRecord(
tokens=tokens[i],
mask=mask[i].bool(),
source_start=int(frame),
source_end=int(frame) + 1,
frame_indices=torch.as_tensor([frame], device=tokens.device),
pose=None if pose is None else pose[i],
source_type=source_type,
is_generated=True,
chunk_id=f"generated_{int(frame)}",
)
)
def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]:
if isinstance(query, int):
query = MemoryBankQuery(target_frame=query, **kwargs)
out: list[MemoryRecord] = []
for record in self._records:
if int(record.source_end) > int(query.target_frame):
continue
if query.source_type is not None and record.source_type != query.source_type:
continue
if not query.include_generated and record.is_generated:
continue
out.append(record)
if query.max_records is not None and len(out) >= query.max_records:
break
return out
def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)]
if offenders:
raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
def stack_record_tokens(records: list[MemoryRecord], target_slots: int | None = None):
if not records:
return None, None
tokens = torch.cat([r.tokens for r in records], dim=0)
mask = torch.cat([r.mask.bool() for r in records], dim=0)
if target_slots is not None:
valid_idx = mask.nonzero(as_tuple=False).flatten()
tokens = tokens.index_select(0, valid_idx)[:target_slots]
mask = mask.index_select(0, valid_idx)[:target_slots]
return tokens, mask