File size: 7,784 Bytes
b47a1ce 93d7b0a b47a1ce 93d7b0a b47a1ce 93d7b0a b47a1ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
import torch
from .types import MemoryRecord, MemorySourceType
@dataclass
class MemoryBankQuery:
target_frame: int
source_type: Optional[MemorySourceType] = None
include_generated: bool = True
max_records: Optional[int] = None
class CausalMemoryBank:
"""Small causal memory bank for DeMemWM records."""
def __init__(self, max_records: Optional[int] = None):
self.max_records = max_records
self._records: list[MemoryRecord] = []
def __len__(self) -> int:
return len(self._records)
@property
def records(self) -> tuple[MemoryRecord, ...]:
return tuple(self._records)
def add_record(self, record: MemoryRecord) -> None:
if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated:
raise ValueError("generated records cannot be high-trust prefix anchors")
self._records.append(record)
if self.max_records is not None and len(self._records) > self.max_records:
self._records = self._records[-self.max_records:]
def add_prefix_anchors(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor] = None,
slots_per_anchor: Optional[int] = None,
) -> None:
if tokens.ndim == 2:
tokens = tokens.unsqueeze(0)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
flat_frames = frame_indices.detach().reshape(-1)
if tokens.shape[0] != flat_frames.numel():
raise ValueError("tokens first dimension must match number of frame indices")
for i, frame in enumerate(flat_frames.tolist()):
rec_tokens = tokens[i]
rec_mask = mask[i].bool()
if slots_per_anchor is not None:
rec_tokens = rec_tokens[:slots_per_anchor]
rec_mask = rec_mask[:slots_per_anchor]
self.add_record(
MemoryRecord(
tokens=rec_tokens,
mask=rec_mask,
source_start=int(frame),
source_end=int(frame) + 1,
frame_indices=torch.as_tensor([frame], device=rec_tokens.device),
pose=None if pose is None else pose[i],
source_type=MemorySourceType.PREFIX_GT,
is_generated=False,
chunk_id=f"prefix_{int(frame)}",
)
)
def add_chunk_record(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor] = None,
source_type: MemorySourceType = MemorySourceType.PREFIX_GT,
is_generated: bool = False,
chunk_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> None:
flat_frames = frame_indices.detach().reshape(-1)
if flat_frames.numel() == 0:
raise ValueError("chunk frame_indices must be non-empty")
if tokens.ndim != 2:
raise ValueError("chunk tokens must have shape (M,D)")
if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]:
raise ValueError("chunk mask must have shape (M,)")
start = int(flat_frames.min().item())
end = int(flat_frames.max().item()) + 1
self.add_record(
MemoryRecord(
tokens=tokens,
mask=mask.bool(),
source_start=start,
source_end=end,
frame_indices=flat_frames.to(device=tokens.device),
pose=pose,
source_type=source_type,
is_generated=bool(is_generated),
chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}",
metadata=dict(metadata or {}),
)
)
def add_frame_record(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_index: torch.Tensor | int,
pose: Optional[torch.Tensor] = None,
source_type: MemorySourceType = MemorySourceType.REVISIT,
is_generated: bool = False,
record_id: Optional[str] = None,
metadata: Optional[dict] = None,
) -> None:
frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device)
frame = int(frame_tensor.item())
self.add_record(
MemoryRecord(
tokens=tokens,
mask=mask.bool(),
source_start=frame,
source_end=frame + 1,
frame_indices=frame_tensor,
pose=pose,
source_type=source_type,
is_generated=bool(is_generated),
chunk_id=record_id or f"{source_type.value}_frame_{frame}",
metadata=dict(metadata or {}),
)
)
def add_generated_records(
self,
tokens: torch.Tensor,
mask: torch.Tensor,
frame_indices: torch.Tensor,
pose: Optional[torch.Tensor] = None,
source_type: MemorySourceType = MemorySourceType.GENERATED,
) -> None:
if source_type == MemorySourceType.PREFIX_GT:
raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default")
if tokens.ndim == 2:
tokens = tokens.unsqueeze(0)
if mask.ndim == 1:
mask = mask.unsqueeze(0)
flat_frames = frame_indices.detach().reshape(-1)
for i, frame in enumerate(flat_frames.tolist()):
self.add_record(
MemoryRecord(
tokens=tokens[i],
mask=mask[i].bool(),
source_start=int(frame),
source_end=int(frame) + 1,
frame_indices=torch.as_tensor([frame], device=tokens.device),
pose=None if pose is None else pose[i],
source_type=source_type,
is_generated=True,
chunk_id=f"generated_{int(frame)}",
)
)
def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]:
if isinstance(query, int):
query = MemoryBankQuery(target_frame=query, **kwargs)
out: list[MemoryRecord] = []
for record in self._records:
if int(record.source_end) > int(query.target_frame):
continue
if query.source_type is not None and record.source_type != query.source_type:
continue
if not query.include_generated and record.is_generated:
continue
out.append(record)
if query.max_records is not None and len(out) >= query.max_records:
break
return out
def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)]
if offenders:
raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
def stack_record_tokens(records: list[MemoryRecord], target_slots: int | None = None):
if not records:
return None, None
tokens = torch.cat([r.tokens for r in records], dim=0)
mask = torch.cat([r.mask.bool() for r in records], dim=0)
if target_slots is not None:
valid_idx = mask.nonzero(as_tuple=False).flatten()
tokens = tokens.index_select(0, valid_idx)[:target_slots]
mask = mask.index_select(0, valid_idx)[:target_slots]
return tokens, mask
|