DeMemWM / algorithms /worldmem /dememwm /algorithm.py
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
from __future__ import annotations
import math
from dataclasses import replace
from typing import Iterable
import torch
from einops import rearrange
from .cache import StreamingCache
from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens
from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics
from .injection import InjectionAdapter
from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
from .negatives import apply_revisit_eval_corruption
from .retrieval import deterministic_revisit_retrieval
from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, noise_bucket_from_denoising_fraction, noise_bucket_from_noise_levels, noise_bucket_ids_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum
from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors
class MemoryDiTMixin:
"""Standalone DeMemWM / Memory-DiT mixin.
Reuses the base video-DiT infrastructure while keeping memory construction and
injection under the standalone `dememwm` package. Legacy memory-method files
are not part of this path.
"""
strict_key_prefixes = (
"dememwm_dynamic_compressor.",
"dememwm_anchor_proj.",
"dememwm_revisit_proj.",
"dememwm_revisit_gate.",
)
strict_key_substrings = (
".memory_token_cross_attn.",
)
_TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({
"revisit_candidate_frame_count",
"revisit_pose_preselect_input_count",
"revisit_pose_preselect_selected_count",
"revisit_exact_fov_candidate_count",
"valid_revisit_frame_count",
"no_valid_revisit_count",
"revisit_selected_frame_count",
"revisit_frame_fov_overlap_mean",
"revisit_best_selected_frame_fov_overlap_mean",
"revisit_best_selected_plucker_overlap_mean",
"revisit_best_selected_gap_frames_mean",
"revisit_gate_raw",
"revisit_gate_eff",
"revisit_learned_gate_mean",
"revisit_effective_gate_mean",
"generated_history_proxy_prob",
"noise_bucket_target_count",
"noise_bucket_high_target_count",
"noise_bucket_mid_target_count",
"noise_bucket_low_target_count",
})
_VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({
"cache_records",
"cache_slots",
})
def _memory_cfg(self):
return getattr(self.cfg, "dememwm", None)
def _cfg_get(self, obj, name, default):
if obj is None:
return default
if isinstance(obj, dict):
return obj.get(name, default)
return getattr(obj, name, default)
def _cfg_has(self, obj, name: str) -> bool:
if obj is None:
return False
if isinstance(obj, dict):
return name in obj
try:
getattr(obj, name)
return True
except Exception:
return False
def _stage_policy_cfg(self):
return self._cfg_get(self._memory_cfg(), "stage_policy", None)
def _eval_ablation_cfg(self):
return self._cfg_get(self._memory_cfg(), "eval_ablation", None)
def _generated_history_proxy_cfg(self):
return self._cfg_get(self._memory_cfg(), "generated_history_proxy", None)
def _eval_ablation_state(self) -> tuple[bool, str]:
cfg = self._eval_ablation_cfg()
enabled = bool(self._cfg_get(cfg, "enabled", False))
branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal"))
return enabled, branch
def _effective_gate_state(self, denoising_fraction: float | None = None, noise_bucket: str | None = None) -> dict:
memory_cfg = self._memory_cfg()
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
injection_cfg = self._cfg_get(memory_cfg, "injection", None)
anchor_config_enabled = self._stream_enabled(anchor_cfg)
dynamic_config_enabled = self._stream_enabled(dynamic_cfg)
revisit_config_enabled = self._stream_enabled(revisit_cfg)
curriculum_state = self._curriculum_state()
eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state()
debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False))
resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction)
gates = compute_stream_gates(
curriculum_state.stage,
denoising_fraction=denoising_fraction,
debug_force_all_streams=debug_force,
anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)),
dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)),
revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)),
)
anchor_effective_enabled = bool(gates.anchor_enabled and anchor_config_enabled)
dynamic_effective_enabled = bool(gates.dynamic_enabled and dynamic_config_enabled)
revisit_stage_config_enabled = bool(gates.revisit_enabled and revisit_config_enabled)
if eval_ablation_enabled:
if eval_ablation_branch == "memory_off":
anchor_effective_enabled = False
dynamic_effective_enabled = False
revisit_stage_config_enabled = False
elif eval_ablation_branch == "A_only":
dynamic_effective_enabled = False
revisit_stage_config_enabled = False
elif eval_ablation_branch == "D_only":
anchor_effective_enabled = False
revisit_stage_config_enabled = False
elif eval_ablation_branch == "A_plus_D":
revisit_stage_config_enabled = False
return {
"curriculum_state": curriculum_state,
"gates": gates,
"resolved_noise_bucket": resolved_noise_bucket,
"anchor_config_enabled": anchor_config_enabled,
"dynamic_config_enabled": dynamic_config_enabled,
"revisit_config_enabled": revisit_config_enabled,
"anchor_effective_enabled": anchor_effective_enabled,
"dynamic_effective_enabled": dynamic_effective_enabled,
"revisit_stage_config_enabled": revisit_stage_config_enabled,
"eval_ablation_enabled": eval_ablation_enabled,
"eval_ablation_branch": eval_ablation_branch,
"force_revisit_off": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_off"),
"force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"),
}
def _validate_config_contract(self) -> dict:
if bool(getattr(self, "_dememwm_contract_validated", False)):
return getattr(self, "_last_dememwm_config_diagnostics", {})
memory_cfg = self._memory_cfg()
if memory_cfg is None:
self._dememwm_contract_validated = True
self._last_dememwm_config_diagnostics = {}
return {}
stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)]
if stale_sections:
raise ValueError(f"stale DeMemWM config sections are not part of the final contract: {stale_sections}")
ratio_fields = [
name
for name in ("anchor_ratio", "dynamic_ratio", "revisit_ratio", "revisit_max_ratio")
if self._cfg_has(memory_cfg, name)
]
if ratio_fields:
raise ValueError(f"standalone DeMemWM derives stream slots from latent shape and compression settings, not ratio fields: {ratio_fields}")
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
stale_nested = []
for section_name, section_cfg, field_names in (
("anchor", anchor_cfg, ("policy", "topk", "pin_prefix")),
("dynamic", dynamic_cfg, ("include_generated_recent",)),
("revisit", revisit_cfg, ("deterministic_only", "min_age_frames", "min_gap_frames", "topk", "max_chunks", "chunk_frames", "min_score", "time_weight", "pose_weight", "latent_weight", "pose_overlap_threshold", "action_overlap_threshold", "generated_penalty", "force_gate_zero_when_invalid")),
):
stale_nested.extend(
f"{section_name}.{field_name}" for field_name in field_names if self._cfg_has(section_cfg, field_name)
)
if stale_nested:
raise ValueError(f"stale DeMemWM config fields are not part of the final contract: {stale_nested}")
exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4))
if exclude_latest_local_frames < 0:
raise ValueError("dememwm.dynamic.exclude_latest_local_frames must be non-negative")
if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)):
raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval")
fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
if fov_overlap_threshold is not None:
fov_overlap_threshold = float(fov_overlap_threshold)
if fov_overlap_threshold < 0.0:
raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative")
high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
if high_quality_fov_threshold < 0.0:
raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative")
plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10))
if plucker_weight < 0.0:
raise ValueError("dememwm.revisit.plucker_weight must be non-negative")
for field_name, default in (
("fov_half_h", 52.5),
("fov_half_v", 37.5),
("fov_radius", 30.0),
("plucker_focal_length", 0.35),
):
value = float(self._cfg_get(revisit_cfg, field_name, default))
if value <= 0.0:
raise ValueError(f"dememwm.revisit.{field_name} must be positive")
for field_name, default in (
("fov_yaw_samples", 25),
("fov_pitch_samples", 20),
("fov_depth_samples", 20),
("plucker_grid_h", 4),
("plucker_grid_w", 4),
):
value = int(self._cfg_get(revisit_cfg, field_name, default))
if value <= 0:
raise ValueError(f"dememwm.revisit.{field_name} must be positive")
stage_policy_cfg = self._stage_policy_cfg()
if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)):
raise ValueError("final DeMemWM keeps noise_bucket logging enabled")
proxy_cfg = self._generated_history_proxy_cfg()
proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0))
proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0))
proxy_noise_std = float(self._cfg_get(proxy_cfg, "noise_std", 0.0))
proxy_ramp_steps = int(self._cfg_get(proxy_cfg, "ramp_steps", 0))
if proxy_max_prob < 0.0 or proxy_max_prob > 1.0:
raise ValueError("dememwm.generated_history_proxy.max_prob must be in [0, 1]")
if proxy_dropout_prob < 0.0 or proxy_dropout_prob > 1.0:
raise ValueError("dememwm.generated_history_proxy.dropout_prob must be in [0, 1]")
if proxy_noise_std < 0.0:
raise ValueError("dememwm.generated_history_proxy.noise_std must be non-negative")
if proxy_ramp_steps < 0:
raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative")
eval_ablation_cfg = self._eval_ablation_cfg()
normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal"))
diagnostics = {
"dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
"revisit_deterministic_fov_plucker_retrieval": True,
"revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
"revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
"revisit_plucker_weight": plucker_weight,
"stage_policy_noise_bucket_logging": True,
}
self._dememwm_contract_validated = True
self._last_dememwm_config_diagnostics = diagnostics
return diagnostics
def _stream_enabled(self, stream_cfg) -> bool:
return bool(self._cfg_get(stream_cfg, "enabled", True))
def _context_frame_count(self) -> int:
frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1))
return max(0, int(getattr(self, "context_frames", 0) or 0) // frame_stack)
def _local_context_exclusion_frames(self) -> int:
n_tokens = max(0, int(getattr(self, "n_tokens", 0) or 0))
frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1))
return n_tokens * frame_stack
def _curriculum_state(self, step: int | None = None):
if step is None:
step = int(getattr(self, "global_step", 0) or 0)
return resolve_curriculum(self._memory_cfg(), step)
def _generated_history_proxy_prob(self, step: int | None = None) -> float:
cfg = self._generated_history_proxy_cfg()
if not bool(self._cfg_get(cfg, "enabled", False)):
return 0.0
max_prob = min(max(float(self._cfg_get(cfg, "max_prob", 0.0)), 0.0), 1.0)
if max_prob <= 0.0:
return 0.0
if step is None:
step = int(getattr(self, "global_step", 0) or 0)
start_step = int(self._cfg_get(cfg, "start_step", 0))
if step < start_step:
return 0.0
ramp_steps = int(self._cfg_get(cfg, "ramp_steps", 0))
if ramp_steps <= 0:
return max_prob
ramp_fraction = min(max(float(step - start_step) / float(ramp_steps), 0.0), 1.0)
return max_prob * ramp_fraction
def _apply_generated_history_proxy(
self,
source_latents: torch.Tensor,
source_is_generated: torch.Tensor | None,
context_frame_count: int | None = None,
target_start_frame: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor, dict]:
cfg = self._generated_history_proxy_cfg()
prob = self._generated_history_proxy_prob()
noise_std = float(self._cfg_get(cfg, "noise_std", 0.0))
dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0))
diagnostics = {
"generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)),
"generated_history_proxy_prob": float(prob),
"generated_history_proxy_noise_std": float(noise_std),
"generated_history_proxy_dropout_prob": float(dropout_prob),
"generated_history_proxy_frame_count": 0,
"generated_history_proxy_frame_fraction": 0.0,
}
if source_is_generated is None:
source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
else:
source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool)
if prob <= 0.0 or source_latents.numel() == 0:
return source_latents, source_is_generated, diagnostics
eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
if context_frame_count is not None or target_start_frame is not None:
frame_positions = torch.arange(source_latents.shape[0], device=source_latents.device)[:, None]
if context_frame_count is not None:
eligible_mask &= frame_positions >= max(0, int(context_frame_count))
if target_start_frame is not None:
eligible_mask &= frame_positions < max(0, int(target_start_frame))
proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask
proxy_count = int(proxy_mask.detach().long().sum().item())
total_count = max(1, int(proxy_mask.numel()))
diagnostics["generated_history_proxy_frame_count"] = proxy_count
diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count)
if proxy_count == 0:
return source_latents, source_is_generated, diagnostics
corrupt_latents = source_latents.clone()
frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype)
if noise_std > 0.0:
corrupt_latents = corrupt_latents + torch.randn_like(corrupt_latents) * float(noise_std) * frame_mask
if dropout_prob > 0.0:
dropout_mask = torch.rand(
(*source_latents.shape[:2], 1, source_latents.shape[-2], source_latents.shape[-1]),
device=source_latents.device,
) < dropout_prob
dropout_mask = dropout_mask & proxy_mask[:, :, None, None, None]
corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents)
source_is_generated = source_is_generated.clone()
source_is_generated |= proxy_mask
return corrupt_latents, source_is_generated, diagnostics
def _checkpoint_cfg(self):
return self._cfg_get(self._memory_cfg(), "checkpoint", None)
def _strict_eval_load_enabled(self) -> bool:
return bool(self._cfg_get(self._checkpoint_cfg(), "strict_dememwm_eval_load", True))
def _cache_cfg(self):
return self._cfg_get(self._memory_cfg(), "cache", None)
def _cache_enabled(self) -> bool:
return bool(self._cfg_get(self._cache_cfg(), "enabled", False))
def _new_streaming_cache(self, video_id=None) -> StreamingCache | None:
if not self._cache_enabled():
return None
cache = StreamingCache.from_config(self._cache_cfg(), enabled_default=True)
if cache.clear_between_videos:
cache.reset(video_id=video_id)
return cache
def _is_memory_adapter_param(self, name: str) -> bool:
return ".memory_token_cross_attn." in name
def _param_group_name(self, name: str, state=None) -> str:
state = state or self._curriculum_state()
if name.startswith("vae.") or name.startswith("validation_lpips_model."):
return "excluded_frozen"
if name.startswith(("dememwm_dynamic_compressor.", "dememwm_anchor_proj.", "dememwm_revisit_proj.")):
return "dememwm_modules"
if self._is_memory_adapter_param(name):
return "memory_adapters"
if name.startswith("diffusion_model."):
return "full_dit"
return "dememwm_modules"
def _group_trainable(self, group_name: str, state) -> bool:
if group_name in {"dememwm_modules", "memory_adapters"}:
return True
if group_name == "full_dit":
return state.dit_full_trainable
return False
def _group_lr(self, group_name: str, state) -> float:
if group_name == "dememwm_modules":
return state.dememwm_lr
if group_name == "memory_adapters":
return state.memory_adapter_lr
if group_name == "full_dit":
return state.full_dit_lr
return 0.0
def _apply_freeze_policy(self, optimizer=None, step: int | None = None):
state = self._curriculum_state(step)
# Keep DDP's trainable graph stable: DiT params stay requires_grad=True
# from step 0 and are frozen by optimizer LR=0 until the full stage.
# Re-walk only when curriculum diagnostics can change.
freeze_key = (state.stage, state.dit_train_state, state.freeze_vae)
last_key = getattr(self, "_last_freeze_key", None)
if last_key != freeze_key:
trainable_tensors = {
"dememwm_modules": 0,
"memory_adapters": 0,
"full_dit": 0,
"excluded_frozen": 0,
}
trainable_scalars = {key: 0 for key in trainable_tensors}
requires_grad_tensors = {key: 0 for key in trainable_tensors}
requires_grad_scalars = {key: 0 for key in trainable_tensors}
for name, param in self.named_parameters():
group_name = self._param_group_name(name, state)
should_train = self._group_trainable(group_name, state)
if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae):
should_train = False
should_require_grad = False
else:
should_require_grad = True
param.requires_grad_(should_require_grad)
if should_train:
trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1
trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel())
if should_require_grad:
requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1
requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel())
self._last_freeze_key = freeze_key
self._last_trainable_tensors = trainable_tensors
self._last_trainable_scalars = trainable_scalars
self._last_requires_grad_tensors = requires_grad_tensors
self._last_requires_grad_scalars = requires_grad_scalars
else:
trainable_tensors = getattr(self, "_last_trainable_tensors", {})
trainable_scalars = getattr(self, "_last_trainable_scalars", {})
requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {})
requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {})
if optimizer is not None:
for param_group in optimizer.param_groups:
group_name = param_group.get("name", "")
trainable = self._group_trainable(group_name, state)
param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0
diagnostics = state.diagnostics()
for group_name in ("dememwm_modules", "memory_adapters", "full_dit"):
diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0)
diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0)
diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0)
diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0)
diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0
self._last_dememwm_freeze_diagnostics = diagnostics
return state
def configure_optimizers(self):
state = self._curriculum_state(0)
self._apply_freeze_policy(step=0)
grouped: dict[str, list[torch.nn.Parameter]] = {
"dememwm_modules": [],
"memory_adapters": [],
"full_dit": [],
}
for name, param in self.named_parameters():
group_name = self._param_group_name(name, state)
if group_name in grouped:
grouped[group_name].append(param)
param_groups = []
for group_name in ("dememwm_modules", "memory_adapters", "full_dit"):
params = grouped[group_name]
if params:
trainable = self._group_trainable(group_name, state)
param_groups.append({
"params": params,
"lr": self._group_lr(group_name, state) if trainable else 0.0,
"name": group_name,
})
if not param_groups:
raise RuntimeError("DeMemWM optimizer found no trainable parameter groups")
return torch.optim.AdamW(
param_groups,
weight_decay=self.cfg.weight_decay,
betas=self.cfg.optimizer_beta,
)
def on_train_start(self):
optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or []
for optimizer in optimizers:
self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0))
def on_train_batch_start(self, batch, batch_idx):
optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or []
for optimizer in optimizers:
self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0))
def on_after_backward(self):
step = int(getattr(self, "global_step", 0) or 0)
state = self._apply_freeze_policy(step=step)
for name, param in self.named_parameters():
if param.grad is None:
continue
group_name = self._param_group_name(name, state)
if not self._group_trainable(group_name, state):
param.grad = None
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0))
optimizer.step(closure=optimizer_closure)
self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0) + 1)
def on_load_checkpoint(self, checkpoint):
super().on_load_checkpoint(checkpoint)
if self._strict_eval_load_enabled():
state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint
self.strict_checkpoint_key_check(state_dict)
def _preprocess_batch(self, batch):
"""Preprocess RGB or precomputed-latent Minecraft batches for DeMemWM.
MinecraftVideoLatentDataset returns an extra image_hw tensor. Keep the
DeMemWM path on VAE latents while preserving RGB image size for Plucker
pose embeddings. This mirrors the existing latent-dataset contract
without routing through the legacy SSM memory implementation.
"""
from ..df_video import euler_to_camera_to_world_matrix
if len(batch) == 5:
xs, conditions, pose_conditions, frame_index, image_hw = batch
self._last_dememwm_xs_are_latents = True
self._last_dememwm_image_hw = image_hw
else:
xs, conditions, pose_conditions, frame_index = batch
self._last_dememwm_xs_are_latents = False
self._last_dememwm_image_hw = None
if self.action_cond_dim:
conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
conditions = rearrange(conditions, "b t d -> t b d").contiguous()
else:
raise NotImplementedError("Only support external cond.")
pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
frame_index = rearrange(frame_index, "b t -> t b").contiguous()
return xs, conditions, pose_conditions, c2w_mat, frame_index
def _as_latents(self, xs: torch.Tensor) -> torch.Tensor:
if bool(getattr(self, "_last_dememwm_xs_are_latents", False)):
return xs
return self.encode(xs)
def _image_size(self, xs: torch.Tensor) -> tuple[int, int]:
image_hw = getattr(self, "_last_dememwm_image_hw", None)
if image_hw is not None:
if torch.is_tensor(image_hw):
values = image_hw.detach().cpu().reshape(-1).tolist()
else:
values = list(image_hw)
if len(values) >= 2:
return int(values[0]), int(values[1])
return int(xs.shape[-2]), int(xs.shape[-1])
def _update_streaming_cache(
self,
cache: StreamingCache | None,
new_latents: torch.Tensor,
frame_indices: torch.Tensor,
pose: torch.Tensor | None = None,
source_is_generated: torch.Tensor | None = None,
action: torch.Tensor | None = None,
) -> None:
if cache is None or not cache.enabled or new_latents is None or new_latents.shape[0] == 0:
return
cache.add_raw_latents(new_latents, frame_indices, source_is_generated, pose)
if not cache.keep_compressed_records:
return
memory_cfg = self._memory_cfg()
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2))
anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])]
anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None)
anchor_src_h, anchor_src_w = self._projected_spatial_grid_size(
int(new_latents.shape[-2]),
int(new_latents.shape[-1]),
self.dememwm_anchor_proj,
token_patch_size,
)
anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size(
anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8
)
anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False))
allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False))
# Prefix anchors are a per-video prefix resource. Do not add new prefix
# anchors for later committed segments unless explicitly generated anchors are allowed.
if cache.records_count("anchor") > 0 and not allow_generated_anchor:
anchor_indices = []
anchor_banks, revisit_banks = self._build_streaming_cache_records(
new_latents,
frame_indices,
source_is_generated,
pose,
action,
allow_generated_anchor,
anchor_indices,
anchor_pool_h,
anchor_pool_w,
anchor_diverse,
token_patch_size,
)
cache.add_memory_banks(anchor_banks, revisit_banks)
def _build_model(self):
from algorithms.common.metrics import LearnedPerceptualImagePatchSimilarity
from .gates import RevisitRawGate
from ..models.diffusion import Diffusion
from ..models.pose_prediction import PosePredictionNet
from ..models.vae import VAE_models
self.diffusion_model = Diffusion(
reference_length=self.memory_condition_length,
x_shape=self.x_stacked_shape,
action_cond_dim=self.action_cond_dim,
pose_cond_dim=self.pose_cond_dim,
is_causal=self.causal,
cfg=self.cfg.diffusion,
is_dit=True,
use_plucker=self.use_plucker,
relative_embedding=self.relative_embedding,
state_embed_only_on_qk=self.state_embed_only_on_qk,
use_memory_attention=False,
add_timestamp_embedding=self.add_timestamp_embedding,
memory_token_cross_attention=getattr(self.cfg, "memory_token_cross_attention", True),
memory_cross_attn_layers=getattr(self.cfg, "memory_cross_attn_layers", None),
ref_mode=self.ref_mode,
)
memory_cfg = self._memory_cfg()
self._validate_config_contract()
injection_cfg = self._cfg_get(memory_cfg, "injection", None)
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2))
max_source_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8))
self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor(
latent_channels=self.x_stacked_shape[0],
dit_hidden_size=hidden_size,
patch_size=token_patch_size,
conv_kernel_t=int(self._cfg_get(dynamic_cfg, "conv_kernel_t", 3)),
conv_stride_t=int(self._cfg_get(dynamic_cfg, "conv_stride_t", 2)),
max_source_frames=max_source_frames,
exclude_latest_local_frames=int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)),
)
spatial_mid_channels = self.x_stacked_shape[0] * token_patch_size * token_patch_size
self.dememwm_anchor_proj = SpatialConv2DMemoryProjector(
latent_channels=self.x_stacked_shape[0],
dit_hidden_size=hidden_size,
mid_channels=spatial_mid_channels,
kernel_size=3,
)
self.dememwm_revisit_proj = SpatialConv2DMemoryProjector(
latent_channels=self.x_stacked_shape[0],
dit_hidden_size=hidden_size,
mid_channels=spatial_mid_channels,
kernel_size=3,
)
self.dememwm_revisit_gate = RevisitRawGate()
self.dememwm_injection_adapter = InjectionAdapter()
self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
self.vae = VAE_models["vit-l-20-shallow-encoder"]().eval()
for param in self.vae.parameters():
param.requires_grad_(False)
if self.require_pose_prediction:
self.pose_prediction_model = PosePredictionNet()
def _project_latent_patch_tokens(
self,
latents: torch.Tensor,
projection: torch.nn.Module,
patch_size: int,
) -> torch.Tensor:
# (T,B,C,H,W) -> (B,T,T_frame,D). Conv2D projectors keep T_frame=H*W.
if bool(getattr(projection, "projects_spatial_latents", False)):
return projection(latents)
patch_vectors = latent_patch_tokens(latents, patch_size)
return projection(patch_vectors).permute(1, 0, 2, 3).contiguous()
def _projected_spatial_grid_size(
self,
latent_h: int,
latent_w: int,
projection: torch.nn.Module,
patch_size: int,
) -> tuple[int, int]:
if bool(getattr(projection, "projects_spatial_latents", False)):
return int(latent_h), int(latent_w)
return int(latent_h) // int(patch_size), int(latent_w) // int(patch_size)
def _take_uniform_slots(self, tokens: torch.Tensor, num_slots: int) -> torch.Tensor:
if tokens.ndim != 2:
raise ValueError("tokens must have shape (N,D)")
num_slots = max(0, int(num_slots))
if num_slots == 0:
return tokens[:0]
if tokens.shape[0] <= num_slots:
return tokens
idx = torch.linspace(0, tokens.shape[0] - 1, num_slots, device=tokens.device).round().long()
return tokens.index_select(0, idx)
def _spatial_pool_tokens(
self,
tokens: torch.Tensor,
pool_h: int,
pool_w: int,
src_h: int,
src_w: int,
) -> torch.Tensor:
return spatial_pool_tokens(tokens, pool_h, pool_w, src_h, src_w)
def _resolve_spatial_pool_size(
self,
compress_cfg,
src_h: int,
src_w: int,
default_pool_h: int,
default_pool_w: int,
) -> tuple[int, int]:
ratio = self._cfg_get(compress_cfg, "downsample_ratio", None)
ratio_h = self._cfg_get(compress_cfg, "downsample_h", ratio)
ratio_w = self._cfg_get(compress_cfg, "downsample_w", ratio)
if ratio_h is not None or ratio_w is not None:
if ratio_h is None:
ratio_h = ratio_w
if ratio_w is None:
ratio_w = ratio_h
ratio_h = float(ratio_h)
ratio_w = float(ratio_w)
if ratio_h <= 0.0 or ratio_w <= 0.0:
raise ValueError("DeMemWM compress downsample ratios must be positive")
return (
max(1, int(math.ceil(float(src_h) / ratio_h))),
max(1, int(math.ceil(float(src_w) / ratio_w))),
)
pool_h = int(self._cfg_get(compress_cfg, "pool_h", default_pool_h))
pool_w = int(self._cfg_get(compress_cfg, "pool_w", default_pool_w))
if pool_h <= 0 or pool_w <= 0:
raise ValueError("DeMemWM compress pool_h/pool_w must be positive")
return pool_h, pool_w
def _select_diverse_anchor_positions(
self,
source_positions: torch.Tensor,
pose: torch.Tensor | None,
num_anchors: int,
) -> torch.Tensor:
num_anchors = max(0, int(num_anchors))
if num_anchors == 0:
return source_positions[:0]
if source_positions.numel() <= num_anchors or pose is None:
return source_positions[:num_anchors]
poses = pose.float()
pairwise = torch.cdist(poses, poses)
if not bool((pairwise > 0).any().item()):
return source_positions[:num_anchors]
available = torch.ones((int(source_positions.numel()),), device=poses.device, dtype=torch.bool)
if num_anchors == 1:
selected = [int(pairwise.mean(dim=1).argmax().item())]
else:
first, second = divmod(int(pairwise.argmax().item()), int(pairwise.shape[1]))
selected = [int(first), int(second)]
for idx in selected:
available[idx] = False
dists = pairwise[selected].min(dim=0).values
dists = dists.masked_fill(~available, float("-inf"))
for _ in range(num_anchors - len(selected)):
farthest = int(dists.argmax().item())
if not bool(available[farthest].item()):
break
selected.append(farthest)
available[farthest] = False
d_new = pairwise[farthest]
dists = torch.minimum(dists, d_new)
dists = dists.masked_fill(~available, float("-inf"))
return source_positions[torch.tensor(sorted(selected), device=source_positions.device)]
def _build_streaming_cache_records(
self,
source_latents: torch.Tensor,
source_frame_indices: torch.Tensor,
source_is_generated: torch.Tensor | None,
pose: torch.Tensor | None,
action: torch.Tensor | None,
allow_generated_anchor: bool,
anchor_indices: list[int],
anchor_pool_h: int,
anchor_pool_w: int,
anchor_diverse: bool,
token_patch_size: int,
) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]:
if source_latents.ndim != 5:
raise ValueError("source_latents must have shape (T,B,C,H,W)")
if source_frame_indices.ndim != 2:
raise ValueError("source_frame_indices must have shape (T,B)")
T_src, B = source_frame_indices.shape
if source_latents.shape[:2] != (T_src, B):
raise ValueError("source_latents and source_frame_indices must share T/B dimensions")
_, _, _, latent_H, latent_W = source_latents.shape
src_h, src_w = self._projected_spatial_grid_size(
latent_H,
latent_W,
self.dememwm_anchor_proj,
token_patch_size,
)
param = next(iter(self.dememwm_anchor_proj.parameters()))
project_device = param.device
project_dtype = param.dtype
hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0])
generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device)
anchor_banks: list[CausalMemoryBank] = []
revisit_banks: list[CausalMemoryBank] = []
def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int):
if tensor is None or tensor.ndim < 2:
return None
pos = positions.to(device=tensor.device)
if tensor.shape[0] == T_src and tensor.shape[1] == B:
return tensor[pos, batch_idx]
if tensor.shape[0] == B and tensor.shape[1] == T_src:
return tensor[batch_idx, pos]
return None
def _metadata_subset(positions: torch.Tensor, batch_idx: int):
return {"dememwm_revisit_metadata_only": True}
def _pose_subset(positions: torch.Tensor, batch_idx: int):
return _tensor_subset(pose, positions, batch_idx)
def _add_anchor_records(bank: CausalMemoryBank, batch_idx: int, positions: torch.Tensor, generated_anchor: bool) -> None:
if positions.numel() == 0:
return
projected = self._project_latent_patch_tokens(
source_latents.index_select(0, positions.to(device=source_latents.device))[:, batch_idx:batch_idx + 1].to(device=project_device, dtype=project_dtype),
self.dememwm_anchor_proj,
token_patch_size,
)[0]
src_frames = source_frame_indices[:, batch_idx]
for local_idx, source_pos in enumerate(positions):
source_pos_i = int(source_pos.item())
anchor_tokens = self._spatial_pool_tokens(projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w)
n_slots = anchor_tokens.shape[0]
record_mask = torch.ones((n_slots,), device=anchor_tokens.device, dtype=torch.bool)
if generated_anchor:
bank.add_generated_records(
anchor_tokens.unsqueeze(0),
record_mask.unsqueeze(0),
src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device),
source_type=MemorySourceType.GENERATED,
)
else:
bank.add_prefix_anchors(
anchor_tokens.unsqueeze(0),
record_mask.unsqueeze(0),
src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device),
slots_per_anchor=n_slots,
)
for batch_idx in range(B):
anchor_bank = CausalMemoryBank()
revisit_bank = CausalMemoryBank()
src_frames = source_frame_indices[:, batch_idx]
if generated is None:
non_generated = torch.ones_like(src_frames, dtype=torch.bool)
else:
non_generated = ~generated[:, batch_idx]
source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
if source_positions.numel() > 0:
if anchor_diverse:
anchor_source_positions = source_positions[source_positions < self._context_frame_count()]
if anchor_source_positions.numel() > 0:
anchor_pose = _pose_subset(anchor_source_positions, batch_idx)
selected_anchor_positions = self._select_diverse_anchor_positions(
anchor_source_positions, anchor_pose, len(anchor_indices)
)
else:
selected_anchor_positions = source_positions[:0]
else:
selected_list = []
for anchor_idx in anchor_indices:
if 0 <= int(anchor_idx) < source_positions.numel():
selected_list.append(source_positions[int(anchor_idx)])
selected_anchor_positions = torch.stack(selected_list).long() if selected_list else source_positions[:0]
if selected_anchor_positions.numel() > 0:
_add_anchor_records(anchor_bank, batch_idx, selected_anchor_positions.long(), False)
dummy_tokens = torch.zeros((1, hidden_size), device=source_frame_indices.device, dtype=project_dtype)
dummy_mask = torch.ones((1,), device=source_frame_indices.device, dtype=torch.bool)
for prefix, positions, source_type, is_generated in (
("prefix", source_positions, MemorySourceType.PREFIX_GT, False),
(
"generated",
torch.empty(0, device=source_frame_indices.device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(),
MemorySourceType.GENERATED,
True,
),
):
if positions.numel() == 0:
continue
for source_pos in positions.to(device=source_frame_indices.device, dtype=torch.long):
source_pos_i = int(source_pos.item())
frame_index = src_frames[source_pos_i]
frame = int(frame_index.detach().item())
frame_pos = source_pos.reshape(1)
revisit_bank.add_frame_record(
dummy_tokens,
dummy_mask,
frame_index,
pose=_pose_subset(frame_pos, batch_idx),
source_type=source_type,
metadata=_metadata_subset(frame_pos, batch_idx),
is_generated=is_generated,
record_id=f"{prefix}_revisit_b{batch_idx}_f{frame}",
)
if allow_generated_anchor and generated is not None and anchor_indices:
generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten()
_add_anchor_records(anchor_bank, batch_idx, generated_positions[:len(anchor_indices)].long(), True)
anchor_banks.append(anchor_bank)
revisit_banks.append(revisit_bank)
return anchor_banks, revisit_banks
def _build_causal_memory_banks(
self,
anchor_projected: torch.Tensor,
revisit_projected: torch.Tensor,
source_frame_indices: torch.Tensor,
source_is_generated: torch.Tensor | None,
pose: torch.Tensor | None,
action: torch.Tensor | None,
allow_generated_anchor: bool,
anchor_indices: list[int],
anchor_pool_h: int,
anchor_pool_w: int,
revisit_pool_h: int,
revisit_pool_w: int,
src_h: int,
src_w: int,
) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]:
# projected tensors use the same batch/source convention as
# _project_latent_patch_tokens: (B, T_src, T_frame, D), while frame indices are
# (T_src, B). Build separate banks because anchor and revisit records
# come from different projections.
if anchor_projected.ndim != 4 or revisit_projected.ndim != 4:
raise ValueError("anchor/revisit projected tensors must have shape (B,T_src,T_frame,D)")
B, T_src, _, _ = anchor_projected.shape
if revisit_projected.shape[:3] != anchor_projected.shape[:3]:
raise ValueError("anchor/revisit projected tensors must share batch/source/token dimensions")
generated = None if source_is_generated is None else source_is_generated.bool().to(source_frame_indices.device)
anchor_banks: list[CausalMemoryBank] = []
revisit_banks: list[CausalMemoryBank] = []
def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int):
if tensor is None or tensor.ndim < 2:
return None
if tensor.shape[0] == T_src and tensor.shape[1] == B:
return tensor[positions, batch_idx]
if tensor.shape[0] == B and tensor.shape[1] == T_src:
return tensor[batch_idx, positions]
return None
def _metadata_subset(positions: torch.Tensor, batch_idx: int):
return {}
def _pose_subset(positions: torch.Tensor, batch_idx: int):
return _tensor_subset(pose, positions, batch_idx)
for batch_idx in range(B):
anchor_bank = CausalMemoryBank()
revisit_bank = CausalMemoryBank()
src_frames = source_frame_indices[:, batch_idx]
if generated is None:
non_generated = torch.ones_like(src_frames, dtype=torch.bool)
else:
non_generated = ~generated[:, batch_idx]
source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
if source_positions.numel() > 0:
selected_anchor_positions = []
for anchor_idx in anchor_indices:
if 0 <= int(anchor_idx) < source_positions.numel():
selected_anchor_positions.append(source_positions[int(anchor_idx)])
for source_pos in selected_anchor_positions:
source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos)
anchor_tokens = self._spatial_pool_tokens(
anchor_projected[batch_idx, source_pos_i],
anchor_pool_h, anchor_pool_w, src_h, src_w,
)
n_slots = anchor_tokens.shape[0]
record_mask = torch.ones(
(n_slots,),
device=anchor_projected.device,
dtype=torch.bool,
)
anchor_bank.add_prefix_anchors(
anchor_tokens.unsqueeze(0),
record_mask.unsqueeze(0),
src_frames[source_pos_i:source_pos_i + 1],
slots_per_anchor=n_slots,
)
for source_pos in source_positions:
source_pos_i = int(source_pos.item())
frame_index = src_frames[source_pos_i]
frame = int(frame_index.detach().item())
frame_pos = source_pos.reshape(1)
frame_tokens = self._spatial_pool_tokens(
revisit_projected[batch_idx, source_pos_i],
revisit_pool_h, revisit_pool_w, src_h, src_w,
)
frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool)
revisit_bank.add_frame_record(
frame_tokens,
frame_mask,
frame_index,
pose=_pose_subset(frame_pos, batch_idx),
source_type=MemorySourceType.PREFIX_GT,
metadata=_metadata_subset(frame_pos, batch_idx),
is_generated=False,
record_id=f"prefix_revisit_b{batch_idx}_f{frame}",
)
if generated is not None:
generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten()
if generated_positions.numel() > 0:
for source_pos in generated_positions:
source_pos_i = int(source_pos.item())
frame_index = src_frames[source_pos_i]
frame = int(frame_index.detach().item())
frame_pos = source_pos.reshape(1)
frame_tokens = self._spatial_pool_tokens(
revisit_projected[batch_idx, source_pos_i],
revisit_pool_h, revisit_pool_w, src_h, src_w,
)
frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool)
revisit_bank.add_frame_record(
frame_tokens,
frame_mask,
frame_index,
pose=_pose_subset(frame_pos, batch_idx),
source_type=MemorySourceType.GENERATED,
metadata=_metadata_subset(frame_pos, batch_idx),
is_generated=True,
record_id=f"generated_revisit_b{batch_idx}_f{frame}",
)
if allow_generated_anchor:
for source_pos in generated_positions[:len(anchor_indices)]:
source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos)
anchor_tokens = self._spatial_pool_tokens(
anchor_projected[batch_idx, source_pos_i],
anchor_pool_h, anchor_pool_w, src_h, src_w,
)
record_mask = torch.ones((anchor_tokens.shape[0],), device=anchor_projected.device, dtype=torch.bool)
anchor_bank.add_generated_records(
anchor_tokens.unsqueeze(0),
record_mask.unsqueeze(0),
src_frames[source_pos_i:source_pos_i + 1],
source_type=MemorySourceType.GENERATED,
)
anchor_banks.append(anchor_bank)
revisit_banks.append(revisit_bank)
return anchor_banks, revisit_banks
def _build_preselected_causal_memory_banks(
self,
committed_latents: torch.Tensor,
source_frame_indices: torch.Tensor,
source_is_generated: torch.Tensor | None,
pose: torch.Tensor | None,
action: torch.Tensor | None,
target_frame_indices: torch.Tensor,
target_pose: torch.Tensor | None,
target_action: torch.Tensor | None,
target_video_ids,
allow_generated_anchor: bool,
anchor_indices: list[int],
anchor_pool_h: int,
anchor_pool_w: int,
anchor_diverse: bool,
revisit_pool_h: int,
revisit_pool_w: int,
revisit_max_frames: int,
exclude_local_context_frames: int,
fov_overlap_threshold,
plucker_weight: float,
revisit_retrieval_kwargs: dict | None,
token_patch_size: int,
) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]:
if committed_latents.ndim != 5:
raise ValueError("committed_latents must have shape (T_src,B,C,H,W)")
T_src, B, _, H, W = committed_latents.shape
if source_frame_indices.shape != (T_src, B):
raise ValueError("source_frame_indices must have shape (T_src,B)")
if target_frame_indices.ndim == 1:
target_frame_indices = target_frame_indices[:, None]
if target_frame_indices.shape[1] != B:
raise ValueError("target_frame_indices must have batch dimension B")
T_tgt = target_frame_indices.shape[0]
stream_device = committed_latents.device
hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0])
src_h, src_w = self._projected_spatial_grid_size(
H,
W,
self.dememwm_anchor_proj,
token_patch_size,
)
tokens_per_frame = src_h * src_w
generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device)
anchor_banks: list[CausalMemoryBank] = []
revisit_banks: list[CausalMemoryBank] = []
dummy_tokens = committed_latents.new_zeros((1, hidden_size))
dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool)
preselection_candidate_count = 0
preselection_valid_candidate_label_count = 0
preselection_selected_count = 0
projected_anchor_frames = 0
projected_revisit_frames = 0
projected_revisit_records = 0
retrieval_kwargs = dict(revisit_retrieval_kwargs or {})
# Pre-convert pose tensors to stream_device once so that the
# _tensor_subset / _target_tensor closures below never trigger a
# device transfer on every call.
if pose is not None:
pose = pose.to(device=stream_device)
if target_pose is not None:
target_pose = target_pose.to(device=stream_device)
def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int):
if tensor is None or tensor.ndim < 2:
return None
if tensor.shape[0] == T_src and tensor.shape[1] == B:
return tensor[positions, batch_idx]
if tensor.shape[0] == B and tensor.shape[1] == T_src:
return tensor[batch_idx, positions]
return None
def _target_tensor(tensor: torch.Tensor | None, batch_idx: int, target_idx: int):
if tensor is None or tensor.ndim < 2:
return None
if tensor.shape[0] == T_tgt and tensor.shape[1] == B:
return tensor[target_idx, batch_idx]
if tensor.shape[0] == B and tensor.shape[1] == T_tgt:
return tensor[batch_idx, target_idx]
return None
def _target_video_id(batch_idx: int, target_idx: int):
if target_video_ids is None:
return None
if torch.is_tensor(target_video_ids):
ids = target_video_ids.detach().cpu()
if ids.ndim == 0:
return ids.item()
if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B:
return ids[target_idx, batch_idx].item()
if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt:
return ids[batch_idx, target_idx].item()
return None
if isinstance(target_video_ids, (list, tuple)):
if len(target_video_ids) == B:
return target_video_ids[batch_idx]
if len(target_video_ids) == T_tgt:
row = target_video_ids[target_idx]
if isinstance(row, (list, tuple)) and len(row) == B:
return row[batch_idx]
return row
return target_video_ids
def _metadata_subset(positions: torch.Tensor, batch_idx: int):
return {}
def _pose_subset(positions: torch.Tensor, batch_idx: int):
return _tensor_subset(pose, positions, batch_idx)
def _candidate_record(
*,
batch_idx: int,
frame_position: torch.Tensor,
source_type: MemorySourceType,
is_generated: bool,
record_id: str,
) -> MemoryRecord:
frame_values = source_frame_indices[frame_position, batch_idx].to(device=stream_device)
frame = int(frame_values.reshape(-1)[0].item())
return MemoryRecord(
tokens=dummy_tokens,
mask=dummy_mask,
source_start=frame,
source_end=frame + 1,
frame_indices=frame_values.reshape(1),
pose=_pose_subset(frame_position, batch_idx),
source_type=source_type,
is_generated=bool(is_generated),
chunk_id=record_id,
metadata=_metadata_subset(frame_position, batch_idx),
)
for batch_idx in range(B):
anchor_bank = CausalMemoryBank()
revisit_bank = CausalMemoryBank()
src_frames = source_frame_indices[:, batch_idx]
if generated is None:
non_generated = torch.ones_like(src_frames, dtype=torch.bool)
else:
non_generated = ~generated[:, batch_idx]
source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long)
if anchor_indices and source_positions.numel() > 0:
if anchor_diverse:
anchor_source_positions = source_positions[source_positions < self._context_frame_count()]
if anchor_source_positions.numel() > 0:
anchor_pose = _pose_subset(anchor_source_positions, batch_idx)
anchor_positions = self._select_diverse_anchor_positions(
anchor_source_positions, anchor_pose, len(anchor_indices)
).to(device=stream_device, dtype=torch.long)
else:
selected_anchor_positions = []
for anchor_idx in anchor_indices:
if 0 <= int(anchor_idx) < source_positions.numel():
selected_anchor_positions.append(source_positions[int(anchor_idx)])
if selected_anchor_positions:
anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
if anchor_positions.numel() > 0:
projected_anchor_frames += int(anchor_positions.numel())
anchor_projected = self._project_latent_patch_tokens(
committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
self.dememwm_anchor_proj,
token_patch_size,
)[0]
for local_idx, source_pos in enumerate(anchor_positions):
source_pos_i = int(source_pos.item())
anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w)
n_slots = anchor_tokens.shape[0]
record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool)
anchor_bank.add_prefix_anchors(
anchor_tokens.unsqueeze(0),
record_mask.unsqueeze(0),
src_frames[source_pos_i:source_pos_i + 1],
slots_per_anchor=n_slots,
)
candidate_records: list[MemoryRecord] = []
candidate_positions: dict[str, torch.Tensor] = {}
src_frames_cpu = src_frames.detach().cpu()
target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long)
latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames)
for prefix, positions, source_type, is_generated in (
("prefix", source_positions, MemorySourceType.PREFIX_GT, False),
(
"generated",
torch.empty(0, device=stream_device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(),
MemorySourceType.GENERATED,
True,
),
):
if positions.numel() == 0 or latest_valid_source_frame_exclusive <= 0:
continue
positions_cpu = positions.detach().cpu().to(dtype=torch.long)
for frame_position_cpu in positions_cpu:
frame = int(src_frames_cpu[int(frame_position_cpu.item())].item())
if frame >= latest_valid_source_frame_exclusive:
continue
frame_position = frame_position_cpu.reshape(1).to(device=stream_device, dtype=torch.long)
record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}"
candidate_positions[record_id] = frame_position
candidate_records.append(_candidate_record(
batch_idx=batch_idx,
frame_position=frame_position,
source_type=source_type,
is_generated=is_generated,
record_id=record_id,
))
selected_frame_record_ids: set[str] = set()
selected_frame_metadata: dict[str, dict] = {}
for target_idx in range(T_tgt):
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
result = deterministic_revisit_retrieval(
candidate_records,
target_frame=target_frame,
target_pose=_target_tensor(target_pose, batch_idx, target_idx),
target_summary=None,
topk=revisit_max_frames,
exclude_local_context_frames=exclude_local_context_frames,
fov_overlap_threshold=fov_overlap_threshold,
plucker_weight=plucker_weight,
target_video_id=_target_video_id(batch_idx, target_idx),
**retrieval_kwargs,
)
preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0)))
preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0))
preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0)))
for selected_record in result.records:
if selected_record.chunk_id is None:
continue
record_id = str(selected_record.chunk_id)
selected_frame_record_ids.add(record_id)
selected_frame_metadata[record_id] = dict(selected_record.metadata)
for record in candidate_records:
if record.chunk_id not in selected_frame_record_ids:
continue
record_id = str(record.chunk_id)
frame_position = candidate_positions[record_id]
projected_revisit_records += 1
projected_revisit_frames += int(frame_position.numel())
revisit_projected = self._project_latent_patch_tokens(
committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1],
self.dememwm_revisit_proj,
token_patch_size,
)[0]
frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, src_h, src_w)
frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool)
record_metadata = dict(record.metadata)
record_metadata.update(selected_frame_metadata.get(record_id, {}))
revisit_bank.add_frame_record(
frame_tokens,
frame_mask,
record.frame_indices.reshape(-1)[0],
pose=record.pose,
source_type=record.source_type,
metadata=record_metadata,
is_generated=record.is_generated,
record_id=record.chunk_id,
)
anchor_banks.append(anchor_bank)
revisit_banks.append(revisit_bank)
diagnostics = {
"preselected_anchor_projected_frame_count": projected_anchor_frames,
"preselected_revisit_projected_frame_count": projected_revisit_frames,
"preselected_revisit_projected_frame_record_count": projected_revisit_records,
"preselected_revisit_candidate_frame_count": preselection_candidate_count,
"preselected_revisit_candidate_count": preselection_candidate_count,
"preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count,
"preselected_revisit_selected_frame_count": preselection_selected_count,
"preselected_revisit_selected_count": preselection_selected_count,
}
return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
def _causal_cached_revisit_records(
self,
records: Iterable[MemoryRecord],
target_frame: int,
) -> list[MemoryRecord]:
target_frame = int(target_frame)
return [record for record in records if int(record.source_end) <= target_frame]
def _records_to_stream(
self,
records,
target_slots: int,
hidden_size: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, int]:
target_slots = max(0, int(target_slots))
record_list = list(records)
stacked_tokens, stacked_mask = stack_record_tokens(record_list, target_slots=target_slots)
max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1)
if stacked_tokens is None or stacked_mask is None or target_slots == 0:
tokens = torch.zeros((target_slots, hidden_size), device=device, dtype=dtype)
mask = torch.zeros((target_slots,), device=device, dtype=torch.bool)
return tokens, mask, max_source_frame
n = min(target_slots, stacked_tokens.shape[0])
filled = stacked_tokens[:n].to(device=device, dtype=dtype)
filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool)
if n < target_slots:
pad = filled.new_zeros(target_slots - n, hidden_size)
pad_mask = torch.zeros(target_slots - n, device=device, dtype=torch.bool)
tokens = torch.cat([filled, pad], dim=0)
mask = torch.cat([filled_mask, pad_mask], dim=0)
else:
tokens = filled
mask = filled_mask
return tokens, mask, max_source_frame
def _project_streaming_revisit_records(
self,
*,
cache: StreamingCache,
batch_idx: int,
records: list[MemoryRecord],
device: torch.device,
dtype: torch.dtype,
token_patch_size: int,
revisit_pool_h: int,
revisit_pool_w: int,
projection_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord],
) -> list[MemoryRecord]:
projected_records: list[MemoryRecord] = []
for record in records:
if not bool(record.metadata.get("dememwm_revisit_metadata_only", False)):
projected_records.append(record)
continue
selected_frame_index = record.metadata.get("dememwm_selected_frame_index")
if selected_frame_index is None:
best_frame_idx = record.frame_indices[torch.argmax(record.frame_indices)].reshape(1)
else:
best_frame_idx = torch.as_tensor(
[int(selected_frame_index)],
device=record.frame_indices.device,
dtype=record.frame_indices.dtype,
)
key = (
int(batch_idx),
str(record.chunk_id or ""),
int(record.source_start),
int(record.source_end),
int(best_frame_idx.detach().cpu().reshape(-1)[0].item()),
bool(record.is_generated),
)
cached = projection_cache.get(key)
if cached is not None:
projected_records.append(cached)
continue
raw_latents = cache.raw_latents_for_frames(
batch_idx=batch_idx,
frame_indices=best_frame_idx,
device=device,
dtype=dtype,
)
revisit_projected = self._project_latent_patch_tokens(
raw_latents,
self.dememwm_revisit_proj,
token_patch_size,
)[0]
_proj_src_h, _proj_src_w = self._projected_spatial_grid_size(
raw_latents.shape[3],
raw_latents.shape[4],
self.dememwm_revisit_proj,
token_patch_size,
)
frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, _proj_src_h, _proj_src_w)
frame_mask = torch.ones((frame_tokens.shape[0],), device=device, dtype=torch.bool)
metadata = {
key: (value.to(device=device) if torch.is_tensor(value) else value)
for key, value in record.metadata.items()
}
metadata["dememwm_revisit_metadata_only"] = False
projected = MemoryRecord(
tokens=frame_tokens,
mask=frame_mask,
source_start=int(record.source_start),
source_end=int(record.source_end),
frame_indices=record.frame_indices.to(device=device),
pose=None if record.pose is None else record.pose.to(device=device),
source_type=record.source_type,
is_generated=bool(record.is_generated),
score=record.score,
chunk_id=record.chunk_id,
metadata=metadata,
)
projection_cache[key] = projected
projected_records.append(projected)
return projected_records
def build_memory_streams(
self,
committed_latents: torch.Tensor | None,
source_frame_indices: torch.Tensor | None,
target_frame_indices: torch.Tensor | None = None,
pose: torch.Tensor | None = None,
target_pose: torch.Tensor | None = None,
action: torch.Tensor | None = None,
target_action: torch.Tensor | None = None,
target_video_ids=None,
source_is_generated: torch.Tensor | None = None,
denoising_fraction: float | None = None,
noise_bucket: str | None = None,
noise_bucket_ids: torch.Tensor | None = None,
streaming_cache: StreamingCache | None = None,
) -> MemoryStreamTensors:
if target_frame_indices is None:
if source_frame_indices is None:
raise ValueError("target_frame_indices or source_frame_indices is required")
target_frame_indices = source_frame_indices
memory_cfg = self._memory_cfg()
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
injection_cfg = self._cfg_get(memory_cfg, "injection", None)
contract_diag = self._validate_config_contract()
gate_state = self._effective_gate_state(
denoising_fraction=denoising_fraction,
noise_bucket=noise_bucket,
)
anchor_config_enabled = gate_state["anchor_config_enabled"]
dynamic_config_enabled = gate_state["dynamic_config_enabled"]
revisit_config_enabled = gate_state["revisit_config_enabled"]
curriculum_state = gate_state["curriculum_state"]
eval_ablation_enabled = gate_state["eval_ablation_enabled"]
eval_ablation_branch = gate_state["eval_ablation_branch"]
resolved_noise_bucket = gate_state["resolved_noise_bucket"]
gates = gate_state["gates"]
anchor_effective_enabled = gate_state["anchor_effective_enabled"]
dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"]
force_revisit_off = gate_state["force_revisit_off"]
force_revisit_on = gate_state["force_revisit_on"]
token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2))
anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])]
anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None)
pool_latent_h = int(committed_latents.shape[-2]) if committed_latents is not None else int(self.x_stacked_shape[-2])
pool_latent_w = int(committed_latents.shape[-1]) if committed_latents is not None else int(self.x_stacked_shape[-1])
anchor_src_h, anchor_src_w = self._projected_spatial_grid_size(
pool_latent_h,
pool_latent_w,
self.dememwm_anchor_proj,
token_patch_size,
)
anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size(
anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8
)
anchor_num_tokens = len(anchor_indices) * anchor_pool_h * anchor_pool_w
anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False))
allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False))
revisit_max_frames = int(self._cfg_get(revisit_cfg, "max_frames", 2))
revisit_compress_cfg = self._cfg_get(revisit_cfg, "compress", None)
revisit_src_h, revisit_src_w = self._projected_spatial_grid_size(
pool_latent_h,
pool_latent_w,
self.dememwm_revisit_proj,
token_patch_size,
)
revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size(
revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8
)
revisit_target_slots = revisit_max_frames * revisit_pool_h * revisit_pool_w
recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8))
dynamic_recent_exclusion_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4))
revisit_context_window_exclusion_frames = self._local_context_exclusion_frames()
fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10))
revisit_retrieval_kwargs = {
"high_quality_fov_threshold": high_quality_fov_threshold,
"fov_half_h": float(self._cfg_get(revisit_cfg, "fov_half_h", 52.5)),
"fov_half_v": float(self._cfg_get(revisit_cfg, "fov_half_v", 37.5)),
"fov_yaw_samples": int(self._cfg_get(revisit_cfg, "fov_yaw_samples", 25)),
"fov_pitch_samples": int(self._cfg_get(revisit_cfg, "fov_pitch_samples", 20)),
"fov_depth_samples": int(self._cfg_get(revisit_cfg, "fov_depth_samples", 20)),
"fov_radius": float(self._cfg_get(revisit_cfg, "fov_radius", 30.0)),
"pose_preselect_topk": self._cfg_get(revisit_cfg, "pose_preselect_topk", 64),
"plucker_grid_h": int(self._cfg_get(revisit_cfg, "plucker_grid_h", 4)),
"plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)),
"plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)),
}
preselection_diag = {}
use_cache_revisit_records = False
revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None
cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None
cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0}
if committed_latents is not None:
stream_device = committed_latents.device
stream_dtype = committed_latents.dtype
else:
param = next(iter(self.dememwm_anchor_proj.parameters()))
stream_device = param.device
stream_dtype = param.dtype
target_frame_indices = target_frame_indices.to(device=stream_device)
if target_frame_indices.ndim == 1:
target_frame_indices = target_frame_indices[:, None]
use_cache_records = cache is not None and cache.keep_compressed_records and cache.record_count > 0
dynamic_latents = committed_latents if dynamic_effective_enabled else None
dynamic_frame_indices = source_frame_indices if dynamic_effective_enabled else None
dynamic_generated = source_is_generated if dynamic_effective_enabled else None
dynamic_pose = pose if dynamic_effective_enabled else None
if dynamic_effective_enabled and cache is not None and cache.raw_frame_slots > 0:
raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents(
device=stream_device,
dtype=stream_dtype,
max_recent_frames=recent_frames,
target_frame_indices=target_frame_indices,
exclude_latest_local_frames=dynamic_recent_exclusion_frames,
)
if raw_latents is not None:
dynamic_latents = raw_latents
dynamic_frame_indices = raw_frames
dynamic_generated = raw_generated
dynamic_pose = raw_pose
if use_cache_records:
B = target_frame_indices.shape[1]
hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
anchor_banks = (
cache.memory_banks("anchor", device=stream_device, dtype=stream_dtype, batch_size=B)
if anchor_effective_enabled else [CausalMemoryBank() for _ in range(B)]
)
revisit_banks = [CausalMemoryBank() for _ in range(B)]
revisit_record_batches = (
[cache.records_for_batch("revisit", batch_idx) for batch_idx in range(B)]
if revisit_stage_config_enabled else [tuple() for _ in range(B)]
)
use_cache_revisit_records = bool(revisit_stage_config_enabled)
if dynamic_latents is not None and dynamic_latents.ndim == 5 and dynamic_latents.shape[0] > 0:
tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size(
dynamic_latents.shape[-2],
dynamic_latents.shape[-1],
self.dememwm_anchor_proj,
token_patch_size,
)
tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w
else:
latent_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 0
latent_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 0
tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size(
latent_h,
latent_w,
self.dememwm_anchor_proj,
token_patch_size,
)
tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w
else:
if committed_latents is None or source_frame_indices is None:
raise ValueError("committed_latents/source_frame_indices are required when no streaming cache records are available")
B = committed_latents.shape[1]
hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
target_pose_source = target_pose if target_pose is not None else pose
anchor_banks, revisit_banks, tokens_per_frame, preselection_diag = self._build_preselected_causal_memory_banks(
committed_latents,
source_frame_indices.to(device=stream_device),
None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool),
None if pose is None else pose.to(device=stream_device),
None,
target_frame_indices,
None if target_pose_source is None else target_pose_source.to(device=stream_device),
None,
target_video_ids,
allow_generated_anchor,
anchor_indices,
anchor_pool_h,
anchor_pool_w,
anchor_diverse,
revisit_pool_h,
revisit_pool_w,
revisit_max_frames,
revisit_context_window_exclusion_frames,
fov_overlap_threshold,
plucker_weight,
revisit_retrieval_kwargs,
token_patch_size,
)
revisit_record_batches = [tuple(bank.records) for bank in revisit_banks]
T_tgt = target_frame_indices.shape[0]
anchor_slots = max(0, anchor_num_tokens)
revisit_slots = max(0, revisit_target_slots)
anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT
anchor_include_generated = allow_generated_anchor
anchor_token_rows = []
anchor_mask_rows = []
anchor_max_rows = []
for batch_idx, anchor_bank in enumerate(anchor_banks):
batch_token_rows = []
batch_mask_rows = []
batch_max_rows = []
for target_idx in range(T_tgt):
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
records = anchor_bank.query(
MemoryBankQuery(
target_frame=target_frame,
source_type=anchor_source_type,
include_generated=anchor_include_generated,
max_records=len(anchor_indices),
)
)
anchor_bank.assert_causal(target_frame, records)
stream_tokens, stream_mask, max_source_frame = self._records_to_stream(
records,
anchor_slots,
hidden_size,
stream_device,
stream_dtype,
)
batch_token_rows.append(stream_tokens)
batch_mask_rows.append(stream_mask)
batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long))
anchor_token_rows.append(torch.stack(batch_token_rows, dim=0))
anchor_mask_rows.append(torch.stack(batch_mask_rows, dim=0))
anchor_max_rows.append(torch.stack(batch_max_rows, dim=0))
anchor_tokens = torch.stack(anchor_token_rows, dim=0)
anchor_mask = torch.stack(anchor_mask_rows, dim=0)
anchor_max = torch.stack(anchor_max_rows, dim=0)
if dynamic_latents is None or dynamic_frame_indices is None or dynamic_latents.shape[0] == 0:
_fallback_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 18
_fallback_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 32
dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w)
dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype)
dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool)
dynamic_diag = {
"selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
"max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
"generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device),
"dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
"dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
"dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
"dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
}
else:
# Pre-select dynamic source frame positions using only frame index metadata
# before touching latents, so we pass a small slice instead of the full
# 1000-frame tensor to the compressor.
_dfi = dynamic_frame_indices.to(device=stream_device)
_max_src = self.dememwm_dynamic_compressor.max_source_frames
_needed: list[int] = []
for _b in range(B):
for _j in range(T_tgt):
_target = int(target_frame_indices[_j, _b].item())
_valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
_needed.extend(_valid[-_max_src:].tolist())
if _needed:
_needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long)
_dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx)
_dynamic_fi_small = _dfi.index_select(0, _needed_idx)
_dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None
_dynamic_gen_small = (
dynamic_generated.to(device=stream_device, dtype=torch.bool).index_select(0, _needed_idx)
if dynamic_generated is not None else None
)
else:
_dynamic_latents_small = dynamic_latents[:0]
_dynamic_fi_small = _dfi[:0]
_dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None
_dynamic_gen_small = None
dynamic_tokens, dynamic_mask, dynamic_diag = self.dememwm_dynamic_compressor(
_dynamic_latents_small,
_dynamic_fi_small,
_dynamic_pose_small,
target_frame_indices,
_dynamic_gen_small,
exclude_latest_local_frames=dynamic_recent_exclusion_frames,
)
dynamic_min_gap_tensor = torch.as_tensor(
dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)),
device=stream_device,
)
dynamic_max_gap_tensor = torch.as_tensor(
dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)),
device=stream_device,
)
dynamic_gap_valid = dynamic_min_gap_tensor >= 0
dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1
dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0
dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1
def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int):
if tensor is None or tensor.ndim < 2:
return None
tensor_dev = tensor.to(device=stream_device)
if tensor_dev.shape[0] == T_tgt and tensor_dev.shape[1] == B:
return tensor_dev[target_idx, batch_idx]
if tensor_dev.shape[0] == B and tensor_dev.shape[1] == T_tgt:
return tensor_dev[batch_idx, target_idx]
return None
def _target_video_id_or_none(batch_idx: int, target_idx: int):
if target_video_ids is None:
return None
if torch.is_tensor(target_video_ids):
ids = target_video_ids.detach().cpu()
if ids.ndim == 0:
return ids.item()
if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B:
return ids[target_idx, batch_idx].item()
if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt:
return ids[batch_idx, target_idx].item()
return None
if isinstance(target_video_ids, (list, tuple)):
if len(target_video_ids) == B:
return target_video_ids[batch_idx]
if len(target_video_ids) == T_tgt:
row = target_video_ids[target_idx]
if isinstance(row, (list, tuple)) and len(row) == B:
return row[batch_idx]
return row
return target_video_ids
target_pose_source = target_pose if target_pose is not None else pose
revisit_token_rows = []
revisit_mask_rows = []
revisit_max_rows = []
valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
revisit_result_diagnostics = []
projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {}
if revisit_record_batches is None:
revisit_record_batches = [tuple(bank.records) for bank in revisit_banks]
for batch_idx in range(B):
revisit_bank = revisit_banks[batch_idx]
batch_token_rows = []
batch_mask_rows = []
batch_max_rows = []
for target_idx in range(T_tgt):
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
if use_cache_revisit_records:
candidate_records = self._causal_cached_revisit_records(
revisit_record_batches[batch_idx],
target_frame,
)
else:
candidate_records = revisit_bank.query(
MemoryBankQuery(
target_frame=target_frame,
include_generated=True,
)
)
result = deterministic_revisit_retrieval(
candidate_records,
target_frame=target_frame,
target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
target_summary=None,
topk=revisit_max_frames,
exclude_local_context_frames=revisit_context_window_exclusion_frames,
fov_overlap_threshold=fov_overlap_threshold,
plucker_weight=plucker_weight,
target_video_id=_target_video_id_or_none(batch_idx, target_idx),
**revisit_retrieval_kwargs,
)
selected_records = result.records
if use_cache_revisit_records and selected_records:
selected_records = self._project_streaming_revisit_records(
cache=cache,
batch_idx=batch_idx,
records=selected_records,
device=stream_device,
dtype=stream_dtype,
token_patch_size=token_patch_size,
revisit_pool_h=revisit_pool_h,
revisit_pool_w=revisit_pool_w,
projection_cache=projected_revisit_record_cache,
)
revisit_result_diagnostics.append(result.diagnostics)
revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0)))
revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0)))
revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
revisit_bank.assert_causal(target_frame, selected_records)
if selected_records:
valid_revisit_mask[batch_idx, target_idx] = True
stream_tokens, stream_mask, max_source_frame = self._records_to_stream(
selected_records,
revisit_slots,
hidden_size,
stream_device,
stream_dtype,
)
revisit_causal_max[batch_idx, target_idx] = max_source_frame
if eval_corruption_enabled:
stream_tokens, was_corrupted = apply_revisit_eval_corruption(
tokens=stream_tokens,
mask=stream_mask,
branch=eval_ablation_branch,
target_frame=target_frame,
)
eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted)
batch_token_rows.append(stream_tokens)
batch_mask_rows.append(stream_mask)
batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long))
revisit_token_rows.append(torch.stack(batch_token_rows, dim=0))
revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0))
revisit_max_rows.append(torch.stack(batch_max_rows, dim=0))
revisit_tokens = torch.stack(revisit_token_rows, dim=0)
revisit_mask = torch.stack(revisit_mask_rows, dim=0)
revisit_max = torch.stack(revisit_max_rows, dim=0)
if anchor_tokens.shape[-2] != anchor_num_tokens:
raise AssertionError(f"anchor slot count mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}")
if dynamic_latents is not None and dynamic_latents.shape[0] > 0:
_expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target(
int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1])
)
if dynamic_tokens.shape[-2] != _expected_dyn:
raise AssertionError(f"dynamic slot count mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}")
if revisit_tokens.shape[-2] != revisit_target_slots:
raise AssertionError(f"revisit slot count mismatch: got {revisit_tokens.shape[-2]}, expected {revisit_target_slots}")
anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0
dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0
gate_module = getattr(self, "dememwm_revisit_gate", None)
if gate_module is None:
revisit_gate_raw = torch.ones((B, T_tgt), device=stream_device, dtype=stream_dtype)
else:
revisit_gate_raw = gate_module(
valid_revisit_mask=valid_revisit_mask,
best_selected_fov_overlap=revisit_best_selected_fov_overlap,
best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
selected_gap_frames=revisit_selected_gap_frames,
).to(device=stream_device, dtype=stream_dtype)
valid_revisit_eff_mask = valid_revisit_mask
if not revisit_stage_config_enabled or force_revisit_off:
revisit_gate = torch.zeros_like(revisit_gate_raw)
elif force_revisit_on:
revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * torch.ones_like(revisit_gate_raw)
else:
revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * revisit_gate_raw * float(gates.revisit_gate)
revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item())
if not anchor_effective_enabled:
anchor_mask = torch.zeros_like(anchor_mask)
if not dynamic_effective_enabled:
dynamic_mask = torch.zeros_like(dynamic_mask)
if not revisit_stage_config_enabled:
revisit_mask = torch.zeros_like(revisit_mask)
valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
revisit_candidate_count = torch.zeros_like(revisit_candidate_count)
revisit_selected_count = torch.zeros_like(revisit_selected_count)
revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask)
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
revisit_gate = torch.zeros_like(revisit_gate)
no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask)
revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask)
causal_violation_count = 0
for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max):
if source_max is None:
continue
source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device)
valid = source_max_t >= 0
if valid.any():
causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item())
diagnostics = {
**curriculum_state.diagnostics(),
**getattr(self, "_last_dememwm_freeze_diagnostics", {}),
**contract_diag,
**cache_diag,
**preselection_diag,
**revisit_diag,
"dememwm_stage": gates.stage,
"dememwm_gate_reason": gates.reason,
"anchor_config_enabled": anchor_config_enabled,
"dynamic_config_enabled": dynamic_config_enabled,
"revisit_config_enabled": revisit_config_enabled,
"anchor_effective_enabled": anchor_effective_enabled,
"dynamic_effective_enabled": dynamic_effective_enabled,
"revisit_effective_enabled": revisit_effective_enabled,
"revisit_stage_config_enabled": revisit_stage_config_enabled,
"revisit_gate_raw": revisit_gate_raw.detach(),
"revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
"no_valid_revisit_mask": no_valid_revisit_mask,
"valid_revisit_eff_mask": valid_revisit_eff_mask,
"revisit_candidate_frame_count_per_target": revisit_candidate_count,
"revisit_selected_frame_count_per_target": revisit_selected_count,
"revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
"revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap,
"revisit_selected_gap_frames_per_target": revisit_selected_gap_frames,
"revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0,
"revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()),
**summarize_noise_bucket_diagnostics(
noise_bucket=resolved_noise_bucket,
valid_revisit_mask=valid_revisit_mask,
no_valid_revisit_mask=no_valid_revisit_mask,
noise_bucket_ids=noise_bucket_ids,
),
**summarize_eval_ablation_diagnostics(
enabled=eval_ablation_enabled,
branch=eval_ablation_branch,
valid_revisit_mask=valid_revisit_mask,
no_valid_revisit_mask=no_valid_revisit_mask,
eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None,
),
"token_patch_size": token_patch_size,
"tokens_per_frame": tokens_per_frame,
"anchor_token_slots": int(anchor_tokens.shape[-2]),
"anchor_target_slots": anchor_num_tokens,
"anchor_pool_h": anchor_pool_h,
"anchor_pool_w": anchor_pool_w,
"dynamic_token_slots": int(dynamic_tokens.shape[-2]),
"dynamic_target_slots": int(dynamic_tokens.shape[-2]),
"dynamic_min_gap_to_target": dynamic_min_gap_to_target,
"dynamic_max_gap_to_target": dynamic_max_gap_to_target,
"dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
"revisit_token_slots": int(revisit_tokens.shape[-2]),
"revisit_target_slots": revisit_target_slots,
"revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames,
"revisit_pool_h": revisit_pool_h,
"revisit_pool_w": revisit_pool_w,
"revisit_max_frames": revisit_max_frames,
"anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0,
"dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0,
"revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0,
"causal_violation_count": causal_violation_count,
"anchor_max_source_frame": anchor_max,
"dynamic_max_source_frame": dynamic_diag.get("max_source_frame"),
"revisit_max_source_frame": revisit_max,
"dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"),
}
if eval_corruption_enabled:
diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask
return MemoryStreamTensors(
anchor_tokens=anchor_tokens,
anchor_mask=anchor_mask,
dynamic_tokens=dynamic_tokens,
dynamic_mask=dynamic_mask,
revisit_tokens=revisit_tokens,
revisit_mask=revisit_mask,
anchor_gate=anchor_gate,
dynamic_gate=dynamic_gate,
revisit_gate=revisit_gate,
revisit_gate_raw=revisit_gate_raw,
valid_revisit_mask=valid_revisit_mask,
no_valid_revisit_mask=no_valid_revisit_mask,
diagnostics=diagnostics,
)
def _refresh_stream_gates(
self,
streams: MemoryStreamTensors,
denoising_fraction: float | None = None,
noise_bucket: str | None = None,
) -> MemoryStreamTensors:
gate_state = self._effective_gate_state(
denoising_fraction=denoising_fraction,
noise_bucket=noise_bucket,
)
gates = gate_state["gates"]
device = streams.anchor_tokens.device
dtype = streams.anchor_tokens.dtype
B, T_tgt = streams.anchor_tokens.shape[:2]
valid_revisit_mask = streams.valid_revisit_mask
if valid_revisit_mask is None:
valid_revisit_mask = torch.zeros((B, T_tgt), device=device, dtype=torch.bool)
else:
valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool)
diagnostics = dict(streams.diagnostics)
def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor:
value = diagnostics.get(name)
if value is None:
return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32)
tensor = torch.as_tensor(value, device=device, dtype=torch.float32)
if tensor.ndim == 0:
return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32)
return tensor.expand((B, T_tgt))
revisit_best_selected_fov_overlap = _diagnostic_tensor("revisit_best_selected_fov_overlap_per_target")
revisit_best_selected_plucker_overlap = _diagnostic_tensor("revisit_best_selected_plucker_overlap_per_target")
revisit_selected_gap_frames = _diagnostic_tensor("revisit_selected_gap_frames_per_target", -1.0)
anchor_effective_enabled = gate_state["anchor_effective_enabled"]
dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"]
anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0
dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0
gate_module = getattr(self, "dememwm_revisit_gate", None)
if gate_module is None:
revisit_gate_raw = torch.ones((B, T_tgt), device=device, dtype=dtype)
else:
revisit_gate_raw = gate_module(
valid_revisit_mask=valid_revisit_mask,
best_selected_fov_overlap=revisit_best_selected_fov_overlap,
best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
selected_gap_frames=revisit_selected_gap_frames,
).to(device=device, dtype=dtype)
valid_revisit_eff_mask = valid_revisit_mask
if not revisit_stage_config_enabled or gate_state["force_revisit_off"]:
revisit_gate = torch.zeros_like(revisit_gate_raw)
elif gate_state["force_revisit_on"]:
revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw)
else:
revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate)
if not revisit_stage_config_enabled:
valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
revisit_gate = torch.zeros_like(revisit_gate)
no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask)
eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask")
if eval_corrupted_revisit_mask is not None:
eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool)
revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item())
diagnostics.update(gate_state["curriculum_state"].diagnostics())
diagnostics.update({
"dememwm_stage": gates.stage,
"dememwm_gate_reason": gates.reason,
"anchor_config_enabled": gate_state["anchor_config_enabled"],
"dynamic_config_enabled": gate_state["dynamic_config_enabled"],
"revisit_config_enabled": gate_state["revisit_config_enabled"],
"anchor_effective_enabled": anchor_effective_enabled,
"dynamic_effective_enabled": dynamic_effective_enabled,
"revisit_effective_enabled": revisit_effective_enabled,
"revisit_stage_config_enabled": revisit_stage_config_enabled,
"revisit_gate_raw": revisit_gate_raw.detach(),
"revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
"no_valid_revisit_mask": no_valid_revisit_mask,
"valid_revisit_eff_mask": valid_revisit_eff_mask,
"revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0,
"revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0,
})
diagnostics.update(summarize_noise_bucket_diagnostics(
noise_bucket=gate_state["resolved_noise_bucket"],
valid_revisit_mask=valid_revisit_mask,
no_valid_revisit_mask=no_valid_revisit_mask,
))
diagnostics.update(summarize_eval_ablation_diagnostics(
enabled=gate_state["eval_ablation_enabled"],
branch=gate_state["eval_ablation_branch"],
valid_revisit_mask=valid_revisit_mask,
no_valid_revisit_mask=no_valid_revisit_mask,
eval_corrupted_revisit_mask=eval_corrupted_revisit_mask,
))
return replace(
streams,
anchor_gate=anchor_gate,
dynamic_gate=dynamic_gate,
revisit_gate=revisit_gate,
revisit_gate_raw=revisit_gate_raw,
valid_revisit_mask=valid_revisit_mask,
no_valid_revisit_mask=no_valid_revisit_mask,
diagnostics=diagnostics,
)
def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> tuple[dict, dict]:
memory_kwargs, diagnostics = self.dememwm_injection_adapter(streams, device=streams.anchor_tokens.device, dtype=streams.anchor_tokens.dtype)
return memory_kwargs, diagnostics
def build_memory_kwargs(self, *args, **kwargs) -> tuple[dict, dict]:
streams = self.build_memory_streams(*args, **kwargs)
return self._streams_to_kwargs(streams)
def _memory_adapter_delta_diagnostics(self) -> dict:
dit_model = getattr(getattr(self, "diffusion_model", None), "model", None)
diagnostics_fn = getattr(dit_model, "memory_adapter_delta_diagnostics", None)
if diagnostics_fn is None:
return {}
return diagnostics_fn()
def _log_memory_diagnostics(self, namespace: str, diagnostics: dict) -> None:
if namespace == "training/dememwm":
allowed_keys = self._TRAIN_DIAGNOSTIC_LOG_KEYS
elif namespace.endswith("/dememwm"):
allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS
else:
allowed_keys = None
for key, value in diagnostics.items():
if allowed_keys is not None and key not in allowed_keys:
continue
if isinstance(value, str) or value is None:
continue
if torch.is_tensor(value):
if value.numel() > 0:
self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True)
elif isinstance(value, (bool, int, float)):
self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True)
def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx):
from ..df_video import convert_to_plucker
image_height, image_width = self._image_size(xs)
if self.use_plucker:
if self.relative_embedding:
input_pose_condition = []
frame_idx_list = []
ref_c2w = c2w_mat[-self.memory_condition_length:] if self.memory_condition_length else c2w_mat[:0]
ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0]
for i in range(c2w_mat.shape[0]):
input_pose_condition.append(
convert_to_plucker(
torch.cat([c2w_mat[i:i + 1], ref_c2w]).clone(),
0,
focal_length=self.focal_length,
image_height=image_height, image_width=image_width
).to(xs.dtype)
)
frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone())
return torch.cat(input_pose_condition), torch.cat(frame_idx_list)
return convert_to_plucker(
c2w_mat, 0, focal_length=self.focal_length,
image_height=image_height, image_width=image_width
).to(xs.dtype), frame_idx
return pose_conditions.to(xs.dtype), None
def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]:
total_frames = max(0, int(total_frames))
n_tokens = max(1, min(int(self.n_tokens), total_frames))
max_start = max(0, total_frames - n_tokens)
if max_start == 0:
return 0, n_tokens
context_start = self._context_frame_count()
min_start = min(context_start, max_start)
if min_start == max_start:
return min_start, min_start + n_tokens
start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item())
return start, start + n_tokens
def training_step(self, batch, batch_idx):
xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
xs = self._as_latents(xs)
# Randomly select a contiguous n_tokens denoising window inside the long
# clip. DeMemWM memory streams are selected causally from frames before
# each target, then only those selected frames are projected.
total_frames = xs.shape[0]
start, end = self._training_window_bounds(total_frames, xs.device)
xs_window = xs[start:end]
conditions_window = conditions[start:end].clone()
frame_idx_window = frame_idx[start:end]
input_pose_condition, frame_idx_list = self._training_pose_condition(
xs_window, pose_conditions[start:end], c2w_mat[start:end], frame_idx_window
)
noise_levels = self._generate_noise_levels(xs_window)
if self.memory_condition_length:
noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
conditions_window[-self.memory_condition_length:] *= 0
source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool)
memory_source_latents, source_is_generated, proxy_diagnostics = self._apply_generated_history_proxy(
xs,
source_is_generated,
context_frame_count=self._context_frame_count(),
target_start_frame=start,
)
timesteps = int(getattr(self, "timesteps", 0) or 0)
training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps)
training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps)
training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps)
memory_kwargs, diagnostics = self.build_memory_kwargs(
memory_source_latents,
frame_idx,
target_frame_indices=frame_idx_window,
pose=pose_conditions,
target_pose=pose_conditions[start:end],
action=conditions,
target_action=conditions_window,
source_is_generated=source_is_generated,
denoising_fraction=training_denoising_fraction,
noise_bucket=training_noise_bucket,
noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1),
)
diagnostics.update(proxy_diagnostics)
_, loss = self.diffusion_model(
xs_window,
conditions_window,
input_pose_condition,
noise_levels=noise_levels,
reference_length=self.memory_condition_length,
frame_idx=frame_idx_list,
**memory_kwargs,
)
diagnostics.update(self._memory_adapter_delta_diagnostics())
if self.memory_condition_length:
loss = loss[:-self.memory_condition_length]
loss_denoise = self.reweight_loss(loss, None)
loss_total = loss_denoise
diagnostics["training_window_start"] = int(start)
diagnostics["training_window_end"] = int(end)
diagnostics["training_window_size"] = int(end - start)
diagnostics["loss_denoise"] = float(loss_denoise.detach().item())
diagnostics["loss_total"] = float(loss_total.detach().item())
if batch_idx % 20 == 0:
self.log("training/loss", loss_total.detach().cpu())
self._log_memory_diagnostics("training/dememwm", diagnostics)
return {"loss": loss_total}
def validation_step(self, batch, batch_idx, namespace="validation"):
import numpy as np
from tqdm import tqdm
memory_condition_length = self.memory_condition_length
xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
total_frame = xs_raw.shape[0]
if bool(getattr(self, "_last_dememwm_xs_are_latents", False)):
xs = xs_raw.cpu()
elif total_frame > 10:
xs = torch.cat([self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu() for i in range(10)])
else:
xs = self.encode(xs_raw).cpu()
n_frames, batch_size, *_ = xs.shape
curr_frame = 0
n_context_frames = self.context_frames // self.frame_stack
xs_pred = xs[:n_context_frames].clone()
curr_frame += n_context_frames
streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}")
cached_until = 0
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
last_diagnostics = None
while curr_frame < n_frames:
if streaming_cache is not None and curr_frame > cached_until:
new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device)
if curr_frame > n_context_frames:
rel_start = max(0, n_context_frames - cached_until)
new_generated[rel_start:] = True
self._update_streaming_cache(
streaming_cache,
xs_pred[cached_until:curr_frame],
frame_idx[cached_until:curr_frame],
pose=pose_conditions[cached_until:curr_frame],
source_is_generated=new_generated,
action=conditions[cached_until:curr_frame],
)
cached_until = curr_frame
horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
scheduling_matrix = self._generate_scheduling_matrix(horizon)
chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
xs_pred = torch.cat([xs_pred, chunk], 0)
start_frame = max(0, curr_frame + horizon - self.n_tokens)
pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
if memory_condition_length:
random_idx = self._generate_condition_indices(curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon)
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
else:
random_idx = torch.empty((0, batch_size), dtype=torch.long, device=frame_idx.device)
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
image_width=self._image_size(xs_raw)[1], image_height=self._image_size(xs_raw)[0]
)
target_idx = frame_idx[start_frame:curr_frame + horizon].to(input_condition.device)
use_streaming_cache = streaming_cache is not None and streaming_cache.record_count > 0
target_pose = pose_conditions[start_frame:curr_frame + horizon].to(input_condition.device)
target_action = conditions[start_frame:curr_frame + horizon].to(input_condition.device)
if use_streaming_cache:
committed_latents = None
committed_idx = None
generated_flags = None
source_pose = None
source_action = None
else:
committed_latents = xs_pred[:curr_frame].to(input_condition.device)
committed_idx = frame_idx[:curr_frame].to(input_condition.device)
generated_flags = torch.zeros(committed_idx.shape, device=input_condition.device, dtype=torch.bool)
if curr_frame > n_context_frames:
generated_flags[n_context_frames:] = True
source_pose = pose_conditions[:curr_frame].to(input_condition.device)
source_action = conditions[:curr_frame].to(input_condition.device)
memory_streams = self.build_memory_streams(
committed_latents,
committed_idx,
target_frame_indices=target_idx,
pose=source_pose,
target_pose=target_pose,
action=source_action,
target_action=target_action,
source_is_generated=generated_flags,
denoising_fraction=None,
streaming_cache=streaming_cache,
)
for m in range(scheduling_matrix.shape[0] - 1):
from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length)
denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0)
step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac)
memory_kwargs, last_diagnostics = self._streams_to_kwargs(step_streams)
xs_pred[start_frame:] = self.diffusion_model.sample_step(
xs_pred[start_frame:].to(input_condition.device),
input_condition,
input_pose_condition,
from_noise_levels[start_frame:],
to_noise_levels[start_frame:],
current_frame=curr_frame,
mode="validation",
reference_length=memory_condition_length,
frame_idx=frame_idx_list,
**memory_kwargs,
).cpu()
if memory_condition_length:
xs_pred = xs_pred[:-memory_condition_length]
curr_frame += horizon
if streaming_cache is not None and curr_frame > cached_until:
new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device)
if curr_frame > n_context_frames:
rel_start = max(0, n_context_frames - cached_until)
new_generated[rel_start:] = True
self._update_streaming_cache(
streaming_cache,
xs_pred[cached_until:curr_frame],
frame_idx[cached_until:curr_frame],
pose=pose_conditions[cached_until:curr_frame],
source_is_generated=new_generated,
action=conditions[cached_until:curr_frame],
)
cached_until = curr_frame
if last_diagnostics is not None:
last_diagnostics.update(streaming_cache.diagnostics("cache"))
pbar.update(horizon)
pbar.close()
if last_diagnostics is not None:
self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics)
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
return
def strict_checkpoint_key_check(self, state_dict: dict, required_prefixes: Iterable[str] | None = None) -> None:
prefixes = tuple(required_prefixes or self.strict_key_prefixes)
strip_prefixes = ("", "model.", "module.", "algo.")
normalized_keys = []
for key in state_dict.keys():
key = str(key)
for strip_prefix in strip_prefixes:
if not strip_prefix or key.startswith(strip_prefix):
normalized_keys.append(key.removeprefix(strip_prefix))
missing_prefixes = [prefix for prefix in prefixes if not any(key.startswith(prefix) for key in normalized_keys)]
missing_substrings = [
marker
for marker in self.strict_key_substrings
if not any(marker in key for key in normalized_keys)
]
if missing_prefixes or missing_substrings:
raise RuntimeError(
"DeMemWM checkpoint is missing required DeMemWM key coverage: "
f"prefixes={missing_prefixes}, memory_adapter_markers={missing_substrings}"
)
# Compatibility aliases for old DeMemWM test and experiment call sites.
dememwm_strict_key_prefixes = strict_key_prefixes
dememwm_strict_key_substrings = strict_key_substrings
_DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS
_DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS
_dememwm_cfg = _memory_cfg
_dememwm_stage_policy_cfg = _stage_policy_cfg
_dememwm_eval_ablation_cfg = _eval_ablation_cfg
_dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg
_dememwm_eval_ablation_state = _eval_ablation_state
_dememwm_effective_gate_state = _effective_gate_state
_dememwm_validate_config_contract = _validate_config_contract
_dememwm_stream_enabled = _stream_enabled
_dememwm_context_frame_count = _context_frame_count
_dememwm_local_context_exclusion_frames = _local_context_exclusion_frames
_dememwm_curriculum_state = _curriculum_state
_dememwm_generated_history_proxy_prob = _generated_history_proxy_prob
_dememwm_apply_generated_history_proxy = _apply_generated_history_proxy
_dememwm_checkpoint_cfg = _checkpoint_cfg
_dememwm_strict_eval_load_enabled = _strict_eval_load_enabled
_dememwm_cache_cfg = _cache_cfg
_dememwm_cache_enabled = _cache_enabled
_dememwm_new_streaming_cache = _new_streaming_cache
_dememwm_is_memory_adapter_param = _is_memory_adapter_param
_dememwm_param_group_name = _param_group_name
_dememwm_group_trainable = _group_trainable
_dememwm_group_lr = _group_lr
_dememwm_apply_freeze_policy = _apply_freeze_policy
_dememwm_as_latents = _as_latents
_dememwm_image_size = _image_size
_dememwm_update_streaming_cache = _update_streaming_cache
_build_dememwm_streaming_cache_records = _build_streaming_cache_records
_build_dememwm_causal_memory_banks = _build_causal_memory_banks
_build_dememwm_preselected_causal_memory_banks = _build_preselected_causal_memory_banks
_dememwm_records_to_stream = _records_to_stream
build_dememwm_memory_streams = build_memory_streams
_dememwm_refresh_stream_gates = _refresh_stream_gates
_dememwm_streams_to_kwargs = _streams_to_kwargs
build_dememwm_memory_kwargs = build_memory_kwargs
_dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics
_log_dememwm_diagnostics = _log_memory_diagnostics
_dememwm_training_window_bounds = _training_window_bounds
strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check
DeMemWMMemoryDiTMixin = MemoryDiTMixin