|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Any, Iterable, Optional |
| import warnings |
|
|
| import torch |
|
|
| from .memory import CausalMemoryBank |
| from .types import MemoryRecord |
|
|
|
|
| @dataclass |
| class _RawLatentSegment: |
| latents: torch.Tensor |
| frame_indices: torch.Tensor |
| source_is_generated: torch.Tensor |
| pose: Optional[torch.Tensor] |
|
|
|
|
| class StreamingCache: |
| """Per-video DeMemWM streaming cache with strict no-eviction semantics. |
| |
| The cache is intentionally allowed to grow for the current video. It stores |
| detached CPU (or pinned CPU) raw latents plus compressed MemoryRecord objects, |
| while DiT readout tensors remain bounded by the caller's manual budgets. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| enabled: bool = True, |
| device: str = "cpu", |
| keep_raw_latents: str = "all", |
| keep_compressed_records: bool = True, |
| keep_prefix_anchors: bool = True, |
| eviction_policy: str = "none", |
| no_evict: bool = True, |
| clear_between_videos: bool = True, |
| max_records: Optional[int] = None, |
| on_capacity_exceeded: str = "warn", |
| ) -> None: |
| self.enabled = bool(enabled) |
| self.device = str(device or "cpu") |
| self.keep_raw_latents = keep_raw_latents |
| self.keep_compressed_records = bool(keep_compressed_records) |
| self.keep_prefix_anchors = bool(keep_prefix_anchors) |
| self.eviction_policy = str(eviction_policy or "none") |
| self.no_evict = bool(no_evict) |
| self.clear_between_videos = bool(clear_between_videos) |
| self.max_records = max_records |
| self.on_capacity_exceeded = str(on_capacity_exceeded or "warn") |
| if self.eviction_policy != "none" or not self.no_evict: |
| raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true") |
| if self.device not in {"cpu", "pinned_cpu", "cuda"}: |
| raise ValueError("cache.device must be one of: cpu, pinned_cpu, cuda") |
| self.reset_count = 0 |
| self.evictions = 0 |
| self.capacity_exceeded_count = 0 |
| self.current_video_id: Any = None |
| self._raw_segments: list[_RawLatentSegment] = [] |
| self._records: dict[str, dict[int, list[MemoryRecord]]] = {"anchor": {}, "revisit": {}} |
| self._raw_keys: set[tuple[int, int]] = set() |
| self._raw_index: dict[tuple[int, int], tuple[int, int]] = {} |
| self._record_keys: set[tuple[str, int, str, int, int, bool]] = set() |
| self._batch_size: Optional[int] = None |
| |
| |
| self._raw_concat_version: int = 0 |
| self._raw_concat_built: int = -1 |
| self._raw_concat_cache: Optional[tuple] = None |
| |
| |
| self._banks_version: int = 0 |
| self._banks_built_cache: dict[tuple, tuple[int, list[CausalMemoryBank]]] = {} |
|
|
| @classmethod |
| def from_config(cls, cfg: Any, *, enabled_default: bool = True) -> "StreamingCache": |
| def get(name: str, default: Any) -> Any: |
| return getattr(cfg, name, default) if cfg is not None else default |
|
|
| return cls( |
| enabled=bool(get("enabled", enabled_default)), |
| device=str(get("device", "cpu")), |
| keep_raw_latents=str(get("keep_raw_latents", "all")), |
| keep_compressed_records=bool(get("keep_compressed_records", True)), |
| keep_prefix_anchors=bool(get("keep_prefix_anchors", True)), |
| eviction_policy=str(get("eviction_policy", "none")), |
| no_evict=bool(get("no_evict", True)), |
| clear_between_videos=bool(get("clear_between_videos", True)), |
| max_records=get("max_records", None), |
| on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")), |
| ) |
|
|
| @property |
| def batch_size(self) -> int: |
| return int(self._batch_size or 0) |
|
|
| @property |
| def raw_segment_count(self) -> int: |
| return len(self._raw_segments) |
|
|
| @property |
| def raw_frame_slots(self) -> int: |
| return sum(int(seg.latents.shape[0] * seg.latents.shape[1]) for seg in self._raw_segments) |
|
|
| @property |
| def record_count(self) -> int: |
| return sum(len(records) for by_batch in self._records.values() for records in by_batch.values()) |
|
|
| @property |
| def slot_count(self) -> int: |
| return sum(record.valid_slots for by_batch in self._records.values() for records in by_batch.values() for record in records) |
|
|
| def records_count(self, kind: str | None = None) -> int: |
| if kind is None: |
| return self.record_count |
| return sum(len(records) for records in self._records.get(kind, {}).values()) |
|
|
| def reset(self, video_id: Any = None) -> None: |
| self.current_video_id = video_id |
| self._raw_segments.clear() |
| self._records = {"anchor": {}, "revisit": {}} |
| self._raw_keys.clear() |
| self._raw_index.clear() |
| self._record_keys.clear() |
| self._batch_size = None |
| self.evictions = 0 |
| self.capacity_exceeded_count = 0 |
| self.reset_count += 1 |
| self._raw_concat_version += 1 |
| self._raw_concat_built = -1 |
| self._raw_concat_cache = None |
| self._banks_version += 1 |
| self._banks_built_cache.clear() |
|
|
| def _store_tensor(self, tensor: Optional[torch.Tensor], *, dtype: torch.dtype | None = None) -> Optional[torch.Tensor]: |
| if tensor is None: |
| return None |
| out = tensor.detach() |
| if dtype is not None and out.is_floating_point(): |
| out = out.to(dtype=dtype) |
| if self.device in {"cpu", "pinned_cpu"}: |
| out = out.to(device="cpu", copy=True) |
| if self.device == "pinned_cpu": |
| try: |
| out = out.pin_memory() |
| except RuntimeError: |
| |
| pass |
| elif self.device == "cuda": |
| out = out.clone() |
| return out |
|
|
| def _metadata_to_storage(self, metadata: dict) -> dict: |
| out = {} |
| for key, value in dict(metadata or {}).items(): |
| if torch.is_tensor(value): |
| out[key] = self._store_tensor(value) |
| elif isinstance(value, dict): |
| out[key] = self._metadata_to_storage(value) |
| else: |
| out[key] = value |
| return out |
|
|
| def _metadata_to_device(self, metadata: dict, *, device: torch.device, dtype: torch.dtype) -> dict: |
| out = {} |
| for key, value in dict(metadata or {}).items(): |
| if torch.is_tensor(value): |
| tensor = value.to(device=device) |
| out[key] = tensor.to(dtype=dtype) if tensor.is_floating_point() else tensor |
| elif isinstance(value, dict): |
| out[key] = self._metadata_to_device(value, device=device, dtype=dtype) |
| else: |
| out[key] = value |
| return out |
|
|
| def _record_to_storage(self, record: MemoryRecord) -> MemoryRecord: |
| return MemoryRecord( |
| tokens=self._store_tensor(record.tokens), |
| mask=self._store_tensor(record.mask), |
| source_start=int(record.source_start), |
| source_end=int(record.source_end), |
| frame_indices=self._store_tensor(record.frame_indices), |
| pose=self._store_tensor(record.pose), |
| source_type=record.source_type, |
| is_generated=bool(record.is_generated), |
| score=None if record.score is None or not torch.is_tensor(record.score) else self._store_tensor(record.score), |
| chunk_id=record.chunk_id, |
| metadata=self._metadata_to_storage(record.metadata), |
| ) |
|
|
| def _record_to_device(self, record: MemoryRecord, *, device: torch.device, dtype: torch.dtype) -> MemoryRecord: |
| return MemoryRecord( |
| tokens=record.tokens.to(device=device, dtype=dtype), |
| mask=record.mask.to(device=device, dtype=torch.bool), |
| source_start=int(record.source_start), |
| source_end=int(record.source_end), |
| frame_indices=record.frame_indices.to(device=device), |
| pose=None if record.pose is None else record.pose.to(device=device), |
| source_type=record.source_type, |
| is_generated=bool(record.is_generated), |
| score=record.score, |
| chunk_id=record.chunk_id, |
| metadata=self._metadata_to_device(record.metadata, device=device, dtype=dtype), |
| ) |
|
|
| def _check_capacity(self) -> None: |
| exceeded = False |
| if self.max_records is not None and self.record_count > int(self.max_records): |
| exceeded = True |
| if not exceeded: |
| return |
| self.capacity_exceeded_count += 1 |
| msg = ( |
| "DeMemWMStreamingCache capacity exceeded " |
| f"records={self.record_count}/{self.max_records}; " |
| "no eviction performed because no_evict=true" |
| ) |
| if self.on_capacity_exceeded == "error": |
| raise RuntimeError(msg) |
| if self.on_capacity_exceeded == "warn": |
| warnings.warn(msg, RuntimeWarning, stacklevel=2) |
|
|
| def add_raw_latents( |
| self, |
| latents: torch.Tensor, |
| frame_indices: torch.Tensor, |
| source_is_generated: Optional[torch.Tensor] = None, |
| pose: Optional[torch.Tensor] = None, |
| ) -> None: |
| if not self.enabled or self.keep_raw_latents != "all": |
| return |
| if latents.ndim != 5: |
| raise ValueError("cached raw latents must have shape (T,B,C,H,W)") |
| T, B = int(latents.shape[0]), int(latents.shape[1]) |
| if frame_indices.shape != (T, B): |
| raise ValueError("cached frame_indices must have shape (T,B)") |
| if self._batch_size is None: |
| self._batch_size = B |
| elif self._batch_size != B: |
| raise ValueError("streaming cache batch size changed within a video") |
| keep_positions: list[int] = [] |
| frame_cpu = frame_indices.detach().cpu() |
| for t in range(T): |
| keys = [(b, int(frame_cpu[t, b].item())) for b in range(B)] |
| if any(key not in self._raw_keys for key in keys): |
| keep_positions.append(t) |
| self._raw_keys.update(keys) |
| if not keep_positions: |
| return |
| pos = torch.as_tensor(keep_positions, dtype=torch.long) |
| seg_latents = latents.index_select(0, pos.to(device=latents.device)) |
| seg_frames = frame_indices.index_select(0, pos.to(device=frame_indices.device)) |
| if source_is_generated is None: |
| seg_generated = torch.zeros(seg_frames.shape, device=seg_frames.device, dtype=torch.bool) |
| else: |
| seg_generated = source_is_generated.index_select(0, pos.to(device=source_is_generated.device)).bool() |
| seg_pose = None if pose is None else pose.index_select(0, pos.to(device=pose.device)) |
| segment_idx = len(self._raw_segments) |
| self._raw_segments.append( |
| _RawLatentSegment( |
| latents=self._store_tensor(seg_latents), |
| frame_indices=self._store_tensor(seg_frames), |
| source_is_generated=self._store_tensor(seg_generated), |
| pose=self._store_tensor(seg_pose), |
| ) |
| ) |
| for local_pos, source_pos in enumerate(keep_positions): |
| for b in range(B): |
| key = (b, int(frame_cpu[source_pos, b].item())) |
| self._raw_index.setdefault(key, (segment_idx, local_pos)) |
| |
| self._raw_concat_version += 1 |
| self._raw_concat_cache = None |
|
|
| def add_records(self, kind: str, batch_idx: int, records: Iterable[MemoryRecord]) -> None: |
| if not self.enabled or not self.keep_compressed_records: |
| return |
| if kind not in self._records: |
| raise ValueError(f"unsupported cache record kind: {kind}") |
| batch_idx = int(batch_idx) |
| bucket = self._records[kind].setdefault(batch_idx, []) |
| added_any = False |
| for record in records: |
| if kind == "anchor" and not self.keep_prefix_anchors: |
| continue |
| key = ( |
| kind, |
| batch_idx, |
| str(record.chunk_id or ""), |
| int(record.source_start), |
| int(record.source_end), |
| bool(record.is_generated), |
| ) |
| if key in self._record_keys: |
| continue |
| self._record_keys.add(key) |
| bucket.append(self._record_to_storage(record)) |
| added_any = True |
| if added_any: |
| |
| self._banks_version += 1 |
| self._banks_built_cache.clear() |
| self._check_capacity() |
|
|
| def add_memory_banks(self, anchor_banks: list[CausalMemoryBank], revisit_banks: list[CausalMemoryBank]) -> None: |
| for batch_idx, bank in enumerate(anchor_banks): |
| self.add_records("anchor", batch_idx, bank.records) |
| for batch_idx, bank in enumerate(revisit_banks): |
| self.add_records("revisit", batch_idx, bank.records) |
|
|
| def memory_banks(self, kind: str, *, device: torch.device, dtype: torch.dtype, batch_size: int | None = None) -> list[CausalMemoryBank]: |
| if kind not in self._records: |
| raise ValueError(f"unsupported cache record kind: {kind}") |
| B = int(batch_size or self.batch_size or (max(self._records[kind].keys()) + 1 if self._records[kind] else 0)) |
| cache_key = (kind, device, dtype, B) |
| cached = self._banks_built_cache.get(cache_key) |
| if cached is not None and cached[0] == self._banks_version: |
| return cached[1] |
| banks: list[CausalMemoryBank] = [] |
| for batch_idx in range(B): |
| bank = CausalMemoryBank() |
| for record in self._records[kind].get(batch_idx, []): |
| bank.add_record(self._record_to_device(record, device=device, dtype=dtype)) |
| banks.append(bank) |
| self._banks_built_cache[cache_key] = (self._banks_version, banks) |
| return banks |
|
|
| def records_for_batch(self, kind: str, batch_idx: int) -> tuple[MemoryRecord, ...]: |
| if kind not in self._records: |
| raise ValueError(f"unsupported cache record kind: {kind}") |
| return tuple(self._records[kind].get(int(batch_idx), ())) |
|
|
| def raw_latents_for_frames( |
| self, |
| *, |
| batch_idx: int, |
| frame_indices: torch.Tensor, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| frames = frame_indices.detach().cpu().reshape(-1) |
| rows = [] |
| batch_idx = int(batch_idx) |
| for frame in frames.tolist(): |
| key = (batch_idx, int(frame)) |
| location = self._raw_index.get(key) |
| if location is None: |
| raise KeyError(f"raw latent for batch={batch_idx}, frame={int(frame)} is not cached") |
| segment_idx, local_pos = location |
| rows.append(self._raw_segments[segment_idx].latents[local_pos, batch_idx]) |
| if not rows: |
| template = self._raw_segments[0].latents |
| return template[:0, batch_idx:batch_idx + 1].to(device=device, dtype=dtype) |
| return torch.stack(rows, dim=0).unsqueeze(1).to(device=device, dtype=dtype) |
|
|
| def _select_time_positions( |
| self, |
| frame_indices: torch.Tensor, |
| target_frame_indices: Optional[torch.Tensor], |
| max_recent_frames: Optional[int], |
| exclude_latest_local_frames: int = 0, |
| ) -> torch.Tensor: |
| T, B = frame_indices.shape |
| if target_frame_indices is None or max_recent_frames is None or int(max_recent_frames) <= 0: |
| return torch.arange(T, dtype=torch.long) |
| targets = target_frame_indices.detach().cpu() |
| if targets.ndim == 1: |
| targets = targets[:, None].expand(-1, B) |
| frames = frame_indices.detach().cpu() |
| recent = int(max_recent_frames) |
| exclude = max(0, int(exclude_latest_local_frames)) |
| |
| |
| |
| valid = frames.unsqueeze(0) < (targets.unsqueeze(1) - exclude) |
| |
| |
| valid_f = valid.flip(1) |
| keep_f = (valid_f.long().cumsum(1) <= recent) & valid_f |
| |
| keep_any = keep_f.flip(1).any(dim=0).any(dim=1) |
| return keep_any.nonzero(as_tuple=False).flatten() |
|
|
| def materialize_raw_latents( |
| self, |
| *, |
| device: torch.device, |
| dtype: torch.dtype, |
| max_recent_frames: Optional[int] = None, |
| target_frame_indices: Optional[torch.Tensor] = None, |
| exclude_latest_local_frames: int = 0, |
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| if not self._raw_segments: |
| return None, None, None, None |
| if target_frame_indices is not None and max_recent_frames is not None and int(max_recent_frames) > 0: |
| return self._materialize_recent_raw_latents( |
| device=device, |
| dtype=dtype, |
| max_recent_frames=int(max_recent_frames), |
| target_frame_indices=target_frame_indices, |
| exclude_latest_local_frames=exclude_latest_local_frames, |
| ) |
| |
| if self._raw_concat_cache is None or self._raw_concat_built != self._raw_concat_version: |
| latents = torch.cat([seg.latents for seg in self._raw_segments], dim=0) |
| frame_indices = torch.cat([seg.frame_indices for seg in self._raw_segments], dim=0) |
| generated = torch.cat([seg.source_is_generated for seg in self._raw_segments], dim=0) |
| pose: Optional[torch.Tensor] = None |
| if all(seg.pose is not None for seg in self._raw_segments): |
| pose = torch.cat([seg.pose for seg in self._raw_segments if seg.pose is not None], dim=0) |
| self._raw_concat_cache = (latents, frame_indices, generated, pose) |
| self._raw_concat_built = self._raw_concat_version |
| else: |
| latents, frame_indices, generated, pose = self._raw_concat_cache |
| pos = self._select_time_positions(frame_indices, target_frame_indices, max_recent_frames, exclude_latest_local_frames) |
| if pos.numel() == 0: |
| empty_latents = latents[:0].to(device=device, dtype=dtype) |
| empty_frames = frame_indices[:0].to(device=device) |
| empty_generated = generated[:0].to(device=device, dtype=torch.bool) |
| empty_pose = None if pose is None else pose[:0].to(device=device) |
| return empty_latents, empty_frames, empty_generated, empty_pose |
| latents = latents.index_select(0, pos.to(device=latents.device)).to(device=device, dtype=dtype) |
| frame_indices = frame_indices.index_select(0, pos.to(device=frame_indices.device)).to(device=device) |
| generated = generated.index_select(0, pos.to(device=generated.device)).to(device=device, dtype=torch.bool) |
| if pose is not None: |
| pose = pose.index_select(0, pos.to(device=pose.device)).to(device=device) |
| return latents, frame_indices, generated, pose |
|
|
| def _materialize_recent_raw_latents( |
| self, |
| *, |
| device: torch.device, |
| dtype: torch.dtype, |
| max_recent_frames: int, |
| target_frame_indices: torch.Tensor, |
| exclude_latest_local_frames: int = 0, |
| ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: |
| B = self.batch_size |
| targets = target_frame_indices.detach().cpu() |
| if targets.ndim == 1: |
| targets = targets[:, None].expand(-1, B) |
| elif targets.shape[1] == 1 and B > 1: |
| targets = targets.expand(-1, B) |
| if targets.shape[1] != B: |
| raise ValueError("target_frame_indices batch dimension does not match streaming cache") |
|
|
| recent = max(0, int(max_recent_frames)) |
| exclude = max(0, int(exclude_latest_local_frames)) |
| counts = torch.zeros(targets.shape, dtype=torch.long) |
| selected: list[tuple[_RawLatentSegment, int]] = [] |
|
|
| for segment in reversed(self._raw_segments): |
| frames = segment.frame_indices.detach().cpu() |
| for local_pos in range(frames.shape[0] - 1, -1, -1): |
| valid = frames[local_pos].unsqueeze(0) < (targets - exclude) |
| needed = valid & (counts < recent) |
| if not needed.any(): |
| continue |
| selected.append((segment, local_pos)) |
| counts += needed.long() |
| if bool((counts >= recent).all().item()): |
| break |
| if bool((counts >= recent).all().item()): |
| break |
|
|
| if not selected: |
| template = self._raw_segments[0] |
| empty_latents = template.latents[:0].to(device=device, dtype=dtype) |
| empty_frames = template.frame_indices[:0].to(device=device) |
| empty_generated = template.source_is_generated[:0].to(device=device, dtype=torch.bool) |
| empty_pose = None if template.pose is None else template.pose[:0].to(device=device) |
| return empty_latents, empty_frames, empty_generated, empty_pose |
|
|
| selected.reverse() |
| latents = torch.stack([segment.latents[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=dtype) |
| frame_indices = torch.stack([segment.frame_indices[local_pos] for segment, local_pos in selected], dim=0).to(device=device) |
| generated = torch.stack([segment.source_is_generated[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=torch.bool) |
| pose = None |
| if all(segment.pose is not None for segment, _ in selected): |
| pose = torch.stack( |
| [segment.pose[local_pos] for segment, local_pos in selected if segment.pose is not None], |
| dim=0, |
| ).to(device=device) |
| return latents, frame_indices, generated, pose |
|
|
| def diagnostics(self, prefix: str = "cache") -> dict[str, Any]: |
| return { |
| f"{prefix}_enabled": bool(self.enabled), |
| f"{prefix}_records": int(self.record_count), |
| f"{prefix}_anchor_records": int(self.records_count("anchor")), |
| f"{prefix}_revisit_records": int(self.records_count("revisit")), |
| f"{prefix}_slots": int(self.slot_count), |
| f"{prefix}_raw_frame_slots": int(self.raw_frame_slots), |
| f"{prefix}_raw_segments": int(self.raw_segment_count), |
| f"{prefix}_evictions": int(self.evictions), |
| f"{prefix}_resets": int(self.reset_count), |
| f"{prefix}_capacity_exceeded": int(self.capacity_exceeded_count), |
| f"{prefix}_device": self.device, |
| f"{prefix}_current_video_id": self.current_video_id, |
| f"{prefix}_clear_between_videos": bool(self.clear_between_videos), |
| f"{prefix}_no_evict": bool(self.no_evict), |
| } |
|
|
|
|
| DeMemWMStreamingCache = StreamingCache |
|
|