diff --git "a/algorithms/worldmem/dememwm/algorithm.py" "b/algorithms/worldmem/dememwm/algorithm.py" new file mode 100644--- /dev/null +++ "b/algorithms/worldmem/dememwm/algorithm.py" @@ -0,0 +1,2464 @@ + +from __future__ import annotations + +import math +from dataclasses import replace +from typing import Iterable + +import torch +from einops import rearrange + +from .cache import StreamingCache +from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens +from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics +from .injection import InjectionAdapter +from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens +from .negatives import apply_revisit_eval_corruption +from .retrieval import deterministic_revisit_retrieval +from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, noise_bucket_from_denoising_fraction, noise_bucket_from_noise_levels, noise_bucket_ids_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum +from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors + + +class MemoryDiTMixin: + """Standalone DeMemWM / Memory-DiT mixin. + + Reuses the base video-DiT infrastructure while keeping memory construction and + injection under the standalone `dememwm` package. Legacy memory-method files + are not part of this path. + """ + + strict_key_prefixes = ( + "dememwm_dynamic_compressor.", + "dememwm_anchor_proj.", + "dememwm_revisit_proj.", + "dememwm_revisit_gate.", + ) + strict_key_substrings = ( + ".memory_token_cross_attn.", + ) + _TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({ + "revisit_candidate_frame_count", + "revisit_pose_preselect_input_count", + "revisit_pose_preselect_selected_count", + "revisit_exact_fov_candidate_count", + "valid_revisit_frame_count", + "valid_revisit_target_count", + "no_valid_revisit_count", + "revisit_selected_frame_count", + "revisit_frame_fov_overlap_mean", + "revisit_best_selected_frame_fov_overlap_mean", + "revisit_best_selected_plucker_overlap_mean", + "revisit_best_selected_gap_frames_mean", + "revisit_gate_raw", + "revisit_gate_eff", + "revisit_learned_gate_mean", + "revisit_effective_gate_mean", + "generated_history_proxy_prob", + "noise_bucket_target_count", + "noise_bucket_high_target_count", + "noise_bucket_mid_target_count", + "noise_bucket_low_target_count", + }) + _VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({ + "cache_records", + "cache_slots", + }) + + def _memory_cfg(self): + return getattr(self.cfg, "dememwm", None) + + def _cfg_get(self, obj, name, default): + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(name, default) + return getattr(obj, name, default) + + def _cfg_has(self, obj, name: str) -> bool: + if obj is None: + return False + if isinstance(obj, dict): + return name in obj + try: + getattr(obj, name) + return True + except Exception: + return False + + def _stage_policy_cfg(self): + return self._cfg_get(self._memory_cfg(), "stage_policy", None) + + def _eval_ablation_cfg(self): + return self._cfg_get(self._memory_cfg(), "eval_ablation", None) + + def _generated_history_proxy_cfg(self): + return self._cfg_get(self._memory_cfg(), "generated_history_proxy", None) + + def _eval_ablation_state(self) -> tuple[bool, str]: + cfg = self._eval_ablation_cfg() + enabled = bool(self._cfg_get(cfg, "enabled", False)) + branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal")) + return enabled, branch + + def _effective_gate_state(self, denoising_fraction: float | None = None, noise_bucket: str | None = None) -> dict: + memory_cfg = self._memory_cfg() + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + injection_cfg = self._cfg_get(memory_cfg, "injection", None) + anchor_config_enabled = self._stream_enabled(anchor_cfg) + dynamic_config_enabled = self._stream_enabled(dynamic_cfg) + revisit_config_enabled = self._stream_enabled(revisit_cfg) + curriculum_state = self._curriculum_state() + eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state() + debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False)) + resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction) + gates = compute_stream_gates( + curriculum_state.stage, + denoising_fraction=denoising_fraction, + debug_force_all_streams=debug_force, + anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)), + dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)), + revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)), + ) + anchor_effective_enabled = bool(gates.anchor_enabled and anchor_config_enabled) + dynamic_effective_enabled = bool(gates.dynamic_enabled and dynamic_config_enabled) + revisit_stage_config_enabled = bool(gates.revisit_enabled and revisit_config_enabled) + if eval_ablation_enabled: + if eval_ablation_branch == "memory_off": + anchor_effective_enabled = False + dynamic_effective_enabled = False + revisit_stage_config_enabled = False + elif eval_ablation_branch == "A_only": + dynamic_effective_enabled = False + revisit_stage_config_enabled = False + elif eval_ablation_branch == "D_only": + anchor_effective_enabled = False + revisit_stage_config_enabled = False + elif eval_ablation_branch == "A_plus_D": + revisit_stage_config_enabled = False + return { + "curriculum_state": curriculum_state, + "gates": gates, + "resolved_noise_bucket": resolved_noise_bucket, + "anchor_config_enabled": anchor_config_enabled, + "dynamic_config_enabled": dynamic_config_enabled, + "revisit_config_enabled": revisit_config_enabled, + "anchor_effective_enabled": anchor_effective_enabled, + "dynamic_effective_enabled": dynamic_effective_enabled, + "revisit_stage_config_enabled": revisit_stage_config_enabled, + "eval_ablation_enabled": eval_ablation_enabled, + "eval_ablation_branch": eval_ablation_branch, + "force_revisit_off": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_off"), + "force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"), + } + + def _validate_config_contract(self) -> dict: + if bool(getattr(self, "_dememwm_contract_validated", False)): + return getattr(self, "_last_dememwm_config_diagnostics", {}) + memory_cfg = self._memory_cfg() + if memory_cfg is None: + self._dememwm_contract_validated = True + self._last_dememwm_config_diagnostics = {} + return {} + + stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)] + if stale_sections: + raise ValueError(f"stale DeMemWM config sections are not part of the final contract: {stale_sections}") + ratio_fields = [ + name + for name in ("anchor_ratio", "dynamic_ratio", "revisit_ratio", "revisit_max_ratio") + if self._cfg_has(memory_cfg, name) + ] + if ratio_fields: + raise ValueError(f"standalone DeMemWM uses fixed manual token budgets, not ratio fields: {ratio_fields}") + + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + stale_nested = [] + for section_name, section_cfg, field_names in ( + ("anchor", anchor_cfg, ("policy", "topk", "pin_prefix")), + ("dynamic", dynamic_cfg, ("include_generated_recent",)), + ("revisit", revisit_cfg, ("deterministic_only", "min_age_frames", "min_gap_frames", "topk", "max_chunks", "chunk_frames", "min_score", "time_weight", "pose_weight", "latent_weight", "pose_overlap_threshold", "action_overlap_threshold", "generated_penalty", "force_gate_zero_when_invalid")), + ): + stale_nested.extend( + f"{section_name}.{field_name}" for field_name in field_names if self._cfg_has(section_cfg, field_name) + ) + if stale_nested: + raise ValueError(f"stale DeMemWM config fields are not part of the final contract: {stale_nested}") + + exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) + if exclude_latest_local_frames < 0: + raise ValueError("dememwm.dynamic.exclude_latest_local_frames must be non-negative") + if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)): + raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval") + fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) + if fov_overlap_threshold is not None: + fov_overlap_threshold = float(fov_overlap_threshold) + if fov_overlap_threshold < 0.0: + raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative") + high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) + if high_quality_fov_threshold < 0.0: + raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative") + plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) + if plucker_weight < 0.0: + raise ValueError("dememwm.revisit.plucker_weight must be non-negative") + for field_name, default in ( + ("fov_half_h", 52.5), + ("fov_half_v", 37.5), + ("fov_radius", 30.0), + ("plucker_focal_length", 0.35), + ): + value = float(self._cfg_get(revisit_cfg, field_name, default)) + if value <= 0.0: + raise ValueError(f"dememwm.revisit.{field_name} must be positive") + for field_name, default in ( + ("fov_yaw_samples", 25), + ("fov_pitch_samples", 20), + ("fov_depth_samples", 20), + ("plucker_grid_h", 4), + ("plucker_grid_w", 4), + ): + value = int(self._cfg_get(revisit_cfg, field_name, default)) + if value <= 0: + raise ValueError(f"dememwm.revisit.{field_name} must be positive") + stage_policy_cfg = self._stage_policy_cfg() + if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)): + raise ValueError("final DeMemWM keeps noise_bucket logging enabled") + proxy_cfg = self._generated_history_proxy_cfg() + proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0)) + proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0)) + proxy_noise_std = float(self._cfg_get(proxy_cfg, "noise_std", 0.0)) + proxy_ramp_steps = int(self._cfg_get(proxy_cfg, "ramp_steps", 0)) + if proxy_max_prob < 0.0 or proxy_max_prob > 1.0: + raise ValueError("dememwm.generated_history_proxy.max_prob must be in [0, 1]") + if proxy_dropout_prob < 0.0 or proxy_dropout_prob > 1.0: + raise ValueError("dememwm.generated_history_proxy.dropout_prob must be in [0, 1]") + if proxy_noise_std < 0.0: + raise ValueError("dememwm.generated_history_proxy.noise_std must be non-negative") + if proxy_ramp_steps < 0: + raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative") + eval_ablation_cfg = self._eval_ablation_cfg() + normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal")) + + diagnostics = { + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + "revisit_deterministic_fov_plucker_retrieval": True, + "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(), + "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold, + "revisit_high_quality_fov_threshold": high_quality_fov_threshold, + "revisit_plucker_weight": plucker_weight, + "stage_policy_noise_bucket_logging": True, + } + self._dememwm_contract_validated = True + self._last_dememwm_config_diagnostics = diagnostics + return diagnostics + + def _stream_enabled(self, stream_cfg) -> bool: + return bool(self._cfg_get(stream_cfg, "enabled", True)) + + def _context_frame_count(self) -> int: + frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) + return max(0, int(getattr(self, "context_frames", 0) or 0) // frame_stack) + + def _local_context_exclusion_frames(self) -> int: + n_tokens = max(0, int(getattr(self, "n_tokens", 0) or 0)) + frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) + return n_tokens * frame_stack + + def _curriculum_state(self, step: int | None = None): + if step is None: + step = int(getattr(self, "global_step", 0) or 0) + return resolve_curriculum(self._memory_cfg(), step) + + def _generated_history_proxy_prob(self, step: int | None = None) -> float: + cfg = self._generated_history_proxy_cfg() + if not bool(self._cfg_get(cfg, "enabled", False)): + return 0.0 + max_prob = min(max(float(self._cfg_get(cfg, "max_prob", 0.0)), 0.0), 1.0) + if max_prob <= 0.0: + return 0.0 + if step is None: + step = int(getattr(self, "global_step", 0) or 0) + start_step = int(self._cfg_get(cfg, "start_step", 0)) + if step < start_step: + return 0.0 + ramp_steps = int(self._cfg_get(cfg, "ramp_steps", 0)) + if ramp_steps <= 0: + return max_prob + ramp_fraction = min(max(float(step - start_step) / float(ramp_steps), 0.0), 1.0) + return max_prob * ramp_fraction + + def _apply_generated_history_proxy( + self, + source_latents: torch.Tensor, + source_is_generated: torch.Tensor | None, + context_frame_count: int | None = None, + target_start_frame: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + cfg = self._generated_history_proxy_cfg() + prob = self._generated_history_proxy_prob() + noise_std = float(self._cfg_get(cfg, "noise_std", 0.0)) + dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0)) + diagnostics = { + "generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)), + "generated_history_proxy_prob": float(prob), + "generated_history_proxy_noise_std": float(noise_std), + "generated_history_proxy_dropout_prob": float(dropout_prob), + "generated_history_proxy_frame_count": 0, + "generated_history_proxy_frame_fraction": 0.0, + } + if source_is_generated is None: + source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) + else: + source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool) + if prob <= 0.0 or source_latents.numel() == 0: + return source_latents, source_is_generated, diagnostics + + eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) + if context_frame_count is not None or target_start_frame is not None: + frame_positions = torch.arange(source_latents.shape[0], device=source_latents.device)[:, None] + if context_frame_count is not None: + eligible_mask &= frame_positions >= max(0, int(context_frame_count)) + if target_start_frame is not None: + eligible_mask &= frame_positions < max(0, int(target_start_frame)) + proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask + proxy_count = int(proxy_mask.detach().long().sum().item()) + total_count = max(1, int(proxy_mask.numel())) + diagnostics["generated_history_proxy_frame_count"] = proxy_count + diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count) + if proxy_count == 0: + return source_latents, source_is_generated, diagnostics + + corrupt_latents = source_latents.clone() + frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype) + if noise_std > 0.0: + corrupt_latents = corrupt_latents + torch.randn_like(corrupt_latents) * float(noise_std) * frame_mask + if dropout_prob > 0.0: + dropout_mask = torch.rand( + (*source_latents.shape[:2], 1, source_latents.shape[-2], source_latents.shape[-1]), + device=source_latents.device, + ) < dropout_prob + dropout_mask = dropout_mask & proxy_mask[:, :, None, None, None] + corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents) + source_is_generated = source_is_generated.clone() + source_is_generated |= proxy_mask + return corrupt_latents, source_is_generated, diagnostics + + def _checkpoint_cfg(self): + return self._cfg_get(self._memory_cfg(), "checkpoint", None) + + def _strict_eval_load_enabled(self) -> bool: + return bool(self._cfg_get(self._checkpoint_cfg(), "strict_dememwm_eval_load", True)) + + def _cache_cfg(self): + return self._cfg_get(self._memory_cfg(), "cache", None) + + def _cache_enabled(self) -> bool: + return bool(self._cfg_get(self._cache_cfg(), "enabled", False)) + + def _new_streaming_cache(self, video_id=None) -> StreamingCache | None: + if not self._cache_enabled(): + return None + cache = StreamingCache.from_config(self._cache_cfg(), enabled_default=True) + if cache.clear_between_videos: + cache.reset(video_id=video_id) + return cache + + def _is_memory_adapter_param(self, name: str) -> bool: + return ".memory_token_cross_attn." in name + + def _param_group_name(self, name: str, state=None) -> str: + state = state or self._curriculum_state() + if name.startswith("vae.") or name.startswith("validation_lpips_model."): + return "excluded_frozen" + if name.startswith(("dememwm_dynamic_compressor.", "dememwm_anchor_proj.", "dememwm_revisit_proj.")): + return "dememwm_modules" + if self._is_memory_adapter_param(name): + return "memory_adapters" + if name.startswith("diffusion_model."): + return "full_dit" + return "dememwm_modules" + + def _group_trainable(self, group_name: str, state) -> bool: + if group_name in {"dememwm_modules", "memory_adapters"}: + return True + if group_name == "full_dit": + return state.dit_full_trainable + return False + + def _group_lr(self, group_name: str, state) -> float: + if group_name == "dememwm_modules": + return state.dememwm_lr + if group_name == "memory_adapters": + return state.memory_adapter_lr + if group_name == "full_dit": + return state.full_dit_lr + return 0.0 + + def _apply_freeze_policy(self, optimizer=None, step: int | None = None): + state = self._curriculum_state(step) + + # Keep DDP's trainable graph stable: DiT params stay requires_grad=True + # from step 0 and are frozen by optimizer LR=0 until the full stage. + # Re-walk only when curriculum diagnostics can change. + freeze_key = (state.stage, state.dit_train_state, state.freeze_vae) + last_key = getattr(self, "_last_freeze_key", None) + if last_key != freeze_key: + trainable_tensors = { + "dememwm_modules": 0, + "memory_adapters": 0, + "full_dit": 0, + "excluded_frozen": 0, + } + trainable_scalars = {key: 0 for key in trainable_tensors} + requires_grad_tensors = {key: 0 for key in trainable_tensors} + requires_grad_scalars = {key: 0 for key in trainable_tensors} + for name, param in self.named_parameters(): + group_name = self._param_group_name(name, state) + should_train = self._group_trainable(group_name, state) + if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae): + should_train = False + should_require_grad = False + else: + should_require_grad = True + param.requires_grad_(should_require_grad) + if should_train: + trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1 + trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel()) + if should_require_grad: + requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1 + requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel()) + self._last_freeze_key = freeze_key + self._last_trainable_tensors = trainable_tensors + self._last_trainable_scalars = trainable_scalars + self._last_requires_grad_tensors = requires_grad_tensors + self._last_requires_grad_scalars = requires_grad_scalars + else: + trainable_tensors = getattr(self, "_last_trainable_tensors", {}) + trainable_scalars = getattr(self, "_last_trainable_scalars", {}) + requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {}) + requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {}) + + if optimizer is not None: + for param_group in optimizer.param_groups: + group_name = param_group.get("name", "") + trainable = self._group_trainable(group_name, state) + param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0 + + diagnostics = state.diagnostics() + for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): + diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0) + diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0) + diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0) + diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0) + diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0 + self._last_dememwm_freeze_diagnostics = diagnostics + return state + + def configure_optimizers(self): + state = self._curriculum_state(0) + self._apply_freeze_policy(step=0) + grouped: dict[str, list[torch.nn.Parameter]] = { + "dememwm_modules": [], + "memory_adapters": [], + "full_dit": [], + } + for name, param in self.named_parameters(): + group_name = self._param_group_name(name, state) + if group_name in grouped: + grouped[group_name].append(param) + param_groups = [] + for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): + params = grouped[group_name] + if params: + trainable = self._group_trainable(group_name, state) + param_groups.append({ + "params": params, + "lr": self._group_lr(group_name, state) if trainable else 0.0, + "name": group_name, + }) + if not param_groups: + raise RuntimeError("DeMemWM optimizer found no trainable parameter groups") + return torch.optim.AdamW( + param_groups, + weight_decay=self.cfg.weight_decay, + betas=self.cfg.optimizer_beta, + ) + + def on_train_start(self): + optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] + for optimizer in optimizers: + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) + + def on_train_batch_start(self, batch, batch_idx): + optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] + for optimizer in optimizers: + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) + + def on_after_backward(self): + step = int(getattr(self, "global_step", 0) or 0) + state = self._apply_freeze_policy(step=step) + for name, param in self.named_parameters(): + if param.grad is None: + continue + group_name = self._param_group_name(name, state) + if not self._group_trainable(group_name, state): + param.grad = None + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) + optimizer.step(closure=optimizer_closure) + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0) + 1) + + def on_load_checkpoint(self, checkpoint): + super().on_load_checkpoint(checkpoint) + if self._strict_eval_load_enabled(): + state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint + self.strict_checkpoint_key_check(state_dict) + + def _preprocess_batch(self, batch): + """Preprocess RGB or precomputed-latent Minecraft batches for DeMemWM. + + MinecraftVideoLatentDataset returns an extra image_hw tensor. Keep the + DeMemWM path on VAE latents while preserving RGB image size for Plucker + pose embeddings. This mirrors the existing latent-dataset contract + without routing through the legacy SSM memory implementation. + """ + from ..df_video import euler_to_camera_to_world_matrix + + if len(batch) == 5: + xs, conditions, pose_conditions, frame_index, image_hw = batch + self._last_dememwm_xs_are_latents = True + self._last_dememwm_image_hw = image_hw + else: + xs, conditions, pose_conditions, frame_index = batch + self._last_dememwm_xs_are_latents = False + self._last_dememwm_image_hw = None + + if self.action_cond_dim: + conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) + conditions = rearrange(conditions, "b t d -> t b d").contiguous() + else: + raise NotImplementedError("Only support external cond.") + + pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous() + c2w_mat = euler_to_camera_to_world_matrix(pose_conditions) + xs = rearrange(xs, "b t c ... -> t b c ...").contiguous() + frame_index = rearrange(frame_index, "b t -> t b").contiguous() + return xs, conditions, pose_conditions, c2w_mat, frame_index + + def _as_latents(self, xs: torch.Tensor) -> torch.Tensor: + if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): + return xs + return self.encode(xs) + + def _image_size(self, xs: torch.Tensor) -> tuple[int, int]: + image_hw = getattr(self, "_last_dememwm_image_hw", None) + if image_hw is not None: + if torch.is_tensor(image_hw): + values = image_hw.detach().cpu().reshape(-1).tolist() + else: + values = list(image_hw) + if len(values) >= 2: + return int(values[0]), int(values[1]) + return int(xs.shape[-2]), int(xs.shape[-1]) + + def _update_streaming_cache( + self, + cache: StreamingCache | None, + new_latents: torch.Tensor, + frame_indices: torch.Tensor, + pose: torch.Tensor | None = None, + source_is_generated: torch.Tensor | None = None, + action: torch.Tensor | None = None, + ) -> None: + if cache is None or not cache.enabled or new_latents is None or new_latents.shape[0] == 0: + return + cache.add_raw_latents(new_latents, frame_indices, source_is_generated, pose) + if not cache.keep_compressed_records: + return + memory_cfg = self._memory_cfg() + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) + anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] + anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) + anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( + int(new_latents.shape[-2]), + int(new_latents.shape[-1]), + self.dememwm_anchor_proj, + token_patch_size, + ) + anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( + anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 + ) + anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) + allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) + # Prefix anchors are a per-video prefix resource. Do not add new prefix + # anchors for later committed segments unless explicitly generated anchors are allowed. + if cache.records_count("anchor") > 0 and not allow_generated_anchor: + anchor_indices = [] + anchor_banks, revisit_banks = self._build_streaming_cache_records( + new_latents, + frame_indices, + source_is_generated, + pose, + action, + allow_generated_anchor, + anchor_indices, + anchor_pool_h, + anchor_pool_w, + anchor_diverse, + token_patch_size, + ) + cache.add_memory_banks(anchor_banks, revisit_banks) + + def _build_model(self): + from algorithms.common.metrics import LearnedPerceptualImagePatchSimilarity + from .gates import RevisitRawGate + from ..models.diffusion import Diffusion + from ..models.pose_prediction import PosePredictionNet + from ..models.vae import VAE_models + + self.diffusion_model = Diffusion( + reference_length=self.memory_condition_length, + x_shape=self.x_stacked_shape, + action_cond_dim=self.action_cond_dim, + pose_cond_dim=self.pose_cond_dim, + is_causal=self.causal, + cfg=self.cfg.diffusion, + is_dit=True, + use_plucker=self.use_plucker, + relative_embedding=self.relative_embedding, + state_embed_only_on_qk=self.state_embed_only_on_qk, + use_memory_attention=False, + add_timestamp_embedding=self.add_timestamp_embedding, + memory_token_cross_attention=getattr(self.cfg, "memory_token_cross_attention", True), + memory_cross_attn_layers=getattr(self.cfg, "memory_cross_attn_layers", None), + ref_mode=self.ref_mode, + ) + memory_cfg = self._memory_cfg() + self._validate_config_contract() + injection_cfg = self._cfg_get(memory_cfg, "injection", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) + token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) + max_source_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) + self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( + latent_channels=self.x_stacked_shape[0], + dit_hidden_size=hidden_size, + patch_size=token_patch_size, + conv_kernel_t=int(self._cfg_get(dynamic_cfg, "conv_kernel_t", 3)), + conv_stride_t=int(self._cfg_get(dynamic_cfg, "conv_stride_t", 2)), + max_source_frames=max_source_frames, + exclude_latest_local_frames=int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)), + ) + spatial_mid_channels = self.x_stacked_shape[0] * token_patch_size * token_patch_size + self.dememwm_anchor_proj = SpatialConv2DMemoryProjector( + latent_channels=self.x_stacked_shape[0], + dit_hidden_size=hidden_size, + mid_channels=spatial_mid_channels, + kernel_size=3, + ) + self.dememwm_revisit_proj = SpatialConv2DMemoryProjector( + latent_channels=self.x_stacked_shape[0], + dit_hidden_size=hidden_size, + mid_channels=spatial_mid_channels, + kernel_size=3, + ) + self.dememwm_revisit_gate = RevisitRawGate() + self.dememwm_injection_adapter = InjectionAdapter() + self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity() + self.vae = VAE_models["vit-l-20-shallow-encoder"]().eval() + for param in self.vae.parameters(): + param.requires_grad_(False) + if self.require_pose_prediction: + self.pose_prediction_model = PosePredictionNet() + + def _project_latent_patch_tokens( + self, + latents: torch.Tensor, + projection: torch.nn.Module, + patch_size: int, + ) -> torch.Tensor: + # (T,B,C,H,W) -> (B,T,T_frame,D). Conv2D projectors keep T_frame=H*W. + if bool(getattr(projection, "projects_spatial_latents", False)): + return projection(latents) + patch_vectors = latent_patch_tokens(latents, patch_size) + return projection(patch_vectors).permute(1, 0, 2, 3).contiguous() + + def _projected_spatial_grid_size( + self, + latent_h: int, + latent_w: int, + projection: torch.nn.Module, + patch_size: int, + ) -> tuple[int, int]: + if bool(getattr(projection, "projects_spatial_latents", False)): + return int(latent_h), int(latent_w) + return int(latent_h) // int(patch_size), int(latent_w) // int(patch_size) + + def _take_uniform_slots(self, tokens: torch.Tensor, num_slots: int) -> torch.Tensor: + if tokens.ndim != 2: + raise ValueError("tokens must have shape (N,D)") + num_slots = max(0, int(num_slots)) + if num_slots == 0: + return tokens[:0] + if tokens.shape[0] <= num_slots: + return tokens + idx = torch.linspace(0, tokens.shape[0] - 1, num_slots, device=tokens.device).round().long() + return tokens.index_select(0, idx) + + def _spatial_pool_tokens( + self, + tokens: torch.Tensor, + pool_h: int, + pool_w: int, + src_h: int, + src_w: int, + ) -> torch.Tensor: + return spatial_pool_tokens(tokens, pool_h, pool_w, src_h, src_w) + + def _resolve_spatial_pool_size( + self, + compress_cfg, + src_h: int, + src_w: int, + default_pool_h: int, + default_pool_w: int, + ) -> tuple[int, int]: + ratio = self._cfg_get(compress_cfg, "downsample_ratio", None) + ratio_h = self._cfg_get(compress_cfg, "downsample_h", ratio) + ratio_w = self._cfg_get(compress_cfg, "downsample_w", ratio) + if ratio_h is not None or ratio_w is not None: + if ratio_h is None: + ratio_h = ratio_w + if ratio_w is None: + ratio_w = ratio_h + ratio_h = float(ratio_h) + ratio_w = float(ratio_w) + if ratio_h <= 0.0 or ratio_w <= 0.0: + raise ValueError("DeMemWM compress downsample ratios must be positive") + return ( + max(1, int(math.ceil(float(src_h) / ratio_h))), + max(1, int(math.ceil(float(src_w) / ratio_w))), + ) + pool_h = int(self._cfg_get(compress_cfg, "pool_h", default_pool_h)) + pool_w = int(self._cfg_get(compress_cfg, "pool_w", default_pool_w)) + if pool_h <= 0 or pool_w <= 0: + raise ValueError("DeMemWM compress pool_h/pool_w must be positive") + return pool_h, pool_w + + def _select_diverse_anchor_positions( + self, + source_positions: torch.Tensor, + pose: torch.Tensor | None, + num_anchors: int, + ) -> torch.Tensor: + num_anchors = max(0, int(num_anchors)) + if num_anchors == 0: + return source_positions[:0] + if source_positions.numel() <= num_anchors or pose is None: + return source_positions[:num_anchors] + poses = pose.float() + selected = [0] + dists = torch.cdist(poses[0:1], poses).squeeze(0) + for _ in range(num_anchors - 1): + farthest = int(dists.argmax().item()) + selected.append(farthest) + d_new = torch.cdist(poses[farthest:farthest + 1], poses).squeeze(0) + dists = torch.minimum(dists, d_new) + return source_positions[torch.tensor(sorted(selected), device=source_positions.device)] + + def _build_streaming_cache_records( + self, + source_latents: torch.Tensor, + source_frame_indices: torch.Tensor, + source_is_generated: torch.Tensor | None, + pose: torch.Tensor | None, + action: torch.Tensor | None, + allow_generated_anchor: bool, + anchor_indices: list[int], + anchor_pool_h: int, + anchor_pool_w: int, + anchor_diverse: bool, + token_patch_size: int, + ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: + if source_latents.ndim != 5: + raise ValueError("source_latents must have shape (T,B,C,H,W)") + if source_frame_indices.ndim != 2: + raise ValueError("source_frame_indices must have shape (T,B)") + T_src, B = source_frame_indices.shape + if source_latents.shape[:2] != (T_src, B): + raise ValueError("source_latents and source_frame_indices must share T/B dimensions") + _, _, _, latent_H, latent_W = source_latents.shape + src_h, src_w = self._projected_spatial_grid_size( + latent_H, + latent_W, + self.dememwm_anchor_proj, + token_patch_size, + ) + + param = next(iter(self.dememwm_anchor_proj.parameters())) + project_device = param.device + project_dtype = param.dtype + hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) + generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) + anchor_banks: list[CausalMemoryBank] = [] + revisit_banks: list[CausalMemoryBank] = [] + + def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): + if tensor is None or tensor.ndim < 2: + return None + pos = positions.to(device=tensor.device) + if tensor.shape[0] == T_src and tensor.shape[1] == B: + return tensor[pos, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_src: + return tensor[batch_idx, pos] + return None + + def _metadata_subset(positions: torch.Tensor, batch_idx: int): + return {"dememwm_revisit_metadata_only": True} + + def _pose_subset(positions: torch.Tensor, batch_idx: int): + return _tensor_subset(pose, positions, batch_idx) + + def _add_anchor_records(bank: CausalMemoryBank, batch_idx: int, positions: torch.Tensor, generated_anchor: bool) -> None: + if positions.numel() == 0: + return + projected = self._project_latent_patch_tokens( + source_latents.index_select(0, positions.to(device=source_latents.device))[:, batch_idx:batch_idx + 1].to(device=project_device, dtype=project_dtype), + self.dememwm_anchor_proj, + token_patch_size, + )[0] + src_frames = source_frame_indices[:, batch_idx] + for local_idx, source_pos in enumerate(positions): + source_pos_i = int(source_pos.item()) + anchor_tokens = self._spatial_pool_tokens(projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) + n_slots = anchor_tokens.shape[0] + record_mask = torch.ones((n_slots,), device=anchor_tokens.device, dtype=torch.bool) + if generated_anchor: + bank.add_generated_records( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), + source_type=MemorySourceType.GENERATED, + ) + else: + bank.add_prefix_anchors( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), + slots_per_anchor=n_slots, + ) + + for batch_idx in range(B): + anchor_bank = CausalMemoryBank() + revisit_bank = CausalMemoryBank() + src_frames = source_frame_indices[:, batch_idx] + if generated is None: + non_generated = torch.ones_like(src_frames, dtype=torch.bool) + else: + non_generated = ~generated[:, batch_idx] + + source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() + if source_positions.numel() > 0: + if anchor_diverse: + anchor_pose = _pose_subset(source_positions, batch_idx) + selected_anchor_positions = self._select_diverse_anchor_positions( + source_positions, anchor_pose, len(anchor_indices) + ) + else: + selected_list = [] + for anchor_idx in anchor_indices: + if 0 <= int(anchor_idx) < source_positions.numel(): + selected_list.append(source_positions[int(anchor_idx)]) + selected_anchor_positions = torch.stack(selected_list).long() if selected_list else source_positions[:0] + if selected_anchor_positions.numel() > 0: + _add_anchor_records(anchor_bank, batch_idx, selected_anchor_positions.long(), False) + + dummy_tokens = torch.zeros((1, hidden_size), device=source_frame_indices.device, dtype=project_dtype) + dummy_mask = torch.ones((1,), device=source_frame_indices.device, dtype=torch.bool) + for prefix, positions, source_type, is_generated in ( + ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), + ( + "generated", + torch.empty(0, device=source_frame_indices.device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), + MemorySourceType.GENERATED, + True, + ), + ): + if positions.numel() == 0: + continue + for source_pos in positions.to(device=source_frame_indices.device, dtype=torch.long): + source_pos_i = int(source_pos.item()) + frame_index = src_frames[source_pos_i] + frame = int(frame_index.detach().item()) + frame_pos = source_pos.reshape(1) + revisit_bank.add_frame_record( + dummy_tokens, + dummy_mask, + frame_index, + pose=_pose_subset(frame_pos, batch_idx), + source_type=source_type, + metadata=_metadata_subset(frame_pos, batch_idx), + is_generated=is_generated, + record_id=f"{prefix}_revisit_b{batch_idx}_f{frame}", + ) + + if allow_generated_anchor and generated is not None and anchor_indices: + generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() + _add_anchor_records(anchor_bank, batch_idx, generated_positions[:len(anchor_indices)].long(), True) + + anchor_banks.append(anchor_bank) + revisit_banks.append(revisit_bank) + return anchor_banks, revisit_banks + + + def _build_causal_memory_banks( + self, + anchor_projected: torch.Tensor, + revisit_projected: torch.Tensor, + source_frame_indices: torch.Tensor, + source_is_generated: torch.Tensor | None, + pose: torch.Tensor | None, + action: torch.Tensor | None, + allow_generated_anchor: bool, + anchor_indices: list[int], + anchor_pool_h: int, + anchor_pool_w: int, + revisit_pool_h: int, + revisit_pool_w: int, + src_h: int, + src_w: int, + ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: + # projected tensors use the same batch/source convention as + # _project_latent_patch_tokens: (B, T_src, T_frame, D), while frame indices are + # (T_src, B). Build separate banks because anchor and revisit records + # come from different projections. + if anchor_projected.ndim != 4 or revisit_projected.ndim != 4: + raise ValueError("anchor/revisit projected tensors must have shape (B,T_src,T_frame,D)") + B, T_src, _, _ = anchor_projected.shape + if revisit_projected.shape[:3] != anchor_projected.shape[:3]: + raise ValueError("anchor/revisit projected tensors must share batch/source/token dimensions") + generated = None if source_is_generated is None else source_is_generated.bool().to(source_frame_indices.device) + anchor_banks: list[CausalMemoryBank] = [] + revisit_banks: list[CausalMemoryBank] = [] + + def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): + if tensor is None or tensor.ndim < 2: + return None + if tensor.shape[0] == T_src and tensor.shape[1] == B: + return tensor[positions, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_src: + return tensor[batch_idx, positions] + return None + + def _metadata_subset(positions: torch.Tensor, batch_idx: int): + return {} + + def _pose_subset(positions: torch.Tensor, batch_idx: int): + return _tensor_subset(pose, positions, batch_idx) + + for batch_idx in range(B): + anchor_bank = CausalMemoryBank() + revisit_bank = CausalMemoryBank() + src_frames = source_frame_indices[:, batch_idx] + if generated is None: + non_generated = torch.ones_like(src_frames, dtype=torch.bool) + else: + non_generated = ~generated[:, batch_idx] + + source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() + if source_positions.numel() > 0: + selected_anchor_positions = [] + for anchor_idx in anchor_indices: + if 0 <= int(anchor_idx) < source_positions.numel(): + selected_anchor_positions.append(source_positions[int(anchor_idx)]) + for source_pos in selected_anchor_positions: + source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) + anchor_tokens = self._spatial_pool_tokens( + anchor_projected[batch_idx, source_pos_i], + anchor_pool_h, anchor_pool_w, src_h, src_w, + ) + n_slots = anchor_tokens.shape[0] + record_mask = torch.ones( + (n_slots,), + device=anchor_projected.device, + dtype=torch.bool, + ) + anchor_bank.add_prefix_anchors( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1], + slots_per_anchor=n_slots, + ) + + for source_pos in source_positions: + source_pos_i = int(source_pos.item()) + frame_index = src_frames[source_pos_i] + frame = int(frame_index.detach().item()) + frame_pos = source_pos.reshape(1) + frame_tokens = self._spatial_pool_tokens( + revisit_projected[batch_idx, source_pos_i], + revisit_pool_h, revisit_pool_w, src_h, src_w, + ) + frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) + revisit_bank.add_frame_record( + frame_tokens, + frame_mask, + frame_index, + pose=_pose_subset(frame_pos, batch_idx), + source_type=MemorySourceType.PREFIX_GT, + metadata=_metadata_subset(frame_pos, batch_idx), + is_generated=False, + record_id=f"prefix_revisit_b{batch_idx}_f{frame}", + ) + + if generated is not None: + generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() + if generated_positions.numel() > 0: + for source_pos in generated_positions: + source_pos_i = int(source_pos.item()) + frame_index = src_frames[source_pos_i] + frame = int(frame_index.detach().item()) + frame_pos = source_pos.reshape(1) + frame_tokens = self._spatial_pool_tokens( + revisit_projected[batch_idx, source_pos_i], + revisit_pool_h, revisit_pool_w, src_h, src_w, + ) + frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) + revisit_bank.add_frame_record( + frame_tokens, + frame_mask, + frame_index, + pose=_pose_subset(frame_pos, batch_idx), + source_type=MemorySourceType.GENERATED, + metadata=_metadata_subset(frame_pos, batch_idx), + is_generated=True, + record_id=f"generated_revisit_b{batch_idx}_f{frame}", + ) + if allow_generated_anchor: + for source_pos in generated_positions[:len(anchor_indices)]: + source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) + anchor_tokens = self._spatial_pool_tokens( + anchor_projected[batch_idx, source_pos_i], + anchor_pool_h, anchor_pool_w, src_h, src_w, + ) + record_mask = torch.ones((anchor_tokens.shape[0],), device=anchor_projected.device, dtype=torch.bool) + anchor_bank.add_generated_records( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1], + source_type=MemorySourceType.GENERATED, + ) + + anchor_banks.append(anchor_bank) + revisit_banks.append(revisit_bank) + return anchor_banks, revisit_banks + + def _build_preselected_causal_memory_banks( + self, + committed_latents: torch.Tensor, + source_frame_indices: torch.Tensor, + source_is_generated: torch.Tensor | None, + pose: torch.Tensor | None, + action: torch.Tensor | None, + target_frame_indices: torch.Tensor, + target_pose: torch.Tensor | None, + target_action: torch.Tensor | None, + target_video_ids, + allow_generated_anchor: bool, + anchor_indices: list[int], + anchor_pool_h: int, + anchor_pool_w: int, + anchor_diverse: bool, + revisit_pool_h: int, + revisit_pool_w: int, + revisit_max_frames: int, + exclude_local_context_frames: int, + fov_overlap_threshold, + plucker_weight: float, + revisit_retrieval_kwargs: dict | None, + token_patch_size: int, + ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]: + if committed_latents.ndim != 5: + raise ValueError("committed_latents must have shape (T_src,B,C,H,W)") + T_src, B, _, H, W = committed_latents.shape + if source_frame_indices.shape != (T_src, B): + raise ValueError("source_frame_indices must have shape (T_src,B)") + if target_frame_indices.ndim == 1: + target_frame_indices = target_frame_indices[:, None] + if target_frame_indices.shape[1] != B: + raise ValueError("target_frame_indices must have batch dimension B") + T_tgt = target_frame_indices.shape[0] + stream_device = committed_latents.device + hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) + src_h, src_w = self._projected_spatial_grid_size( + H, + W, + self.dememwm_anchor_proj, + token_patch_size, + ) + tokens_per_frame = src_h * src_w + generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) + anchor_banks: list[CausalMemoryBank] = [] + revisit_banks: list[CausalMemoryBank] = [] + dummy_tokens = committed_latents.new_zeros((1, hidden_size)) + dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool) + preselection_candidate_count = 0 + preselection_valid_candidate_label_count = 0 + preselection_selected_count = 0 + projected_anchor_frames = 0 + projected_revisit_frames = 0 + projected_revisit_records = 0 + retrieval_kwargs = dict(revisit_retrieval_kwargs or {}) + + # Pre-convert pose tensors to stream_device once so that the + # _tensor_subset / _target_tensor closures below never trigger a + # device transfer on every call. + if pose is not None: + pose = pose.to(device=stream_device) + if target_pose is not None: + target_pose = target_pose.to(device=stream_device) + + def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): + if tensor is None or tensor.ndim < 2: + return None + if tensor.shape[0] == T_src and tensor.shape[1] == B: + return tensor[positions, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_src: + return tensor[batch_idx, positions] + return None + + def _target_tensor(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): + if tensor is None or tensor.ndim < 2: + return None + if tensor.shape[0] == T_tgt and tensor.shape[1] == B: + return tensor[target_idx, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_tgt: + return tensor[batch_idx, target_idx] + return None + + def _target_video_id(batch_idx: int, target_idx: int): + if target_video_ids is None: + return None + if torch.is_tensor(target_video_ids): + ids = target_video_ids.detach().cpu() + if ids.ndim == 0: + return ids.item() + if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: + return ids[target_idx, batch_idx].item() + if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: + return ids[batch_idx, target_idx].item() + return None + if isinstance(target_video_ids, (list, tuple)): + if len(target_video_ids) == B: + return target_video_ids[batch_idx] + if len(target_video_ids) == T_tgt: + row = target_video_ids[target_idx] + if isinstance(row, (list, tuple)) and len(row) == B: + return row[batch_idx] + return row + return target_video_ids + + def _metadata_subset(positions: torch.Tensor, batch_idx: int): + return {} + + def _pose_subset(positions: torch.Tensor, batch_idx: int): + return _tensor_subset(pose, positions, batch_idx) + + def _candidate_record( + *, + batch_idx: int, + frame_position: torch.Tensor, + source_type: MemorySourceType, + is_generated: bool, + record_id: str, + ) -> MemoryRecord: + frame_values = source_frame_indices[frame_position, batch_idx].to(device=stream_device) + frame = int(frame_values.reshape(-1)[0].item()) + return MemoryRecord( + tokens=dummy_tokens, + mask=dummy_mask, + source_start=frame, + source_end=frame + 1, + frame_indices=frame_values.reshape(1), + pose=_pose_subset(frame_position, batch_idx), + source_type=source_type, + is_generated=bool(is_generated), + chunk_id=record_id, + metadata=_metadata_subset(frame_position, batch_idx), + ) + + for batch_idx in range(B): + anchor_bank = CausalMemoryBank() + revisit_bank = CausalMemoryBank() + src_frames = source_frame_indices[:, batch_idx] + if generated is None: + non_generated = torch.ones_like(src_frames, dtype=torch.bool) + else: + non_generated = ~generated[:, batch_idx] + source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() + + anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long) + if anchor_indices and source_positions.numel() > 0: + if anchor_diverse: + anchor_source_positions = source_positions[source_positions < self._context_frame_count()] + if anchor_source_positions.numel() > 0: + anchor_pose = _pose_subset(anchor_source_positions, batch_idx) + anchor_positions = self._select_diverse_anchor_positions( + anchor_source_positions, anchor_pose, len(anchor_indices) + ).to(device=stream_device, dtype=torch.long) + else: + selected_anchor_positions = [] + for anchor_idx in anchor_indices: + if 0 <= int(anchor_idx) < source_positions.numel(): + selected_anchor_positions.append(source_positions[int(anchor_idx)]) + if selected_anchor_positions: + anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long) + if anchor_positions.numel() > 0: + projected_anchor_frames += int(anchor_positions.numel()) + anchor_projected = self._project_latent_patch_tokens( + committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1], + self.dememwm_anchor_proj, + token_patch_size, + )[0] + for local_idx, source_pos in enumerate(anchor_positions): + source_pos_i = int(source_pos.item()) + anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) + n_slots = anchor_tokens.shape[0] + record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool) + anchor_bank.add_prefix_anchors( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1], + slots_per_anchor=n_slots, + ) + + candidate_records: list[MemoryRecord] = [] + candidate_positions: dict[str, torch.Tensor] = {} + src_frames_cpu = src_frames.detach().cpu() + target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long) + latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames) + for prefix, positions, source_type, is_generated in ( + ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), + ( + "generated", + torch.empty(0, device=stream_device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), + MemorySourceType.GENERATED, + True, + ), + ): + if positions.numel() == 0 or latest_valid_source_frame_exclusive <= 0: + continue + positions_cpu = positions.detach().cpu().to(dtype=torch.long) + for frame_position_cpu in positions_cpu: + frame = int(src_frames_cpu[int(frame_position_cpu.item())].item()) + if frame >= latest_valid_source_frame_exclusive: + continue + frame_position = frame_position_cpu.reshape(1).to(device=stream_device, dtype=torch.long) + record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}" + candidate_positions[record_id] = frame_position + candidate_records.append(_candidate_record( + batch_idx=batch_idx, + frame_position=frame_position, + source_type=source_type, + is_generated=is_generated, + record_id=record_id, + )) + + selected_frame_record_ids: set[str] = set() + selected_frame_metadata: dict[str, dict] = {} + for target_idx in range(T_tgt): + target_frame = int(target_frame_indices[target_idx, batch_idx].item()) + result = deterministic_revisit_retrieval( + candidate_records, + target_frame=target_frame, + target_pose=_target_tensor(target_pose, batch_idx, target_idx), + target_summary=None, + topk=revisit_max_frames, + exclude_local_context_frames=exclude_local_context_frames, + fov_overlap_threshold=fov_overlap_threshold, + plucker_weight=plucker_weight, + target_video_id=_target_video_id(batch_idx, target_idx), + **retrieval_kwargs, + ) + preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) + preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0)) + preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) + for selected_record in result.records: + if selected_record.chunk_id is None: + continue + record_id = str(selected_record.chunk_id) + selected_frame_record_ids.add(record_id) + selected_frame_metadata[record_id] = dict(selected_record.metadata) + + for record in candidate_records: + if record.chunk_id not in selected_frame_record_ids: + continue + record_id = str(record.chunk_id) + frame_position = candidate_positions[record_id] + projected_revisit_records += 1 + projected_revisit_frames += int(frame_position.numel()) + revisit_projected = self._project_latent_patch_tokens( + committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1], + self.dememwm_revisit_proj, + token_patch_size, + )[0] + frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, src_h, src_w) + frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool) + record_metadata = dict(record.metadata) + record_metadata.update(selected_frame_metadata.get(record_id, {})) + revisit_bank.add_frame_record( + frame_tokens, + frame_mask, + record.frame_indices.reshape(-1)[0], + pose=record.pose, + source_type=record.source_type, + metadata=record_metadata, + is_generated=record.is_generated, + record_id=record.chunk_id, + ) + + anchor_banks.append(anchor_bank) + revisit_banks.append(revisit_bank) + + diagnostics = { + "preselected_anchor_projected_frame_count": projected_anchor_frames, + "preselected_revisit_projected_frame_count": projected_revisit_frames, + "preselected_revisit_projected_frame_record_count": projected_revisit_records, + "preselected_revisit_candidate_frame_count": preselection_candidate_count, + "preselected_revisit_candidate_count": preselection_candidate_count, + "preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count, + "preselected_revisit_selected_frame_count": preselection_selected_count, + "preselected_revisit_selected_count": preselection_selected_count, + } + return anchor_banks, revisit_banks, tokens_per_frame, diagnostics + + def _records_to_stream( + self, + records, + max_tokens: int, + hidden_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + max_tokens = max(0, int(max_tokens)) + record_list = list(records) + stacked_tokens, stacked_mask = stack_record_tokens(record_list, max_slots=max_tokens) + max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1) + if stacked_tokens is None or stacked_mask is None or max_tokens == 0: + tokens = torch.zeros((max_tokens, hidden_size), device=device, dtype=dtype) + mask = torch.zeros((max_tokens,), device=device, dtype=torch.bool) + return tokens, mask, max_source_frame + n = min(max_tokens, stacked_tokens.shape[0]) + filled = stacked_tokens[:n].to(device=device, dtype=dtype) + filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool) + if n < max_tokens: + pad = filled.new_zeros(max_tokens - n, hidden_size) + pad_mask = torch.zeros(max_tokens - n, device=device, dtype=torch.bool) + tokens = torch.cat([filled, pad], dim=0) + mask = torch.cat([filled_mask, pad_mask], dim=0) + else: + tokens = filled + mask = filled_mask + return tokens, mask, max_source_frame + + def _project_streaming_revisit_records( + self, + *, + cache: StreamingCache, + batch_idx: int, + records: list[MemoryRecord], + device: torch.device, + dtype: torch.dtype, + token_patch_size: int, + revisit_pool_h: int, + revisit_pool_w: int, + projection_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord], + ) -> list[MemoryRecord]: + projected_records: list[MemoryRecord] = [] + for record in records: + if not bool(record.metadata.get("dememwm_revisit_metadata_only", False)): + projected_records.append(record) + continue + selected_frame_index = record.metadata.get("dememwm_selected_frame_index") + if selected_frame_index is None: + best_frame_idx = record.frame_indices[torch.argmax(record.frame_indices)].reshape(1) + else: + best_frame_idx = torch.as_tensor( + [int(selected_frame_index)], + device=record.frame_indices.device, + dtype=record.frame_indices.dtype, + ) + key = ( + int(batch_idx), + str(record.chunk_id or ""), + int(record.source_start), + int(record.source_end), + int(best_frame_idx.detach().cpu().reshape(-1)[0].item()), + bool(record.is_generated), + ) + cached = projection_cache.get(key) + if cached is not None: + projected_records.append(cached) + continue + + raw_latents = cache.raw_latents_for_frames( + batch_idx=batch_idx, + frame_indices=best_frame_idx, + device=device, + dtype=dtype, + ) + revisit_projected = self._project_latent_patch_tokens( + raw_latents, + self.dememwm_revisit_proj, + token_patch_size, + )[0] + _proj_src_h, _proj_src_w = self._projected_spatial_grid_size( + raw_latents.shape[3], + raw_latents.shape[4], + self.dememwm_revisit_proj, + token_patch_size, + ) + frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, _proj_src_h, _proj_src_w) + frame_mask = torch.ones((frame_tokens.shape[0],), device=device, dtype=torch.bool) + metadata = { + key: (value.to(device=device) if torch.is_tensor(value) else value) + for key, value in record.metadata.items() + } + metadata["dememwm_revisit_metadata_only"] = False + projected = MemoryRecord( + tokens=frame_tokens, + mask=frame_mask, + source_start=int(record.source_start), + source_end=int(record.source_end), + frame_indices=record.frame_indices.to(device=device), + pose=None if record.pose is None else record.pose.to(device=device), + source_type=record.source_type, + is_generated=bool(record.is_generated), + score=record.score, + chunk_id=record.chunk_id, + metadata=metadata, + ) + projection_cache[key] = projected + projected_records.append(projected) + return projected_records + + def build_memory_streams( + self, + committed_latents: torch.Tensor | None, + source_frame_indices: torch.Tensor | None, + target_frame_indices: torch.Tensor | None = None, + pose: torch.Tensor | None = None, + target_pose: torch.Tensor | None = None, + action: torch.Tensor | None = None, + target_action: torch.Tensor | None = None, + target_video_ids=None, + source_is_generated: torch.Tensor | None = None, + denoising_fraction: float | None = None, + noise_bucket: str | None = None, + noise_bucket_ids: torch.Tensor | None = None, + streaming_cache: StreamingCache | None = None, + ) -> MemoryStreamTensors: + if target_frame_indices is None: + if source_frame_indices is None: + raise ValueError("target_frame_indices or source_frame_indices is required") + target_frame_indices = source_frame_indices + memory_cfg = self._memory_cfg() + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + injection_cfg = self._cfg_get(memory_cfg, "injection", None) + contract_diag = self._validate_config_contract() + gate_state = self._effective_gate_state( + denoising_fraction=denoising_fraction, + noise_bucket=noise_bucket, + ) + anchor_config_enabled = gate_state["anchor_config_enabled"] + dynamic_config_enabled = gate_state["dynamic_config_enabled"] + revisit_config_enabled = gate_state["revisit_config_enabled"] + curriculum_state = gate_state["curriculum_state"] + eval_ablation_enabled = gate_state["eval_ablation_enabled"] + eval_ablation_branch = gate_state["eval_ablation_branch"] + resolved_noise_bucket = gate_state["resolved_noise_bucket"] + gates = gate_state["gates"] + anchor_effective_enabled = gate_state["anchor_effective_enabled"] + dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] + revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] + force_revisit_off = gate_state["force_revisit_off"] + force_revisit_on = gate_state["force_revisit_on"] + token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) + anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] + anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) + pool_latent_h = int(committed_latents.shape[-2]) if committed_latents is not None else int(self.x_stacked_shape[-2]) + pool_latent_w = int(committed_latents.shape[-1]) if committed_latents is not None else int(self.x_stacked_shape[-1]) + anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( + pool_latent_h, + pool_latent_w, + self.dememwm_anchor_proj, + token_patch_size, + ) + anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( + anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 + ) + anchor_num_tokens = len(anchor_indices) * anchor_pool_h * anchor_pool_w + anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) + allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) + revisit_max_frames = int(self._cfg_get(revisit_cfg, "max_frames", 2)) + revisit_compress_cfg = self._cfg_get(revisit_cfg, "compress", None) + revisit_src_h, revisit_src_w = self._projected_spatial_grid_size( + pool_latent_h, + pool_latent_w, + self.dememwm_revisit_proj, + token_patch_size, + ) + revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size( + revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8 + ) + revisit_max_tokens = revisit_max_frames * revisit_pool_h * revisit_pool_w + recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) + exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) + local_context_exclusion_frames = self._local_context_exclusion_frames() + fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) + high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) + plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) + revisit_retrieval_kwargs = { + "high_quality_fov_threshold": high_quality_fov_threshold, + "fov_half_h": float(self._cfg_get(revisit_cfg, "fov_half_h", 52.5)), + "fov_half_v": float(self._cfg_get(revisit_cfg, "fov_half_v", 37.5)), + "fov_yaw_samples": int(self._cfg_get(revisit_cfg, "fov_yaw_samples", 25)), + "fov_pitch_samples": int(self._cfg_get(revisit_cfg, "fov_pitch_samples", 20)), + "fov_depth_samples": int(self._cfg_get(revisit_cfg, "fov_depth_samples", 20)), + "fov_radius": float(self._cfg_get(revisit_cfg, "fov_radius", 30.0)), + "pose_preselect_topk": self._cfg_get(revisit_cfg, "pose_preselect_topk", 64), + "plucker_grid_h": int(self._cfg_get(revisit_cfg, "plucker_grid_h", 4)), + "plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)), + "plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)), + } + preselection_diag = {} + use_cache_revisit_records = False + revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None + + cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None + cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0} + if committed_latents is not None: + stream_device = committed_latents.device + stream_dtype = committed_latents.dtype + else: + param = next(iter(self.dememwm_anchor_proj.parameters())) + stream_device = param.device + stream_dtype = param.dtype + target_frame_indices = target_frame_indices.to(device=stream_device) + if target_frame_indices.ndim == 1: + target_frame_indices = target_frame_indices[:, None] + + use_cache_records = cache is not None and cache.keep_compressed_records and cache.record_count > 0 + dynamic_latents = committed_latents if dynamic_effective_enabled else None + dynamic_frame_indices = source_frame_indices if dynamic_effective_enabled else None + dynamic_generated = source_is_generated if dynamic_effective_enabled else None + dynamic_pose = pose if dynamic_effective_enabled else None + if dynamic_effective_enabled and cache is not None and cache.raw_frame_slots > 0: + raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( + device=stream_device, + dtype=stream_dtype, + max_recent_frames=recent_frames, + target_frame_indices=target_frame_indices, + exclude_latest_local_frames=exclude_latest_local_frames, + ) + if raw_latents is not None: + dynamic_latents = raw_latents + dynamic_frame_indices = raw_frames + dynamic_generated = raw_generated + dynamic_pose = raw_pose + + if use_cache_records: + B = target_frame_indices.shape[1] + hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) + anchor_banks = ( + cache.memory_banks("anchor", device=stream_device, dtype=stream_dtype, batch_size=B) + if anchor_effective_enabled else [CausalMemoryBank() for _ in range(B)] + ) + revisit_banks = [CausalMemoryBank() for _ in range(B)] + revisit_record_batches = ( + [cache.records_for_batch("revisit", batch_idx) for batch_idx in range(B)] + if revisit_stage_config_enabled else [tuple() for _ in range(B)] + ) + use_cache_revisit_records = bool(revisit_stage_config_enabled) + if dynamic_latents is not None and dynamic_latents.ndim == 5 and dynamic_latents.shape[0] > 0: + tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( + dynamic_latents.shape[-2], + dynamic_latents.shape[-1], + self.dememwm_anchor_proj, + token_patch_size, + ) + tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w + else: + latent_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 0 + latent_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 0 + tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( + latent_h, + latent_w, + self.dememwm_anchor_proj, + token_patch_size, + ) + tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w + else: + if committed_latents is None or source_frame_indices is None: + raise ValueError("committed_latents/source_frame_indices are required when no streaming cache records are available") + B = committed_latents.shape[1] + hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) + target_pose_source = target_pose if target_pose is not None else pose + anchor_banks, revisit_banks, tokens_per_frame, preselection_diag = self._build_preselected_causal_memory_banks( + committed_latents, + source_frame_indices.to(device=stream_device), + None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool), + None if pose is None else pose.to(device=stream_device), + None, + target_frame_indices, + None if target_pose_source is None else target_pose_source.to(device=stream_device), + None, + target_video_ids, + allow_generated_anchor, + anchor_indices, + anchor_pool_h, + anchor_pool_w, + anchor_diverse, + revisit_pool_h, + revisit_pool_w, + revisit_max_frames, + local_context_exclusion_frames, + fov_overlap_threshold, + plucker_weight, + revisit_retrieval_kwargs, + token_patch_size, + ) + revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] + + T_tgt = target_frame_indices.shape[0] + anchor_slots = max(0, anchor_num_tokens) + revisit_slots = max(0, revisit_max_tokens) + anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT + anchor_include_generated = allow_generated_anchor + anchor_token_rows = [] + anchor_mask_rows = [] + anchor_max_rows = [] + for batch_idx, anchor_bank in enumerate(anchor_banks): + batch_token_rows = [] + batch_mask_rows = [] + batch_max_rows = [] + for target_idx in range(T_tgt): + target_frame = int(target_frame_indices[target_idx, batch_idx].item()) + records = anchor_bank.query( + MemoryBankQuery( + target_frame=target_frame, + source_type=anchor_source_type, + include_generated=anchor_include_generated, + max_records=len(anchor_indices), + max_slots=anchor_slots, + ) + ) + anchor_bank.assert_causal(target_frame, records) + stream_tokens, stream_mask, max_source_frame = self._records_to_stream( + records, + anchor_slots, + hidden_size, + stream_device, + stream_dtype, + ) + batch_token_rows.append(stream_tokens) + batch_mask_rows.append(stream_mask) + batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long)) + anchor_token_rows.append(torch.stack(batch_token_rows, dim=0)) + anchor_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) + anchor_max_rows.append(torch.stack(batch_max_rows, dim=0)) + anchor_tokens = torch.stack(anchor_token_rows, dim=0) + anchor_mask = torch.stack(anchor_mask_rows, dim=0) + anchor_max = torch.stack(anchor_max_rows, dim=0) + + if dynamic_latents is None or dynamic_frame_indices is None or dynamic_latents.shape[0] == 0: + _fallback_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 18 + _fallback_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 32 + dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w) + dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype) + dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool) + dynamic_diag = { + "selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), + "max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), + "generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device), + "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), + "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), + "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + } + else: + # Pre-select dynamic source frame positions using only frame index metadata + # before touching latents, so we pass a small slice instead of the full + # 1000-frame tensor to the compressor. + _dfi = dynamic_frame_indices.to(device=stream_device) + _max_src = self.dememwm_dynamic_compressor.max_source_frames + _needed: list[int] = [] + for _b in range(B): + for _j in range(T_tgt): + _target = int(target_frame_indices[_j, _b].item()) + _valid = (_dfi[:, _b] < _target - exclude_latest_local_frames).nonzero(as_tuple=False).flatten() + _needed.extend(_valid[-_max_src:].tolist()) + if _needed: + _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long) + _dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx) + _dynamic_fi_small = _dfi.index_select(0, _needed_idx) + _dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None + _dynamic_gen_small = ( + dynamic_generated.to(device=stream_device, dtype=torch.bool).index_select(0, _needed_idx) + if dynamic_generated is not None else None + ) + else: + _dynamic_latents_small = dynamic_latents[:0] + _dynamic_fi_small = _dfi[:0] + _dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None + _dynamic_gen_small = None + dynamic_tokens, dynamic_mask, dynamic_diag = self.dememwm_dynamic_compressor( + _dynamic_latents_small, + _dynamic_fi_small, + _dynamic_pose_small, + target_frame_indices, + _dynamic_gen_small, + exclude_latest_local_frames=exclude_latest_local_frames, + ) + + dynamic_min_gap_tensor = torch.as_tensor( + dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), + device=stream_device, + ) + dynamic_max_gap_tensor = torch.as_tensor( + dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), + device=stream_device, + ) + dynamic_gap_valid = dynamic_min_gap_tensor >= 0 + dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1 + dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0 + dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1 + def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): + if tensor is None or tensor.ndim < 2: + return None + tensor_dev = tensor.to(device=stream_device) + if tensor_dev.shape[0] == T_tgt and tensor_dev.shape[1] == B: + return tensor_dev[target_idx, batch_idx] + if tensor_dev.shape[0] == B and tensor_dev.shape[1] == T_tgt: + return tensor_dev[batch_idx, target_idx] + return None + + def _target_video_id_or_none(batch_idx: int, target_idx: int): + if target_video_ids is None: + return None + if torch.is_tensor(target_video_ids): + ids = target_video_ids.detach().cpu() + if ids.ndim == 0: + return ids.item() + if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: + return ids[target_idx, batch_idx].item() + if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: + return ids[batch_idx, target_idx].item() + return None + if isinstance(target_video_ids, (list, tuple)): + if len(target_video_ids) == B: + return target_video_ids[batch_idx] + if len(target_video_ids) == T_tgt: + row = target_video_ids[target_idx] + if isinstance(row, (list, tuple)) and len(row) == B: + return row[batch_idx] + return row + return target_video_ids + + target_pose_source = target_pose if target_pose is not None else pose + + revisit_token_rows = [] + revisit_mask_rows = [] + revisit_max_rows = [] + valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) + revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32) + valid_revisit_target_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) + eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) + revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long) + eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES) + revisit_result_diagnostics = [] + projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {} + if revisit_record_batches is None: + revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] + for batch_idx in range(B): + revisit_bank = revisit_banks[batch_idx] + batch_token_rows = [] + batch_mask_rows = [] + batch_max_rows = [] + for target_idx in range(T_tgt): + target_frame = int(target_frame_indices[target_idx, batch_idx].item()) + if use_cache_revisit_records: + candidate_records = list(revisit_record_batches[batch_idx]) + else: + candidate_records = revisit_bank.query( + MemoryBankQuery( + target_frame=target_frame, + include_generated=True, + ) + ) + result = deterministic_revisit_retrieval( + candidate_records, + target_frame=target_frame, + target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx), + target_summary=None, + topk=revisit_max_frames, + exclude_local_context_frames=local_context_exclusion_frames, + fov_overlap_threshold=fov_overlap_threshold, + plucker_weight=plucker_weight, + target_video_id=_target_video_id_or_none(batch_idx, target_idx), + **revisit_retrieval_kwargs, + ) + selected_records = result.records + if use_cache_revisit_records and selected_records: + selected_records = self._project_streaming_revisit_records( + cache=cache, + batch_idx=batch_idx, + records=selected_records, + device=stream_device, + dtype=stream_dtype, + token_patch_size=token_patch_size, + revisit_pool_h=revisit_pool_h, + revisit_pool_w=revisit_pool_w, + projection_cache=projected_revisit_record_cache, + ) + revisit_result_diagnostics.append(result.diagnostics) + revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) + revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) + revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0)) + revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0)) + revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1)) + valid_revisit_target_mask[batch_idx, target_idx] = bool(result.diagnostics.get("valid_revisit_target_count", 0)) + revisit_bank.assert_causal(target_frame, selected_records) + if selected_records: + valid_revisit_mask[batch_idx, target_idx] = True + stream_tokens, stream_mask, max_source_frame = self._records_to_stream( + selected_records, + revisit_slots, + hidden_size, + stream_device, + stream_dtype, + ) + revisit_causal_max[batch_idx, target_idx] = max_source_frame + if eval_corruption_enabled: + stream_tokens, was_corrupted = apply_revisit_eval_corruption( + tokens=stream_tokens, + mask=stream_mask, + branch=eval_ablation_branch, + target_frame=target_frame, + ) + eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted) + actual_max_source_frame = max((int(record.max_source_frame) for record in selected_records), default=max_source_frame) + batch_token_rows.append(stream_tokens) + batch_mask_rows.append(stream_mask) + batch_max_rows.append(torch.as_tensor(actual_max_source_frame, device=stream_device, dtype=torch.long)) + revisit_token_rows.append(torch.stack(batch_token_rows, dim=0)) + revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) + revisit_max_rows.append(torch.stack(batch_max_rows, dim=0)) + revisit_tokens = torch.stack(revisit_token_rows, dim=0) + revisit_mask = torch.stack(revisit_mask_rows, dim=0) + revisit_max = torch.stack(revisit_max_rows, dim=0) + + if anchor_tokens.shape[-2] != anchor_num_tokens: + raise AssertionError(f"anchor token budget mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}") + if dynamic_latents is not None and dynamic_latents.shape[0] > 0: + _expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target( + int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1]) + ) + if dynamic_tokens.shape[-2] != _expected_dyn: + raise AssertionError(f"dynamic token budget mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}") + if revisit_tokens.shape[-2] > revisit_max_tokens: + raise AssertionError(f"revisit token cap exceeded: got {revisit_tokens.shape[-2]}, cap {revisit_max_tokens}") + anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 + dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 + gate_module = getattr(self, "dememwm_revisit_gate", None) + if gate_module is None: + revisit_gate_raw = torch.ones((B, T_tgt), device=stream_device, dtype=stream_dtype) + else: + revisit_gate_raw = gate_module( + valid_revisit_mask=valid_revisit_mask, + best_selected_fov_overlap=revisit_best_selected_fov_overlap, + best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, + selected_gap_frames=revisit_selected_gap_frames, + ).to(device=stream_device, dtype=stream_dtype) + valid_revisit_eff_mask = valid_revisit_mask + if not revisit_stage_config_enabled or force_revisit_off: + revisit_gate = torch.zeros_like(revisit_gate_raw) + elif force_revisit_on: + revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * torch.ones_like(revisit_gate_raw) + else: + revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * revisit_gate_raw * float(gates.revisit_gate) + revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) + if not anchor_effective_enabled: + anchor_mask = torch.zeros_like(anchor_mask) + if not dynamic_effective_enabled: + dynamic_mask = torch.zeros_like(dynamic_mask) + if not revisit_stage_config_enabled: + revisit_mask = torch.zeros_like(revisit_mask) + valid_revisit_mask = torch.zeros_like(valid_revisit_mask) + revisit_candidate_count = torch.zeros_like(revisit_candidate_count) + revisit_selected_count = torch.zeros_like(revisit_selected_count) + revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap) + revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap) + revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0) + valid_revisit_target_mask = torch.zeros_like(valid_revisit_target_mask) + eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask) + valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) + revisit_gate_raw = torch.zeros_like(revisit_gate_raw) + revisit_gate = torch.zeros_like(revisit_gate) + no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) + revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask) + causal_violation_count = 0 + for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max): + if source_max is None: + continue + source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device) + valid = source_max_t >= 0 + if valid.any(): + causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item()) + diagnostics = { + **curriculum_state.diagnostics(), + **getattr(self, "_last_dememwm_freeze_diagnostics", {}), + **contract_diag, + **cache_diag, + **preselection_diag, + **revisit_diag, + "dememwm_stage": gates.stage, + "dememwm_gate_reason": gates.reason, + "anchor_config_enabled": anchor_config_enabled, + "dynamic_config_enabled": dynamic_config_enabled, + "revisit_config_enabled": revisit_config_enabled, + "anchor_effective_enabled": anchor_effective_enabled, + "dynamic_effective_enabled": dynamic_effective_enabled, + "revisit_effective_enabled": revisit_effective_enabled, + "revisit_stage_config_enabled": revisit_stage_config_enabled, + "revisit_gate_raw": revisit_gate_raw.detach(), + "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), + "no_valid_revisit_mask": no_valid_revisit_mask, + "valid_revisit_eff_mask": valid_revisit_eff_mask, + "valid_revisit_target_mask": valid_revisit_target_mask, + "revisit_candidate_frame_count_per_target": revisit_candidate_count, + "revisit_selected_frame_count_per_target": revisit_selected_count, + "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap, + "revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap, + "revisit_selected_gap_frames_per_target": revisit_selected_gap_frames, + "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, + "revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()), + **summarize_noise_bucket_diagnostics( + noise_bucket=resolved_noise_bucket, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + noise_bucket_ids=noise_bucket_ids, + ), + **summarize_eval_ablation_diagnostics( + enabled=eval_ablation_enabled, + branch=eval_ablation_branch, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None, + ), + "token_patch_size": token_patch_size, + "tokens_per_frame": tokens_per_frame, + "anchor_token_slots": int(anchor_tokens.shape[-2]), + "anchor_budget_tokens": anchor_num_tokens, + "anchor_pool_h": anchor_pool_h, + "anchor_pool_w": anchor_pool_w, + "dynamic_token_slots": int(dynamic_tokens.shape[-2]), + "dynamic_budget_tokens": int(dynamic_tokens.shape[-2]), + "dynamic_min_gap_to_target": dynamic_min_gap_to_target, + "dynamic_max_gap_to_target": dynamic_max_gap_to_target, + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + "revisit_token_slots": int(revisit_tokens.shape[-2]), + "revisit_max_tokens": revisit_max_tokens, + "revisit_local_context_exclusion_frames": local_context_exclusion_frames, + "revisit_high_quality_fov_threshold": high_quality_fov_threshold, + "revisit_pool_h": revisit_pool_h, + "revisit_pool_w": revisit_pool_w, + "revisit_max_frames": revisit_max_frames, + "anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0, + "dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0, + "revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0, + "causal_violation_count": causal_violation_count, + "anchor_max_source_frame": anchor_max, + "dynamic_max_source_frame": dynamic_diag.get("max_source_frame"), + "revisit_max_source_frame": revisit_max, + "dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"), + } + if eval_corruption_enabled: + diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask + + return MemoryStreamTensors( + anchor_tokens=anchor_tokens, + anchor_mask=anchor_mask, + dynamic_tokens=dynamic_tokens, + dynamic_mask=dynamic_mask, + revisit_tokens=revisit_tokens, + revisit_mask=revisit_mask, + anchor_gate=anchor_gate, + dynamic_gate=dynamic_gate, + revisit_gate=revisit_gate, + revisit_gate_raw=revisit_gate_raw, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + diagnostics=diagnostics, + ) + + def _refresh_stream_gates( + self, + streams: MemoryStreamTensors, + denoising_fraction: float | None = None, + noise_bucket: str | None = None, + ) -> MemoryStreamTensors: + gate_state = self._effective_gate_state( + denoising_fraction=denoising_fraction, + noise_bucket=noise_bucket, + ) + gates = gate_state["gates"] + device = streams.anchor_tokens.device + dtype = streams.anchor_tokens.dtype + B, T_tgt = streams.anchor_tokens.shape[:2] + valid_revisit_mask = streams.valid_revisit_mask + if valid_revisit_mask is None: + valid_revisit_mask = torch.zeros((B, T_tgt), device=device, dtype=torch.bool) + else: + valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool) + + diagnostics = dict(streams.diagnostics) + + def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor: + value = diagnostics.get(name) + if value is None: + return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32) + tensor = torch.as_tensor(value, device=device, dtype=torch.float32) + if tensor.ndim == 0: + return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32) + return tensor.expand((B, T_tgt)) + + revisit_best_selected_fov_overlap = _diagnostic_tensor("revisit_best_selected_fov_overlap_per_target") + revisit_best_selected_plucker_overlap = _diagnostic_tensor("revisit_best_selected_plucker_overlap_per_target") + revisit_selected_gap_frames = _diagnostic_tensor("revisit_selected_gap_frames_per_target", -1.0) + + anchor_effective_enabled = gate_state["anchor_effective_enabled"] + dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] + revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] + anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 + dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 + gate_module = getattr(self, "dememwm_revisit_gate", None) + if gate_module is None: + revisit_gate_raw = torch.ones((B, T_tgt), device=device, dtype=dtype) + else: + revisit_gate_raw = gate_module( + valid_revisit_mask=valid_revisit_mask, + best_selected_fov_overlap=revisit_best_selected_fov_overlap, + best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, + selected_gap_frames=revisit_selected_gap_frames, + ).to(device=device, dtype=dtype) + valid_revisit_eff_mask = valid_revisit_mask + if not revisit_stage_config_enabled or gate_state["force_revisit_off"]: + revisit_gate = torch.zeros_like(revisit_gate_raw) + elif gate_state["force_revisit_on"]: + revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw) + else: + revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate) + if not revisit_stage_config_enabled: + valid_revisit_mask = torch.zeros_like(valid_revisit_mask) + valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) + revisit_gate_raw = torch.zeros_like(revisit_gate_raw) + revisit_gate = torch.zeros_like(revisit_gate) + no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) + eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask") + if eval_corrupted_revisit_mask is not None: + eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool) + revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) + diagnostics.update(gate_state["curriculum_state"].diagnostics()) + diagnostics.update({ + "dememwm_stage": gates.stage, + "dememwm_gate_reason": gates.reason, + "anchor_config_enabled": gate_state["anchor_config_enabled"], + "dynamic_config_enabled": gate_state["dynamic_config_enabled"], + "revisit_config_enabled": gate_state["revisit_config_enabled"], + "anchor_effective_enabled": anchor_effective_enabled, + "dynamic_effective_enabled": dynamic_effective_enabled, + "revisit_effective_enabled": revisit_effective_enabled, + "revisit_stage_config_enabled": revisit_stage_config_enabled, + "revisit_gate_raw": revisit_gate_raw.detach(), + "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), + "no_valid_revisit_mask": no_valid_revisit_mask, + "valid_revisit_eff_mask": valid_revisit_eff_mask, + "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, + "revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0, + }) + diagnostics.update(summarize_noise_bucket_diagnostics( + noise_bucket=gate_state["resolved_noise_bucket"], + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + )) + diagnostics.update(summarize_eval_ablation_diagnostics( + enabled=gate_state["eval_ablation_enabled"], + branch=gate_state["eval_ablation_branch"], + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + eval_corrupted_revisit_mask=eval_corrupted_revisit_mask, + )) + return replace( + streams, + anchor_gate=anchor_gate, + dynamic_gate=dynamic_gate, + revisit_gate=revisit_gate, + revisit_gate_raw=revisit_gate_raw, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + diagnostics=diagnostics, + ) + + def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> tuple[dict, dict]: + memory_kwargs, diagnostics = self.dememwm_injection_adapter(streams, device=streams.anchor_tokens.device, dtype=streams.anchor_tokens.dtype) + return memory_kwargs, diagnostics + + def build_memory_kwargs(self, *args, **kwargs) -> tuple[dict, dict]: + streams = self.build_memory_streams(*args, **kwargs) + return self._streams_to_kwargs(streams) + + def _memory_adapter_delta_diagnostics(self) -> dict: + dit_model = getattr(getattr(self, "diffusion_model", None), "model", None) + diagnostics_fn = getattr(dit_model, "memory_adapter_delta_diagnostics", None) + if diagnostics_fn is None: + return {} + return diagnostics_fn() + + def _log_memory_diagnostics(self, namespace: str, diagnostics: dict) -> None: + if namespace == "training/dememwm": + allowed_keys = self._TRAIN_DIAGNOSTIC_LOG_KEYS + elif namespace.endswith("/dememwm"): + allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS + else: + allowed_keys = None + for key, value in diagnostics.items(): + if allowed_keys is not None and key not in allowed_keys: + continue + if isinstance(value, str) or value is None: + continue + if torch.is_tensor(value): + if value.numel() > 0: + self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True) + elif isinstance(value, (bool, int, float)): + self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True) + + def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx): + from ..df_video import convert_to_plucker + image_height, image_width = self._image_size(xs) + if self.use_plucker: + if self.relative_embedding: + input_pose_condition = [] + frame_idx_list = [] + ref_c2w = c2w_mat[-self.memory_condition_length:] if self.memory_condition_length else c2w_mat[:0] + ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0] + for i in range(c2w_mat.shape[0]): + input_pose_condition.append( + convert_to_plucker( + torch.cat([c2w_mat[i:i + 1], ref_c2w]).clone(), + 0, + focal_length=self.focal_length, + image_height=image_height, image_width=image_width + ).to(xs.dtype) + ) + frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone()) + return torch.cat(input_pose_condition), torch.cat(frame_idx_list) + return convert_to_plucker( + c2w_mat, 0, focal_length=self.focal_length, + image_height=image_height, image_width=image_width + ).to(xs.dtype), frame_idx + return pose_conditions.to(xs.dtype), None + + def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]: + total_frames = max(0, int(total_frames)) + n_tokens = max(1, min(int(self.n_tokens), total_frames)) + max_start = max(0, total_frames - n_tokens) + if max_start == 0: + return 0, n_tokens + context_start = self._context_frame_count() + min_start = min(context_start, max_start) + if min_start == max_start: + return min_start, min_start + n_tokens + start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item()) + return start, start + n_tokens + + def training_step(self, batch, batch_idx): + xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) + xs = self._as_latents(xs) + + # Randomly select a contiguous n_tokens denoising window inside the long + # clip. DeMemWM memory streams are selected causally from frames before + # each target, then only those selected frames are projected. + total_frames = xs.shape[0] + start, end = self._training_window_bounds(total_frames, xs.device) + + xs_window = xs[start:end] + conditions_window = conditions[start:end].clone() + frame_idx_window = frame_idx[start:end] + + input_pose_condition, frame_idx_list = self._training_pose_condition( + xs_window, pose_conditions[start:end], c2w_mat[start:end], frame_idx_window + ) + + noise_levels = self._generate_noise_levels(xs_window) + if self.memory_condition_length: + noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level + conditions_window[-self.memory_condition_length:] *= 0 + source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool) + memory_source_latents, source_is_generated, proxy_diagnostics = self._apply_generated_history_proxy( + xs, + source_is_generated, + context_frame_count=self._context_frame_count(), + target_start_frame=start, + ) + timesteps = int(getattr(self, "timesteps", 0) or 0) + training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps) + training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps) + training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps) + memory_kwargs, diagnostics = self.build_memory_kwargs( + memory_source_latents, + frame_idx, + target_frame_indices=frame_idx_window, + pose=pose_conditions, + target_pose=pose_conditions[start:end], + action=conditions, + target_action=conditions_window, + source_is_generated=source_is_generated, + denoising_fraction=training_denoising_fraction, + noise_bucket=training_noise_bucket, + noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1), + ) + diagnostics.update(proxy_diagnostics) + _, loss = self.diffusion_model( + xs_window, + conditions_window, + input_pose_condition, + noise_levels=noise_levels, + reference_length=self.memory_condition_length, + frame_idx=frame_idx_list, + **memory_kwargs, + ) + diagnostics.update(self._memory_adapter_delta_diagnostics()) + if self.memory_condition_length: + loss = loss[:-self.memory_condition_length] + loss_denoise = self.reweight_loss(loss, None) + loss_total = loss_denoise + diagnostics["training_window_start"] = int(start) + diagnostics["training_window_end"] = int(end) + diagnostics["training_window_size"] = int(end - start) + diagnostics["loss_denoise"] = float(loss_denoise.detach().item()) + diagnostics["loss_total"] = float(loss_total.detach().item()) + if batch_idx % 20 == 0: + self.log("training/loss", loss_total.detach().cpu()) + self._log_memory_diagnostics("training/dememwm", diagnostics) + return {"loss": loss_total} + + def validation_step(self, batch, batch_idx, namespace="validation"): + import numpy as np + from tqdm import tqdm + + memory_condition_length = self.memory_condition_length + xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) + total_frame = xs_raw.shape[0] + if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): + xs = xs_raw.cpu() + elif total_frame > 10: + xs = torch.cat([self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu() for i in range(10)]) + else: + xs = self.encode(xs_raw).cpu() + n_frames, batch_size, *_ = xs.shape + curr_frame = 0 + n_context_frames = self.context_frames // self.frame_stack + xs_pred = xs[:n_context_frames].clone() + curr_frame += n_context_frames + streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}") + cached_until = 0 + pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") + last_diagnostics = None + while curr_frame < n_frames: + if streaming_cache is not None and curr_frame > cached_until: + new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) + if curr_frame > n_context_frames: + rel_start = max(0, n_context_frames - cached_until) + new_generated[rel_start:] = True + self._update_streaming_cache( + streaming_cache, + xs_pred[cached_until:curr_frame], + frame_idx[cached_until:curr_frame], + pose=pose_conditions[cached_until:curr_frame], + source_is_generated=new_generated, + action=conditions[cached_until:curr_frame], + ) + cached_until = curr_frame + horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame + assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens." + scheduling_matrix = self._generate_scheduling_matrix(horizon) + chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])) + chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device) + xs_pred = torch.cat([xs_pred, chunk], 0) + start_frame = max(0, curr_frame + horizon - self.n_tokens) + pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon}) + if memory_condition_length: + random_idx = self._generate_condition_indices(curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon) + xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0) + else: + random_idx = torch.empty((0, batch_size), dtype=torch.long, device=frame_idx.device) + input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions( + start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx, + image_width=self._image_size(xs_raw)[1], image_height=self._image_size(xs_raw)[0] + ) + target_idx = frame_idx[start_frame:curr_frame + horizon].to(input_condition.device) + use_streaming_cache = streaming_cache is not None and streaming_cache.record_count > 0 + target_pose = pose_conditions[start_frame:curr_frame + horizon].to(input_condition.device) + target_action = conditions[start_frame:curr_frame + horizon].to(input_condition.device) + if use_streaming_cache: + committed_latents = None + committed_idx = None + generated_flags = None + source_pose = None + source_action = None + else: + committed_latents = xs_pred[:curr_frame].to(input_condition.device) + committed_idx = frame_idx[:curr_frame].to(input_condition.device) + generated_flags = torch.zeros(committed_idx.shape, device=input_condition.device, dtype=torch.bool) + if curr_frame > n_context_frames: + generated_flags[n_context_frames:] = True + source_pose = pose_conditions[:curr_frame].to(input_condition.device) + source_action = conditions[:curr_frame].to(input_condition.device) + memory_streams = self.build_memory_streams( + committed_latents, + committed_idx, + target_frame_indices=target_idx, + pose=source_pose, + target_pose=target_pose, + action=source_action, + target_action=target_action, + source_is_generated=generated_flags, + denoising_fraction=None, + streaming_cache=streaming_cache, + ) + for m in range(scheduling_matrix.shape[0] - 1): + from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length) + denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0) + step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac) + memory_kwargs, last_diagnostics = self._streams_to_kwargs(step_streams) + xs_pred[start_frame:] = self.diffusion_model.sample_step( + xs_pred[start_frame:].to(input_condition.device), + input_condition, + input_pose_condition, + from_noise_levels[start_frame:], + to_noise_levels[start_frame:], + current_frame=curr_frame, + mode="validation", + reference_length=memory_condition_length, + frame_idx=frame_idx_list, + **memory_kwargs, + ).cpu() + if memory_condition_length: + xs_pred = xs_pred[:-memory_condition_length] + curr_frame += horizon + if streaming_cache is not None and curr_frame > cached_until: + new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) + if curr_frame > n_context_frames: + rel_start = max(0, n_context_frames - cached_until) + new_generated[rel_start:] = True + self._update_streaming_cache( + streaming_cache, + xs_pred[cached_until:curr_frame], + frame_idx[cached_until:curr_frame], + pose=pose_conditions[cached_until:curr_frame], + source_is_generated=new_generated, + action=conditions[cached_until:curr_frame], + ) + cached_until = curr_frame + if last_diagnostics is not None: + last_diagnostics.update(streaming_cache.diagnostics("cache")) + pbar.update(horizon) + pbar.close() + if last_diagnostics is not None: + self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics) + xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device)) + xs_decode = self.decode(xs[n_context_frames:].to(conditions.device)) + self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu())) + return + + def strict_checkpoint_key_check(self, state_dict: dict, required_prefixes: Iterable[str] | None = None) -> None: + prefixes = tuple(required_prefixes or self.strict_key_prefixes) + strip_prefixes = ("", "model.", "module.", "algo.") + normalized_keys = [] + for key in state_dict.keys(): + key = str(key) + for strip_prefix in strip_prefixes: + if not strip_prefix or key.startswith(strip_prefix): + normalized_keys.append(key.removeprefix(strip_prefix)) + missing_prefixes = [prefix for prefix in prefixes if not any(key.startswith(prefix) for key in normalized_keys)] + missing_substrings = [ + marker + for marker in self.strict_key_substrings + if not any(marker in key for key in normalized_keys) + ] + if missing_prefixes or missing_substrings: + raise RuntimeError( + "DeMemWM checkpoint is missing required DeMemWM key coverage: " + f"prefixes={missing_prefixes}, memory_adapter_markers={missing_substrings}" + ) + + # Compatibility aliases for old DeMemWM test and experiment call sites. + dememwm_strict_key_prefixes = strict_key_prefixes + dememwm_strict_key_substrings = strict_key_substrings + _DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS + _DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS + _dememwm_cfg = _memory_cfg + _dememwm_stage_policy_cfg = _stage_policy_cfg + _dememwm_eval_ablation_cfg = _eval_ablation_cfg + _dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg + _dememwm_eval_ablation_state = _eval_ablation_state + _dememwm_effective_gate_state = _effective_gate_state + _dememwm_validate_config_contract = _validate_config_contract + _dememwm_stream_enabled = _stream_enabled + _dememwm_context_frame_count = _context_frame_count + _dememwm_local_context_exclusion_frames = _local_context_exclusion_frames + _dememwm_curriculum_state = _curriculum_state + _dememwm_generated_history_proxy_prob = _generated_history_proxy_prob + _dememwm_apply_generated_history_proxy = _apply_generated_history_proxy + _dememwm_checkpoint_cfg = _checkpoint_cfg + _dememwm_strict_eval_load_enabled = _strict_eval_load_enabled + _dememwm_cache_cfg = _cache_cfg + _dememwm_cache_enabled = _cache_enabled + _dememwm_new_streaming_cache = _new_streaming_cache + _dememwm_is_memory_adapter_param = _is_memory_adapter_param + _dememwm_param_group_name = _param_group_name + _dememwm_group_trainable = _group_trainable + _dememwm_group_lr = _group_lr + _dememwm_apply_freeze_policy = _apply_freeze_policy + _dememwm_as_latents = _as_latents + _dememwm_image_size = _image_size + _dememwm_update_streaming_cache = _update_streaming_cache + _build_dememwm_streaming_cache_records = _build_streaming_cache_records + _build_dememwm_causal_memory_banks = _build_causal_memory_banks + _build_dememwm_preselected_causal_memory_banks = _build_preselected_causal_memory_banks + _dememwm_records_to_stream = _records_to_stream + build_dememwm_memory_streams = build_memory_streams + _dememwm_refresh_stream_gates = _refresh_stream_gates + _dememwm_streams_to_kwargs = _streams_to_kwargs + build_dememwm_memory_kwargs = build_memory_kwargs + _dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics + _log_dememwm_diagnostics = _log_memory_diagnostics + _dememwm_training_window_bounds = _training_window_bounds + strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check + + +DeMemWMMemoryDiTMixin = MemoryDiTMixin