from __future__ import annotations import math from dataclasses import replace from typing import Iterable import torch from einops import rearrange from .cache import StreamingCache from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics from .injection import InjectionAdapter from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens from .negatives import apply_revisit_eval_corruption from .retrieval import deterministic_revisit_retrieval from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, noise_bucket_from_denoising_fraction, noise_bucket_from_noise_levels, noise_bucket_ids_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors class MemoryDiTMixin: """Standalone DeMemWM / Memory-DiT mixin. Reuses the base video-DiT infrastructure while keeping memory construction and injection under the standalone `dememwm` package. Legacy memory-method files are not part of this path. """ strict_key_prefixes = ( "dememwm_dynamic_compressor.", "dememwm_anchor_proj.", "dememwm_revisit_proj.", "dememwm_revisit_gate.", ) strict_key_substrings = ( ".memory_token_cross_attn.", ) _TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({ "revisit_candidate_frame_count", "revisit_pose_preselect_input_count", "revisit_pose_preselect_selected_count", "revisit_exact_fov_candidate_count", "valid_revisit_frame_count", "no_valid_revisit_count", "revisit_selected_frame_count", "revisit_frame_fov_overlap_mean", "revisit_best_selected_frame_fov_overlap_mean", "revisit_best_selected_plucker_overlap_mean", "revisit_best_selected_gap_frames_mean", "revisit_gate_raw", "revisit_gate_eff", "revisit_learned_gate_mean", "revisit_effective_gate_mean", "generated_history_proxy_prob", "noise_bucket_target_count", "noise_bucket_high_target_count", "noise_bucket_mid_target_count", "noise_bucket_low_target_count", }) _VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({ "cache_records", "cache_slots", }) def _memory_cfg(self): return getattr(self.cfg, "dememwm", None) def _cfg_get(self, obj, name, default): if obj is None: return default if isinstance(obj, dict): return obj.get(name, default) return getattr(obj, name, default) def _cfg_has(self, obj, name: str) -> bool: if obj is None: return False if isinstance(obj, dict): return name in obj try: getattr(obj, name) return True except Exception: return False def _stage_policy_cfg(self): return self._cfg_get(self._memory_cfg(), "stage_policy", None) def _eval_ablation_cfg(self): return self._cfg_get(self._memory_cfg(), "eval_ablation", None) def _generated_history_proxy_cfg(self): return self._cfg_get(self._memory_cfg(), "generated_history_proxy", None) def _eval_ablation_state(self) -> tuple[bool, str]: cfg = self._eval_ablation_cfg() enabled = bool(self._cfg_get(cfg, "enabled", False)) branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal")) return enabled, branch def _effective_gate_state(self, denoising_fraction: float | None = None, noise_bucket: str | None = None) -> dict: memory_cfg = self._memory_cfg() anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) injection_cfg = self._cfg_get(memory_cfg, "injection", None) anchor_config_enabled = self._stream_enabled(anchor_cfg) dynamic_config_enabled = self._stream_enabled(dynamic_cfg) revisit_config_enabled = self._stream_enabled(revisit_cfg) curriculum_state = self._curriculum_state() eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state() debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False)) resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction) gates = compute_stream_gates( curriculum_state.stage, denoising_fraction=denoising_fraction, debug_force_all_streams=debug_force, anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)), dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)), revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)), ) anchor_effective_enabled = bool(gates.anchor_enabled and anchor_config_enabled) dynamic_effective_enabled = bool(gates.dynamic_enabled and dynamic_config_enabled) revisit_stage_config_enabled = bool(gates.revisit_enabled and revisit_config_enabled) if eval_ablation_enabled: if eval_ablation_branch == "memory_off": anchor_effective_enabled = False dynamic_effective_enabled = False revisit_stage_config_enabled = False elif eval_ablation_branch == "A_only": dynamic_effective_enabled = False revisit_stage_config_enabled = False elif eval_ablation_branch == "D_only": anchor_effective_enabled = False revisit_stage_config_enabled = False elif eval_ablation_branch == "A_plus_D": revisit_stage_config_enabled = False return { "curriculum_state": curriculum_state, "gates": gates, "resolved_noise_bucket": resolved_noise_bucket, "anchor_config_enabled": anchor_config_enabled, "dynamic_config_enabled": dynamic_config_enabled, "revisit_config_enabled": revisit_config_enabled, "anchor_effective_enabled": anchor_effective_enabled, "dynamic_effective_enabled": dynamic_effective_enabled, "revisit_stage_config_enabled": revisit_stage_config_enabled, "eval_ablation_enabled": eval_ablation_enabled, "eval_ablation_branch": eval_ablation_branch, "force_revisit_off": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_off"), "force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"), } def _validate_config_contract(self) -> dict: if bool(getattr(self, "_dememwm_contract_validated", False)): return getattr(self, "_last_dememwm_config_diagnostics", {}) memory_cfg = self._memory_cfg() if memory_cfg is None: self._dememwm_contract_validated = True self._last_dememwm_config_diagnostics = {} return {} stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)] if stale_sections: raise ValueError(f"stale DeMemWM config sections are not part of the final contract: {stale_sections}") ratio_fields = [ name for name in ("anchor_ratio", "dynamic_ratio", "revisit_ratio", "revisit_max_ratio") if self._cfg_has(memory_cfg, name) ] if ratio_fields: raise ValueError(f"standalone DeMemWM derives stream slots from latent shape and compression settings, not ratio fields: {ratio_fields}") anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) stale_nested = [] for section_name, section_cfg, field_names in ( ("anchor", anchor_cfg, ("policy", "topk", "pin_prefix")), ("dynamic", dynamic_cfg, ("include_generated_recent",)), ("revisit", revisit_cfg, ("deterministic_only", "min_age_frames", "min_gap_frames", "topk", "max_chunks", "chunk_frames", "min_score", "time_weight", "pose_weight", "latent_weight", "pose_overlap_threshold", "action_overlap_threshold", "generated_penalty", "force_gate_zero_when_invalid")), ): stale_nested.extend( f"{section_name}.{field_name}" for field_name in field_names if self._cfg_has(section_cfg, field_name) ) if stale_nested: raise ValueError(f"stale DeMemWM config fields are not part of the final contract: {stale_nested}") exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) if exclude_latest_local_frames < 0: raise ValueError("dememwm.dynamic.exclude_latest_local_frames must be non-negative") if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)): raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval") fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) if fov_overlap_threshold is not None: fov_overlap_threshold = float(fov_overlap_threshold) if fov_overlap_threshold < 0.0: raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative") high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) if high_quality_fov_threshold < 0.0: raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative") plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) if plucker_weight < 0.0: raise ValueError("dememwm.revisit.plucker_weight must be non-negative") for field_name, default in ( ("fov_half_h", 52.5), ("fov_half_v", 37.5), ("fov_radius", 30.0), ("plucker_focal_length", 0.35), ): value = float(self._cfg_get(revisit_cfg, field_name, default)) if value <= 0.0: raise ValueError(f"dememwm.revisit.{field_name} must be positive") for field_name, default in ( ("fov_yaw_samples", 25), ("fov_pitch_samples", 20), ("fov_depth_samples", 20), ("plucker_grid_h", 4), ("plucker_grid_w", 4), ): value = int(self._cfg_get(revisit_cfg, field_name, default)) if value <= 0: raise ValueError(f"dememwm.revisit.{field_name} must be positive") stage_policy_cfg = self._stage_policy_cfg() if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)): raise ValueError("final DeMemWM keeps noise_bucket logging enabled") proxy_cfg = self._generated_history_proxy_cfg() proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0)) proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0)) proxy_noise_std = float(self._cfg_get(proxy_cfg, "noise_std", 0.0)) proxy_ramp_steps = int(self._cfg_get(proxy_cfg, "ramp_steps", 0)) if proxy_max_prob < 0.0 or proxy_max_prob > 1.0: raise ValueError("dememwm.generated_history_proxy.max_prob must be in [0, 1]") if proxy_dropout_prob < 0.0 or proxy_dropout_prob > 1.0: raise ValueError("dememwm.generated_history_proxy.dropout_prob must be in [0, 1]") if proxy_noise_std < 0.0: raise ValueError("dememwm.generated_history_proxy.noise_std must be non-negative") if proxy_ramp_steps < 0: raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative") eval_ablation_cfg = self._eval_ablation_cfg() normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal")) diagnostics = { "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, "revisit_deterministic_fov_plucker_retrieval": True, "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(), "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold, "revisit_plucker_weight": plucker_weight, "stage_policy_noise_bucket_logging": True, } self._dememwm_contract_validated = True self._last_dememwm_config_diagnostics = diagnostics return diagnostics def _stream_enabled(self, stream_cfg) -> bool: return bool(self._cfg_get(stream_cfg, "enabled", True)) def _context_frame_count(self) -> int: frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) return max(0, int(getattr(self, "context_frames", 0) or 0) // frame_stack) def _local_context_exclusion_frames(self) -> int: n_tokens = max(0, int(getattr(self, "n_tokens", 0) or 0)) frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) return n_tokens * frame_stack def _curriculum_state(self, step: int | None = None): if step is None: step = int(getattr(self, "global_step", 0) or 0) return resolve_curriculum(self._memory_cfg(), step) def _generated_history_proxy_prob(self, step: int | None = None) -> float: cfg = self._generated_history_proxy_cfg() if not bool(self._cfg_get(cfg, "enabled", False)): return 0.0 max_prob = min(max(float(self._cfg_get(cfg, "max_prob", 0.0)), 0.0), 1.0) if max_prob <= 0.0: return 0.0 if step is None: step = int(getattr(self, "global_step", 0) or 0) start_step = int(self._cfg_get(cfg, "start_step", 0)) if step < start_step: return 0.0 ramp_steps = int(self._cfg_get(cfg, "ramp_steps", 0)) if ramp_steps <= 0: return max_prob ramp_fraction = min(max(float(step - start_step) / float(ramp_steps), 0.0), 1.0) return max_prob * ramp_fraction def _apply_generated_history_proxy( self, source_latents: torch.Tensor, source_is_generated: torch.Tensor | None, context_frame_count: int | None = None, target_start_frame: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor, dict]: cfg = self._generated_history_proxy_cfg() prob = self._generated_history_proxy_prob() noise_std = float(self._cfg_get(cfg, "noise_std", 0.0)) dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0)) diagnostics = { "generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)), "generated_history_proxy_prob": float(prob), "generated_history_proxy_noise_std": float(noise_std), "generated_history_proxy_dropout_prob": float(dropout_prob), "generated_history_proxy_frame_count": 0, "generated_history_proxy_frame_fraction": 0.0, } if source_is_generated is None: source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) else: source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool) if prob <= 0.0 or source_latents.numel() == 0: return source_latents, source_is_generated, diagnostics eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) if context_frame_count is not None or target_start_frame is not None: frame_positions = torch.arange(source_latents.shape[0], device=source_latents.device)[:, None] if context_frame_count is not None: eligible_mask &= frame_positions >= max(0, int(context_frame_count)) if target_start_frame is not None: eligible_mask &= frame_positions < max(0, int(target_start_frame)) proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask proxy_count = int(proxy_mask.detach().long().sum().item()) total_count = max(1, int(proxy_mask.numel())) diagnostics["generated_history_proxy_frame_count"] = proxy_count diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count) if proxy_count == 0: return source_latents, source_is_generated, diagnostics corrupt_latents = source_latents.clone() frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype) if noise_std > 0.0: corrupt_latents = corrupt_latents + torch.randn_like(corrupt_latents) * float(noise_std) * frame_mask if dropout_prob > 0.0: dropout_mask = torch.rand( (*source_latents.shape[:2], 1, source_latents.shape[-2], source_latents.shape[-1]), device=source_latents.device, ) < dropout_prob dropout_mask = dropout_mask & proxy_mask[:, :, None, None, None] corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents) source_is_generated = source_is_generated.clone() source_is_generated |= proxy_mask return corrupt_latents, source_is_generated, diagnostics def _checkpoint_cfg(self): return self._cfg_get(self._memory_cfg(), "checkpoint", None) def _strict_eval_load_enabled(self) -> bool: return bool(self._cfg_get(self._checkpoint_cfg(), "strict_dememwm_eval_load", True)) def _cache_cfg(self): return self._cfg_get(self._memory_cfg(), "cache", None) def _cache_enabled(self) -> bool: return bool(self._cfg_get(self._cache_cfg(), "enabled", False)) def _new_streaming_cache(self, video_id=None) -> StreamingCache | None: if not self._cache_enabled(): return None cache = StreamingCache.from_config(self._cache_cfg(), enabled_default=True) if cache.clear_between_videos: cache.reset(video_id=video_id) return cache def _is_memory_adapter_param(self, name: str) -> bool: return ".memory_token_cross_attn." in name def _param_group_name(self, name: str, state=None) -> str: state = state or self._curriculum_state() if name.startswith("vae.") or name.startswith("validation_lpips_model."): return "excluded_frozen" if name.startswith(("dememwm_dynamic_compressor.", "dememwm_anchor_proj.", "dememwm_revisit_proj.")): return "dememwm_modules" if self._is_memory_adapter_param(name): return "memory_adapters" if name.startswith("diffusion_model."): return "full_dit" return "dememwm_modules" def _group_trainable(self, group_name: str, state) -> bool: if group_name in {"dememwm_modules", "memory_adapters"}: return True if group_name == "full_dit": return state.dit_full_trainable return False def _group_lr(self, group_name: str, state) -> float: if group_name == "dememwm_modules": return state.dememwm_lr if group_name == "memory_adapters": return state.memory_adapter_lr if group_name == "full_dit": return state.full_dit_lr return 0.0 def _apply_freeze_policy(self, optimizer=None, step: int | None = None): state = self._curriculum_state(step) # Keep DDP's trainable graph stable: DiT params stay requires_grad=True # from step 0 and are frozen by optimizer LR=0 until the full stage. # Re-walk only when curriculum diagnostics can change. freeze_key = (state.stage, state.dit_train_state, state.freeze_vae) last_key = getattr(self, "_last_freeze_key", None) if last_key != freeze_key: trainable_tensors = { "dememwm_modules": 0, "memory_adapters": 0, "full_dit": 0, "excluded_frozen": 0, } trainable_scalars = {key: 0 for key in trainable_tensors} requires_grad_tensors = {key: 0 for key in trainable_tensors} requires_grad_scalars = {key: 0 for key in trainable_tensors} for name, param in self.named_parameters(): group_name = self._param_group_name(name, state) should_train = self._group_trainable(group_name, state) if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae): should_train = False should_require_grad = False else: should_require_grad = True param.requires_grad_(should_require_grad) if should_train: trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1 trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel()) if should_require_grad: requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1 requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel()) self._last_freeze_key = freeze_key self._last_trainable_tensors = trainable_tensors self._last_trainable_scalars = trainable_scalars self._last_requires_grad_tensors = requires_grad_tensors self._last_requires_grad_scalars = requires_grad_scalars else: trainable_tensors = getattr(self, "_last_trainable_tensors", {}) trainable_scalars = getattr(self, "_last_trainable_scalars", {}) requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {}) requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {}) if optimizer is not None: for param_group in optimizer.param_groups: group_name = param_group.get("name", "") trainable = self._group_trainable(group_name, state) param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0 diagnostics = state.diagnostics() for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0) diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0) diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0) diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0) diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0 self._last_dememwm_freeze_diagnostics = diagnostics return state def configure_optimizers(self): state = self._curriculum_state(0) self._apply_freeze_policy(step=0) grouped: dict[str, list[torch.nn.Parameter]] = { "dememwm_modules": [], "memory_adapters": [], "full_dit": [], } for name, param in self.named_parameters(): group_name = self._param_group_name(name, state) if group_name in grouped: grouped[group_name].append(param) param_groups = [] for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): params = grouped[group_name] if params: trainable = self._group_trainable(group_name, state) param_groups.append({ "params": params, "lr": self._group_lr(group_name, state) if trainable else 0.0, "name": group_name, }) if not param_groups: raise RuntimeError("DeMemWM optimizer found no trainable parameter groups") return torch.optim.AdamW( param_groups, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta, ) def on_train_start(self): optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] for optimizer in optimizers: self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) def on_train_batch_start(self, batch, batch_idx): optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] for optimizer in optimizers: self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) def on_after_backward(self): step = int(getattr(self, "global_step", 0) or 0) state = self._apply_freeze_policy(step=step) for name, param in self.named_parameters(): if param.grad is None: continue group_name = self._param_group_name(name, state) if not self._group_trainable(group_name, state): param.grad = None def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) optimizer.step(closure=optimizer_closure) self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0) + 1) def on_load_checkpoint(self, checkpoint): super().on_load_checkpoint(checkpoint) if self._strict_eval_load_enabled(): state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint self.strict_checkpoint_key_check(state_dict) def _preprocess_batch(self, batch): """Preprocess RGB or precomputed-latent Minecraft batches for DeMemWM. MinecraftVideoLatentDataset returns an extra image_hw tensor. Keep the DeMemWM path on VAE latents while preserving RGB image size for Plucker pose embeddings. This mirrors the existing latent-dataset contract without routing through the legacy SSM memory implementation. """ from ..df_video import euler_to_camera_to_world_matrix if len(batch) == 5: xs, conditions, pose_conditions, frame_index, image_hw = batch self._last_dememwm_xs_are_latents = True self._last_dememwm_image_hw = image_hw else: xs, conditions, pose_conditions, frame_index = batch self._last_dememwm_xs_are_latents = False self._last_dememwm_image_hw = None if self.action_cond_dim: conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) conditions = rearrange(conditions, "b t d -> t b d").contiguous() else: raise NotImplementedError("Only support external cond.") pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous() c2w_mat = euler_to_camera_to_world_matrix(pose_conditions) xs = rearrange(xs, "b t c ... -> t b c ...").contiguous() frame_index = rearrange(frame_index, "b t -> t b").contiguous() return xs, conditions, pose_conditions, c2w_mat, frame_index def _as_latents(self, xs: torch.Tensor) -> torch.Tensor: if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): return xs return self.encode(xs) def _image_size(self, xs: torch.Tensor) -> tuple[int, int]: image_hw = getattr(self, "_last_dememwm_image_hw", None) if image_hw is not None: if torch.is_tensor(image_hw): values = image_hw.detach().cpu().reshape(-1).tolist() else: values = list(image_hw) if len(values) >= 2: return int(values[0]), int(values[1]) return int(xs.shape[-2]), int(xs.shape[-1]) def _update_streaming_cache( self, cache: StreamingCache | None, new_latents: torch.Tensor, frame_indices: torch.Tensor, pose: torch.Tensor | None = None, source_is_generated: torch.Tensor | None = None, action: torch.Tensor | None = None, ) -> None: if cache is None or not cache.enabled or new_latents is None or new_latents.shape[0] == 0: return cache.add_raw_latents(new_latents, frame_indices, source_is_generated, pose) if not cache.keep_compressed_records: return memory_cfg = self._memory_cfg() anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( int(new_latents.shape[-2]), int(new_latents.shape[-1]), self.dememwm_anchor_proj, token_patch_size, ) anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 ) anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) # Prefix anchors are a per-video prefix resource. Do not add new prefix # anchors for later committed segments unless explicitly generated anchors are allowed. if cache.records_count("anchor") > 0 and not allow_generated_anchor: anchor_indices = [] anchor_banks, revisit_banks = self._build_streaming_cache_records( new_latents, frame_indices, source_is_generated, pose, action, allow_generated_anchor, anchor_indices, anchor_pool_h, anchor_pool_w, anchor_diverse, token_patch_size, ) cache.add_memory_banks(anchor_banks, revisit_banks) def _build_model(self): from algorithms.common.metrics import LearnedPerceptualImagePatchSimilarity from .gates import RevisitRawGate from ..models.diffusion import Diffusion from ..models.pose_prediction import PosePredictionNet from ..models.vae import VAE_models self.diffusion_model = Diffusion( reference_length=self.memory_condition_length, x_shape=self.x_stacked_shape, action_cond_dim=self.action_cond_dim, pose_cond_dim=self.pose_cond_dim, is_causal=self.causal, cfg=self.cfg.diffusion, is_dit=True, use_plucker=self.use_plucker, relative_embedding=self.relative_embedding, state_embed_only_on_qk=self.state_embed_only_on_qk, use_memory_attention=False, add_timestamp_embedding=self.add_timestamp_embedding, memory_token_cross_attention=getattr(self.cfg, "memory_token_cross_attention", True), memory_cross_attn_layers=getattr(self.cfg, "memory_cross_attn_layers", None), ref_mode=self.ref_mode, ) memory_cfg = self._memory_cfg() self._validate_config_contract() injection_cfg = self._cfg_get(memory_cfg, "injection", None) dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) max_source_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( latent_channels=self.x_stacked_shape[0], dit_hidden_size=hidden_size, patch_size=token_patch_size, conv_kernel_t=int(self._cfg_get(dynamic_cfg, "conv_kernel_t", 3)), conv_stride_t=int(self._cfg_get(dynamic_cfg, "conv_stride_t", 2)), max_source_frames=max_source_frames, exclude_latest_local_frames=int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)), ) spatial_mid_channels = self.x_stacked_shape[0] * token_patch_size * token_patch_size self.dememwm_anchor_proj = SpatialConv2DMemoryProjector( latent_channels=self.x_stacked_shape[0], dit_hidden_size=hidden_size, mid_channels=spatial_mid_channels, kernel_size=3, ) self.dememwm_revisit_proj = SpatialConv2DMemoryProjector( latent_channels=self.x_stacked_shape[0], dit_hidden_size=hidden_size, mid_channels=spatial_mid_channels, kernel_size=3, ) self.dememwm_revisit_gate = RevisitRawGate() self.dememwm_injection_adapter = InjectionAdapter() self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity() self.vae = VAE_models["vit-l-20-shallow-encoder"]().eval() for param in self.vae.parameters(): param.requires_grad_(False) if self.require_pose_prediction: self.pose_prediction_model = PosePredictionNet() def _project_latent_patch_tokens( self, latents: torch.Tensor, projection: torch.nn.Module, patch_size: int, ) -> torch.Tensor: # (T,B,C,H,W) -> (B,T,T_frame,D). Conv2D projectors keep T_frame=H*W. if bool(getattr(projection, "projects_spatial_latents", False)): return projection(latents) patch_vectors = latent_patch_tokens(latents, patch_size) return projection(patch_vectors).permute(1, 0, 2, 3).contiguous() def _projected_spatial_grid_size( self, latent_h: int, latent_w: int, projection: torch.nn.Module, patch_size: int, ) -> tuple[int, int]: if bool(getattr(projection, "projects_spatial_latents", False)): return int(latent_h), int(latent_w) return int(latent_h) // int(patch_size), int(latent_w) // int(patch_size) def _take_uniform_slots(self, tokens: torch.Tensor, num_slots: int) -> torch.Tensor: if tokens.ndim != 2: raise ValueError("tokens must have shape (N,D)") num_slots = max(0, int(num_slots)) if num_slots == 0: return tokens[:0] if tokens.shape[0] <= num_slots: return tokens idx = torch.linspace(0, tokens.shape[0] - 1, num_slots, device=tokens.device).round().long() return tokens.index_select(0, idx) def _spatial_pool_tokens( self, tokens: torch.Tensor, pool_h: int, pool_w: int, src_h: int, src_w: int, ) -> torch.Tensor: return spatial_pool_tokens(tokens, pool_h, pool_w, src_h, src_w) def _resolve_spatial_pool_size( self, compress_cfg, src_h: int, src_w: int, default_pool_h: int, default_pool_w: int, ) -> tuple[int, int]: ratio = self._cfg_get(compress_cfg, "downsample_ratio", None) ratio_h = self._cfg_get(compress_cfg, "downsample_h", ratio) ratio_w = self._cfg_get(compress_cfg, "downsample_w", ratio) if ratio_h is not None or ratio_w is not None: if ratio_h is None: ratio_h = ratio_w if ratio_w is None: ratio_w = ratio_h ratio_h = float(ratio_h) ratio_w = float(ratio_w) if ratio_h <= 0.0 or ratio_w <= 0.0: raise ValueError("DeMemWM compress downsample ratios must be positive") return ( max(1, int(math.ceil(float(src_h) / ratio_h))), max(1, int(math.ceil(float(src_w) / ratio_w))), ) pool_h = int(self._cfg_get(compress_cfg, "pool_h", default_pool_h)) pool_w = int(self._cfg_get(compress_cfg, "pool_w", default_pool_w)) if pool_h <= 0 or pool_w <= 0: raise ValueError("DeMemWM compress pool_h/pool_w must be positive") return pool_h, pool_w def _select_diverse_anchor_positions( self, source_positions: torch.Tensor, pose: torch.Tensor | None, num_anchors: int, ) -> torch.Tensor: num_anchors = max(0, int(num_anchors)) if num_anchors == 0: return source_positions[:0] if source_positions.numel() <= num_anchors or pose is None: return source_positions[:num_anchors] poses = pose.float() pairwise = torch.cdist(poses, poses) if not bool((pairwise > 0).any().item()): return source_positions[:num_anchors] available = torch.ones((int(source_positions.numel()),), device=poses.device, dtype=torch.bool) if num_anchors == 1: selected = [int(pairwise.mean(dim=1).argmax().item())] else: first, second = divmod(int(pairwise.argmax().item()), int(pairwise.shape[1])) selected = [int(first), int(second)] for idx in selected: available[idx] = False dists = pairwise[selected].min(dim=0).values dists = dists.masked_fill(~available, float("-inf")) for _ in range(num_anchors - len(selected)): farthest = int(dists.argmax().item()) if not bool(available[farthest].item()): break selected.append(farthest) available[farthest] = False d_new = pairwise[farthest] dists = torch.minimum(dists, d_new) dists = dists.masked_fill(~available, float("-inf")) return source_positions[torch.tensor(sorted(selected), device=source_positions.device)] def _build_streaming_cache_records( self, source_latents: torch.Tensor, source_frame_indices: torch.Tensor, source_is_generated: torch.Tensor | None, pose: torch.Tensor | None, action: torch.Tensor | None, allow_generated_anchor: bool, anchor_indices: list[int], anchor_pool_h: int, anchor_pool_w: int, anchor_diverse: bool, token_patch_size: int, ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: if source_latents.ndim != 5: raise ValueError("source_latents must have shape (T,B,C,H,W)") if source_frame_indices.ndim != 2: raise ValueError("source_frame_indices must have shape (T,B)") T_src, B = source_frame_indices.shape if source_latents.shape[:2] != (T_src, B): raise ValueError("source_latents and source_frame_indices must share T/B dimensions") _, _, _, latent_H, latent_W = source_latents.shape src_h, src_w = self._projected_spatial_grid_size( latent_H, latent_W, self.dememwm_anchor_proj, token_patch_size, ) param = next(iter(self.dememwm_anchor_proj.parameters())) project_device = param.device project_dtype = param.dtype hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) anchor_banks: list[CausalMemoryBank] = [] revisit_banks: list[CausalMemoryBank] = [] def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): if tensor is None or tensor.ndim < 2: return None pos = positions.to(device=tensor.device) if tensor.shape[0] == T_src and tensor.shape[1] == B: return tensor[pos, batch_idx] if tensor.shape[0] == B and tensor.shape[1] == T_src: return tensor[batch_idx, pos] return None def _metadata_subset(positions: torch.Tensor, batch_idx: int): return {"dememwm_revisit_metadata_only": True} def _pose_subset(positions: torch.Tensor, batch_idx: int): return _tensor_subset(pose, positions, batch_idx) def _add_anchor_records(bank: CausalMemoryBank, batch_idx: int, positions: torch.Tensor, generated_anchor: bool) -> None: if positions.numel() == 0: return projected = self._project_latent_patch_tokens( source_latents.index_select(0, positions.to(device=source_latents.device))[:, batch_idx:batch_idx + 1].to(device=project_device, dtype=project_dtype), self.dememwm_anchor_proj, token_patch_size, )[0] src_frames = source_frame_indices[:, batch_idx] for local_idx, source_pos in enumerate(positions): source_pos_i = int(source_pos.item()) anchor_tokens = self._spatial_pool_tokens(projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) n_slots = anchor_tokens.shape[0] record_mask = torch.ones((n_slots,), device=anchor_tokens.device, dtype=torch.bool) if generated_anchor: bank.add_generated_records( anchor_tokens.unsqueeze(0), record_mask.unsqueeze(0), src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), source_type=MemorySourceType.GENERATED, ) else: bank.add_prefix_anchors( anchor_tokens.unsqueeze(0), record_mask.unsqueeze(0), src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), slots_per_anchor=n_slots, ) for batch_idx in range(B): anchor_bank = CausalMemoryBank() revisit_bank = CausalMemoryBank() src_frames = source_frame_indices[:, batch_idx] if generated is None: non_generated = torch.ones_like(src_frames, dtype=torch.bool) else: non_generated = ~generated[:, batch_idx] source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() if source_positions.numel() > 0: if anchor_diverse: anchor_source_positions = source_positions[source_positions < self._context_frame_count()] if anchor_source_positions.numel() > 0: anchor_pose = _pose_subset(anchor_source_positions, batch_idx) selected_anchor_positions = self._select_diverse_anchor_positions( anchor_source_positions, anchor_pose, len(anchor_indices) ) else: selected_anchor_positions = source_positions[:0] else: selected_list = [] for anchor_idx in anchor_indices: if 0 <= int(anchor_idx) < source_positions.numel(): selected_list.append(source_positions[int(anchor_idx)]) selected_anchor_positions = torch.stack(selected_list).long() if selected_list else source_positions[:0] if selected_anchor_positions.numel() > 0: _add_anchor_records(anchor_bank, batch_idx, selected_anchor_positions.long(), False) dummy_tokens = torch.zeros((1, hidden_size), device=source_frame_indices.device, dtype=project_dtype) dummy_mask = torch.ones((1,), device=source_frame_indices.device, dtype=torch.bool) for prefix, positions, source_type, is_generated in ( ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), ( "generated", torch.empty(0, device=source_frame_indices.device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), MemorySourceType.GENERATED, True, ), ): if positions.numel() == 0: continue for source_pos in positions.to(device=source_frame_indices.device, dtype=torch.long): source_pos_i = int(source_pos.item()) frame_index = src_frames[source_pos_i] frame = int(frame_index.detach().item()) frame_pos = source_pos.reshape(1) revisit_bank.add_frame_record( dummy_tokens, dummy_mask, frame_index, pose=_pose_subset(frame_pos, batch_idx), source_type=source_type, metadata=_metadata_subset(frame_pos, batch_idx), is_generated=is_generated, record_id=f"{prefix}_revisit_b{batch_idx}_f{frame}", ) if allow_generated_anchor and generated is not None and anchor_indices: generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() _add_anchor_records(anchor_bank, batch_idx, generated_positions[:len(anchor_indices)].long(), True) anchor_banks.append(anchor_bank) revisit_banks.append(revisit_bank) return anchor_banks, revisit_banks def _build_causal_memory_banks( self, anchor_projected: torch.Tensor, revisit_projected: torch.Tensor, source_frame_indices: torch.Tensor, source_is_generated: torch.Tensor | None, pose: torch.Tensor | None, action: torch.Tensor | None, allow_generated_anchor: bool, anchor_indices: list[int], anchor_pool_h: int, anchor_pool_w: int, revisit_pool_h: int, revisit_pool_w: int, src_h: int, src_w: int, ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: # projected tensors use the same batch/source convention as # _project_latent_patch_tokens: (B, T_src, T_frame, D), while frame indices are # (T_src, B). Build separate banks because anchor and revisit records # come from different projections. if anchor_projected.ndim != 4 or revisit_projected.ndim != 4: raise ValueError("anchor/revisit projected tensors must have shape (B,T_src,T_frame,D)") B, T_src, _, _ = anchor_projected.shape if revisit_projected.shape[:3] != anchor_projected.shape[:3]: raise ValueError("anchor/revisit projected tensors must share batch/source/token dimensions") generated = None if source_is_generated is None else source_is_generated.bool().to(source_frame_indices.device) anchor_banks: list[CausalMemoryBank] = [] revisit_banks: list[CausalMemoryBank] = [] def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): if tensor is None or tensor.ndim < 2: return None if tensor.shape[0] == T_src and tensor.shape[1] == B: return tensor[positions, batch_idx] if tensor.shape[0] == B and tensor.shape[1] == T_src: return tensor[batch_idx, positions] return None def _metadata_subset(positions: torch.Tensor, batch_idx: int): return {} def _pose_subset(positions: torch.Tensor, batch_idx: int): return _tensor_subset(pose, positions, batch_idx) for batch_idx in range(B): anchor_bank = CausalMemoryBank() revisit_bank = CausalMemoryBank() src_frames = source_frame_indices[:, batch_idx] if generated is None: non_generated = torch.ones_like(src_frames, dtype=torch.bool) else: non_generated = ~generated[:, batch_idx] source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() if source_positions.numel() > 0: selected_anchor_positions = [] for anchor_idx in anchor_indices: if 0 <= int(anchor_idx) < source_positions.numel(): selected_anchor_positions.append(source_positions[int(anchor_idx)]) for source_pos in selected_anchor_positions: source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) anchor_tokens = self._spatial_pool_tokens( anchor_projected[batch_idx, source_pos_i], anchor_pool_h, anchor_pool_w, src_h, src_w, ) n_slots = anchor_tokens.shape[0] record_mask = torch.ones( (n_slots,), device=anchor_projected.device, dtype=torch.bool, ) anchor_bank.add_prefix_anchors( anchor_tokens.unsqueeze(0), record_mask.unsqueeze(0), src_frames[source_pos_i:source_pos_i + 1], slots_per_anchor=n_slots, ) for source_pos in source_positions: source_pos_i = int(source_pos.item()) frame_index = src_frames[source_pos_i] frame = int(frame_index.detach().item()) frame_pos = source_pos.reshape(1) frame_tokens = self._spatial_pool_tokens( revisit_projected[batch_idx, source_pos_i], revisit_pool_h, revisit_pool_w, src_h, src_w, ) frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) revisit_bank.add_frame_record( frame_tokens, frame_mask, frame_index, pose=_pose_subset(frame_pos, batch_idx), source_type=MemorySourceType.PREFIX_GT, metadata=_metadata_subset(frame_pos, batch_idx), is_generated=False, record_id=f"prefix_revisit_b{batch_idx}_f{frame}", ) if generated is not None: generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() if generated_positions.numel() > 0: for source_pos in generated_positions: source_pos_i = int(source_pos.item()) frame_index = src_frames[source_pos_i] frame = int(frame_index.detach().item()) frame_pos = source_pos.reshape(1) frame_tokens = self._spatial_pool_tokens( revisit_projected[batch_idx, source_pos_i], revisit_pool_h, revisit_pool_w, src_h, src_w, ) frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) revisit_bank.add_frame_record( frame_tokens, frame_mask, frame_index, pose=_pose_subset(frame_pos, batch_idx), source_type=MemorySourceType.GENERATED, metadata=_metadata_subset(frame_pos, batch_idx), is_generated=True, record_id=f"generated_revisit_b{batch_idx}_f{frame}", ) if allow_generated_anchor: for source_pos in generated_positions[:len(anchor_indices)]: source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) anchor_tokens = self._spatial_pool_tokens( anchor_projected[batch_idx, source_pos_i], anchor_pool_h, anchor_pool_w, src_h, src_w, ) record_mask = torch.ones((anchor_tokens.shape[0],), device=anchor_projected.device, dtype=torch.bool) anchor_bank.add_generated_records( anchor_tokens.unsqueeze(0), record_mask.unsqueeze(0), src_frames[source_pos_i:source_pos_i + 1], source_type=MemorySourceType.GENERATED, ) anchor_banks.append(anchor_bank) revisit_banks.append(revisit_bank) return anchor_banks, revisit_banks def _build_preselected_causal_memory_banks( self, committed_latents: torch.Tensor, source_frame_indices: torch.Tensor, source_is_generated: torch.Tensor | None, pose: torch.Tensor | None, action: torch.Tensor | None, target_frame_indices: torch.Tensor, target_pose: torch.Tensor | None, target_action: torch.Tensor | None, target_video_ids, allow_generated_anchor: bool, anchor_indices: list[int], anchor_pool_h: int, anchor_pool_w: int, anchor_diverse: bool, revisit_pool_h: int, revisit_pool_w: int, revisit_max_frames: int, exclude_local_context_frames: int, fov_overlap_threshold, plucker_weight: float, revisit_retrieval_kwargs: dict | None, token_patch_size: int, ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]: if committed_latents.ndim != 5: raise ValueError("committed_latents must have shape (T_src,B,C,H,W)") T_src, B, _, H, W = committed_latents.shape if source_frame_indices.shape != (T_src, B): raise ValueError("source_frame_indices must have shape (T_src,B)") if target_frame_indices.ndim == 1: target_frame_indices = target_frame_indices[:, None] if target_frame_indices.shape[1] != B: raise ValueError("target_frame_indices must have batch dimension B") T_tgt = target_frame_indices.shape[0] stream_device = committed_latents.device hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) src_h, src_w = self._projected_spatial_grid_size( H, W, self.dememwm_anchor_proj, token_patch_size, ) tokens_per_frame = src_h * src_w generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) anchor_banks: list[CausalMemoryBank] = [] revisit_banks: list[CausalMemoryBank] = [] dummy_tokens = committed_latents.new_zeros((1, hidden_size)) dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool) preselection_candidate_count = 0 preselection_valid_candidate_label_count = 0 preselection_selected_count = 0 projected_anchor_frames = 0 projected_revisit_frames = 0 projected_revisit_records = 0 retrieval_kwargs = dict(revisit_retrieval_kwargs or {}) # Pre-convert pose tensors to stream_device once so that the # _tensor_subset / _target_tensor closures below never trigger a # device transfer on every call. if pose is not None: pose = pose.to(device=stream_device) if target_pose is not None: target_pose = target_pose.to(device=stream_device) def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): if tensor is None or tensor.ndim < 2: return None if tensor.shape[0] == T_src and tensor.shape[1] == B: return tensor[positions, batch_idx] if tensor.shape[0] == B and tensor.shape[1] == T_src: return tensor[batch_idx, positions] return None def _target_tensor(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): if tensor is None or tensor.ndim < 2: return None if tensor.shape[0] == T_tgt and tensor.shape[1] == B: return tensor[target_idx, batch_idx] if tensor.shape[0] == B and tensor.shape[1] == T_tgt: return tensor[batch_idx, target_idx] return None def _target_video_id(batch_idx: int, target_idx: int): if target_video_ids is None: return None if torch.is_tensor(target_video_ids): ids = target_video_ids.detach().cpu() if ids.ndim == 0: return ids.item() if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: return ids[target_idx, batch_idx].item() if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: return ids[batch_idx, target_idx].item() return None if isinstance(target_video_ids, (list, tuple)): if len(target_video_ids) == B: return target_video_ids[batch_idx] if len(target_video_ids) == T_tgt: row = target_video_ids[target_idx] if isinstance(row, (list, tuple)) and len(row) == B: return row[batch_idx] return row return target_video_ids def _metadata_subset(positions: torch.Tensor, batch_idx: int): return {} def _pose_subset(positions: torch.Tensor, batch_idx: int): return _tensor_subset(pose, positions, batch_idx) def _candidate_record( *, batch_idx: int, frame_position: torch.Tensor, source_type: MemorySourceType, is_generated: bool, record_id: str, ) -> MemoryRecord: frame_values = source_frame_indices[frame_position, batch_idx].to(device=stream_device) frame = int(frame_values.reshape(-1)[0].item()) return MemoryRecord( tokens=dummy_tokens, mask=dummy_mask, source_start=frame, source_end=frame + 1, frame_indices=frame_values.reshape(1), pose=_pose_subset(frame_position, batch_idx), source_type=source_type, is_generated=bool(is_generated), chunk_id=record_id, metadata=_metadata_subset(frame_position, batch_idx), ) for batch_idx in range(B): anchor_bank = CausalMemoryBank() revisit_bank = CausalMemoryBank() src_frames = source_frame_indices[:, batch_idx] if generated is None: non_generated = torch.ones_like(src_frames, dtype=torch.bool) else: non_generated = ~generated[:, batch_idx] source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long) if anchor_indices and source_positions.numel() > 0: if anchor_diverse: anchor_source_positions = source_positions[source_positions < self._context_frame_count()] if anchor_source_positions.numel() > 0: anchor_pose = _pose_subset(anchor_source_positions, batch_idx) anchor_positions = self._select_diverse_anchor_positions( anchor_source_positions, anchor_pose, len(anchor_indices) ).to(device=stream_device, dtype=torch.long) else: selected_anchor_positions = [] for anchor_idx in anchor_indices: if 0 <= int(anchor_idx) < source_positions.numel(): selected_anchor_positions.append(source_positions[int(anchor_idx)]) if selected_anchor_positions: anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long) if anchor_positions.numel() > 0: projected_anchor_frames += int(anchor_positions.numel()) anchor_projected = self._project_latent_patch_tokens( committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1], self.dememwm_anchor_proj, token_patch_size, )[0] for local_idx, source_pos in enumerate(anchor_positions): source_pos_i = int(source_pos.item()) anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) n_slots = anchor_tokens.shape[0] record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool) anchor_bank.add_prefix_anchors( anchor_tokens.unsqueeze(0), record_mask.unsqueeze(0), src_frames[source_pos_i:source_pos_i + 1], slots_per_anchor=n_slots, ) candidate_records: list[MemoryRecord] = [] candidate_positions: dict[str, torch.Tensor] = {} src_frames_cpu = src_frames.detach().cpu() target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long) latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames) for prefix, positions, source_type, is_generated in ( ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), ( "generated", torch.empty(0, device=stream_device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), MemorySourceType.GENERATED, True, ), ): if positions.numel() == 0 or latest_valid_source_frame_exclusive <= 0: continue positions_cpu = positions.detach().cpu().to(dtype=torch.long) for frame_position_cpu in positions_cpu: frame = int(src_frames_cpu[int(frame_position_cpu.item())].item()) if frame >= latest_valid_source_frame_exclusive: continue frame_position = frame_position_cpu.reshape(1).to(device=stream_device, dtype=torch.long) record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}" candidate_positions[record_id] = frame_position candidate_records.append(_candidate_record( batch_idx=batch_idx, frame_position=frame_position, source_type=source_type, is_generated=is_generated, record_id=record_id, )) selected_frame_record_ids: set[str] = set() selected_frame_metadata: dict[str, dict] = {} for target_idx in range(T_tgt): target_frame = int(target_frame_indices[target_idx, batch_idx].item()) result = deterministic_revisit_retrieval( candidate_records, target_frame=target_frame, target_pose=_target_tensor(target_pose, batch_idx, target_idx), target_summary=None, topk=revisit_max_frames, exclude_local_context_frames=exclude_local_context_frames, fov_overlap_threshold=fov_overlap_threshold, plucker_weight=plucker_weight, target_video_id=_target_video_id(batch_idx, target_idx), **retrieval_kwargs, ) preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0)) preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) for selected_record in result.records: if selected_record.chunk_id is None: continue record_id = str(selected_record.chunk_id) selected_frame_record_ids.add(record_id) selected_frame_metadata[record_id] = dict(selected_record.metadata) for record in candidate_records: if record.chunk_id not in selected_frame_record_ids: continue record_id = str(record.chunk_id) frame_position = candidate_positions[record_id] projected_revisit_records += 1 projected_revisit_frames += int(frame_position.numel()) revisit_projected = self._project_latent_patch_tokens( committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1], self.dememwm_revisit_proj, token_patch_size, )[0] frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, src_h, src_w) frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool) record_metadata = dict(record.metadata) record_metadata.update(selected_frame_metadata.get(record_id, {})) revisit_bank.add_frame_record( frame_tokens, frame_mask, record.frame_indices.reshape(-1)[0], pose=record.pose, source_type=record.source_type, metadata=record_metadata, is_generated=record.is_generated, record_id=record.chunk_id, ) anchor_banks.append(anchor_bank) revisit_banks.append(revisit_bank) diagnostics = { "preselected_anchor_projected_frame_count": projected_anchor_frames, "preselected_revisit_projected_frame_count": projected_revisit_frames, "preselected_revisit_projected_frame_record_count": projected_revisit_records, "preselected_revisit_candidate_frame_count": preselection_candidate_count, "preselected_revisit_candidate_count": preselection_candidate_count, "preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count, "preselected_revisit_selected_frame_count": preselection_selected_count, "preselected_revisit_selected_count": preselection_selected_count, } return anchor_banks, revisit_banks, tokens_per_frame, diagnostics def _causal_cached_revisit_records( self, records: Iterable[MemoryRecord], target_frame: int, ) -> list[MemoryRecord]: target_frame = int(target_frame) return [record for record in records if int(record.source_end) <= target_frame] def _records_to_stream( self, records, target_slots: int, hidden_size: int, device: torch.device, dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor, int]: target_slots = max(0, int(target_slots)) record_list = list(records) stacked_tokens, stacked_mask = stack_record_tokens(record_list, target_slots=target_slots) max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1) if stacked_tokens is None or stacked_mask is None or target_slots == 0: tokens = torch.zeros((target_slots, hidden_size), device=device, dtype=dtype) mask = torch.zeros((target_slots,), device=device, dtype=torch.bool) return tokens, mask, max_source_frame n = min(target_slots, stacked_tokens.shape[0]) filled = stacked_tokens[:n].to(device=device, dtype=dtype) filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool) if n < target_slots: pad = filled.new_zeros(target_slots - n, hidden_size) pad_mask = torch.zeros(target_slots - n, device=device, dtype=torch.bool) tokens = torch.cat([filled, pad], dim=0) mask = torch.cat([filled_mask, pad_mask], dim=0) else: tokens = filled mask = filled_mask return tokens, mask, max_source_frame def _project_streaming_revisit_records( self, *, cache: StreamingCache, batch_idx: int, records: list[MemoryRecord], device: torch.device, dtype: torch.dtype, token_patch_size: int, revisit_pool_h: int, revisit_pool_w: int, projection_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord], ) -> list[MemoryRecord]: projected_records: list[MemoryRecord] = [] for record in records: if not bool(record.metadata.get("dememwm_revisit_metadata_only", False)): projected_records.append(record) continue selected_frame_index = record.metadata.get("dememwm_selected_frame_index") if selected_frame_index is None: best_frame_idx = record.frame_indices[torch.argmax(record.frame_indices)].reshape(1) else: best_frame_idx = torch.as_tensor( [int(selected_frame_index)], device=record.frame_indices.device, dtype=record.frame_indices.dtype, ) key = ( int(batch_idx), str(record.chunk_id or ""), int(record.source_start), int(record.source_end), int(best_frame_idx.detach().cpu().reshape(-1)[0].item()), bool(record.is_generated), ) cached = projection_cache.get(key) if cached is not None: projected_records.append(cached) continue raw_latents = cache.raw_latents_for_frames( batch_idx=batch_idx, frame_indices=best_frame_idx, device=device, dtype=dtype, ) revisit_projected = self._project_latent_patch_tokens( raw_latents, self.dememwm_revisit_proj, token_patch_size, )[0] _proj_src_h, _proj_src_w = self._projected_spatial_grid_size( raw_latents.shape[3], raw_latents.shape[4], self.dememwm_revisit_proj, token_patch_size, ) frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, _proj_src_h, _proj_src_w) frame_mask = torch.ones((frame_tokens.shape[0],), device=device, dtype=torch.bool) metadata = { key: (value.to(device=device) if torch.is_tensor(value) else value) for key, value in record.metadata.items() } metadata["dememwm_revisit_metadata_only"] = False projected = MemoryRecord( tokens=frame_tokens, mask=frame_mask, source_start=int(record.source_start), source_end=int(record.source_end), frame_indices=record.frame_indices.to(device=device), pose=None if record.pose is None else record.pose.to(device=device), source_type=record.source_type, is_generated=bool(record.is_generated), score=record.score, chunk_id=record.chunk_id, metadata=metadata, ) projection_cache[key] = projected projected_records.append(projected) return projected_records def build_memory_streams( self, committed_latents: torch.Tensor | None, source_frame_indices: torch.Tensor | None, target_frame_indices: torch.Tensor | None = None, pose: torch.Tensor | None = None, target_pose: torch.Tensor | None = None, action: torch.Tensor | None = None, target_action: torch.Tensor | None = None, target_video_ids=None, source_is_generated: torch.Tensor | None = None, denoising_fraction: float | None = None, noise_bucket: str | None = None, noise_bucket_ids: torch.Tensor | None = None, streaming_cache: StreamingCache | None = None, ) -> MemoryStreamTensors: if target_frame_indices is None: if source_frame_indices is None: raise ValueError("target_frame_indices or source_frame_indices is required") target_frame_indices = source_frame_indices memory_cfg = self._memory_cfg() anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) injection_cfg = self._cfg_get(memory_cfg, "injection", None) contract_diag = self._validate_config_contract() gate_state = self._effective_gate_state( denoising_fraction=denoising_fraction, noise_bucket=noise_bucket, ) anchor_config_enabled = gate_state["anchor_config_enabled"] dynamic_config_enabled = gate_state["dynamic_config_enabled"] revisit_config_enabled = gate_state["revisit_config_enabled"] curriculum_state = gate_state["curriculum_state"] eval_ablation_enabled = gate_state["eval_ablation_enabled"] eval_ablation_branch = gate_state["eval_ablation_branch"] resolved_noise_bucket = gate_state["resolved_noise_bucket"] gates = gate_state["gates"] anchor_effective_enabled = gate_state["anchor_effective_enabled"] dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] force_revisit_off = gate_state["force_revisit_off"] force_revisit_on = gate_state["force_revisit_on"] token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) pool_latent_h = int(committed_latents.shape[-2]) if committed_latents is not None else int(self.x_stacked_shape[-2]) pool_latent_w = int(committed_latents.shape[-1]) if committed_latents is not None else int(self.x_stacked_shape[-1]) anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( pool_latent_h, pool_latent_w, self.dememwm_anchor_proj, token_patch_size, ) anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 ) anchor_num_tokens = len(anchor_indices) * anchor_pool_h * anchor_pool_w anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) revisit_max_frames = int(self._cfg_get(revisit_cfg, "max_frames", 2)) revisit_compress_cfg = self._cfg_get(revisit_cfg, "compress", None) revisit_src_h, revisit_src_w = self._projected_spatial_grid_size( pool_latent_h, pool_latent_w, self.dememwm_revisit_proj, token_patch_size, ) revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size( revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8 ) revisit_target_slots = revisit_max_frames * revisit_pool_h * revisit_pool_w recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) dynamic_recent_exclusion_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) revisit_context_window_exclusion_frames = self._local_context_exclusion_frames() fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) revisit_retrieval_kwargs = { "high_quality_fov_threshold": high_quality_fov_threshold, "fov_half_h": float(self._cfg_get(revisit_cfg, "fov_half_h", 52.5)), "fov_half_v": float(self._cfg_get(revisit_cfg, "fov_half_v", 37.5)), "fov_yaw_samples": int(self._cfg_get(revisit_cfg, "fov_yaw_samples", 25)), "fov_pitch_samples": int(self._cfg_get(revisit_cfg, "fov_pitch_samples", 20)), "fov_depth_samples": int(self._cfg_get(revisit_cfg, "fov_depth_samples", 20)), "fov_radius": float(self._cfg_get(revisit_cfg, "fov_radius", 30.0)), "pose_preselect_topk": self._cfg_get(revisit_cfg, "pose_preselect_topk", 64), "plucker_grid_h": int(self._cfg_get(revisit_cfg, "plucker_grid_h", 4)), "plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)), "plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)), } preselection_diag = {} use_cache_revisit_records = False revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0} if committed_latents is not None: stream_device = committed_latents.device stream_dtype = committed_latents.dtype else: param = next(iter(self.dememwm_anchor_proj.parameters())) stream_device = param.device stream_dtype = param.dtype target_frame_indices = target_frame_indices.to(device=stream_device) if target_frame_indices.ndim == 1: target_frame_indices = target_frame_indices[:, None] use_cache_records = cache is not None and cache.keep_compressed_records and cache.record_count > 0 dynamic_latents = committed_latents if dynamic_effective_enabled else None dynamic_frame_indices = source_frame_indices if dynamic_effective_enabled else None dynamic_generated = source_is_generated if dynamic_effective_enabled else None dynamic_pose = pose if dynamic_effective_enabled else None if dynamic_effective_enabled and cache is not None and cache.raw_frame_slots > 0: raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( device=stream_device, dtype=stream_dtype, max_recent_frames=recent_frames, target_frame_indices=target_frame_indices, exclude_latest_local_frames=dynamic_recent_exclusion_frames, ) if raw_latents is not None: dynamic_latents = raw_latents dynamic_frame_indices = raw_frames dynamic_generated = raw_generated dynamic_pose = raw_pose if use_cache_records: B = target_frame_indices.shape[1] hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) anchor_banks = ( cache.memory_banks("anchor", device=stream_device, dtype=stream_dtype, batch_size=B) if anchor_effective_enabled else [CausalMemoryBank() for _ in range(B)] ) revisit_banks = [CausalMemoryBank() for _ in range(B)] revisit_record_batches = ( [cache.records_for_batch("revisit", batch_idx) for batch_idx in range(B)] if revisit_stage_config_enabled else [tuple() for _ in range(B)] ) use_cache_revisit_records = bool(revisit_stage_config_enabled) if dynamic_latents is not None and dynamic_latents.ndim == 5 and dynamic_latents.shape[0] > 0: tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( dynamic_latents.shape[-2], dynamic_latents.shape[-1], self.dememwm_anchor_proj, token_patch_size, ) tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w else: latent_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 0 latent_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 0 tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( latent_h, latent_w, self.dememwm_anchor_proj, token_patch_size, ) tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w else: if committed_latents is None or source_frame_indices is None: raise ValueError("committed_latents/source_frame_indices are required when no streaming cache records are available") B = committed_latents.shape[1] hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) target_pose_source = target_pose if target_pose is not None else pose anchor_banks, revisit_banks, tokens_per_frame, preselection_diag = self._build_preselected_causal_memory_banks( committed_latents, source_frame_indices.to(device=stream_device), None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool), None if pose is None else pose.to(device=stream_device), None, target_frame_indices, None if target_pose_source is None else target_pose_source.to(device=stream_device), None, target_video_ids, allow_generated_anchor, anchor_indices, anchor_pool_h, anchor_pool_w, anchor_diverse, revisit_pool_h, revisit_pool_w, revisit_max_frames, revisit_context_window_exclusion_frames, fov_overlap_threshold, plucker_weight, revisit_retrieval_kwargs, token_patch_size, ) revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] T_tgt = target_frame_indices.shape[0] anchor_slots = max(0, anchor_num_tokens) revisit_slots = max(0, revisit_target_slots) anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT anchor_include_generated = allow_generated_anchor anchor_token_rows = [] anchor_mask_rows = [] anchor_max_rows = [] for batch_idx, anchor_bank in enumerate(anchor_banks): batch_token_rows = [] batch_mask_rows = [] batch_max_rows = [] for target_idx in range(T_tgt): target_frame = int(target_frame_indices[target_idx, batch_idx].item()) records = anchor_bank.query( MemoryBankQuery( target_frame=target_frame, source_type=anchor_source_type, include_generated=anchor_include_generated, max_records=len(anchor_indices), ) ) anchor_bank.assert_causal(target_frame, records) stream_tokens, stream_mask, max_source_frame = self._records_to_stream( records, anchor_slots, hidden_size, stream_device, stream_dtype, ) batch_token_rows.append(stream_tokens) batch_mask_rows.append(stream_mask) batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long)) anchor_token_rows.append(torch.stack(batch_token_rows, dim=0)) anchor_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) anchor_max_rows.append(torch.stack(batch_max_rows, dim=0)) anchor_tokens = torch.stack(anchor_token_rows, dim=0) anchor_mask = torch.stack(anchor_mask_rows, dim=0) anchor_max = torch.stack(anchor_max_rows, dim=0) if dynamic_latents is None or dynamic_frame_indices is None or dynamic_latents.shape[0] == 0: _fallback_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 18 _fallback_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 32 dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w) dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype) dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool) dynamic_diag = { "selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), "max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), "generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device), "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames, } else: # Pre-select dynamic source frame positions using only frame index metadata # before touching latents, so we pass a small slice instead of the full # 1000-frame tensor to the compressor. _dfi = dynamic_frame_indices.to(device=stream_device) _max_src = self.dememwm_dynamic_compressor.max_source_frames _needed: list[int] = [] for _b in range(B): for _j in range(T_tgt): _target = int(target_frame_indices[_j, _b].item()) _valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten() _needed.extend(_valid[-_max_src:].tolist()) if _needed: _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long) _dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx) _dynamic_fi_small = _dfi.index_select(0, _needed_idx) _dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None _dynamic_gen_small = ( dynamic_generated.to(device=stream_device, dtype=torch.bool).index_select(0, _needed_idx) if dynamic_generated is not None else None ) else: _dynamic_latents_small = dynamic_latents[:0] _dynamic_fi_small = _dfi[:0] _dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None _dynamic_gen_small = None dynamic_tokens, dynamic_mask, dynamic_diag = self.dememwm_dynamic_compressor( _dynamic_latents_small, _dynamic_fi_small, _dynamic_pose_small, target_frame_indices, _dynamic_gen_small, exclude_latest_local_frames=dynamic_recent_exclusion_frames, ) dynamic_min_gap_tensor = torch.as_tensor( dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), device=stream_device, ) dynamic_max_gap_tensor = torch.as_tensor( dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), device=stream_device, ) dynamic_gap_valid = dynamic_min_gap_tensor >= 0 dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1 dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0 dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1 def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): if tensor is None or tensor.ndim < 2: return None tensor_dev = tensor.to(device=stream_device) if tensor_dev.shape[0] == T_tgt and tensor_dev.shape[1] == B: return tensor_dev[target_idx, batch_idx] if tensor_dev.shape[0] == B and tensor_dev.shape[1] == T_tgt: return tensor_dev[batch_idx, target_idx] return None def _target_video_id_or_none(batch_idx: int, target_idx: int): if target_video_ids is None: return None if torch.is_tensor(target_video_ids): ids = target_video_ids.detach().cpu() if ids.ndim == 0: return ids.item() if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: return ids[target_idx, batch_idx].item() if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: return ids[batch_idx, target_idx].item() return None if isinstance(target_video_ids, (list, tuple)): if len(target_video_ids) == B: return target_video_ids[batch_idx] if len(target_video_ids) == T_tgt: row = target_video_ids[target_idx] if isinstance(row, (list, tuple)) and len(row) == B: return row[batch_idx] return row return target_video_ids target_pose_source = target_pose if target_pose is not None else pose revisit_token_rows = [] revisit_mask_rows = [] revisit_max_rows = [] valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32) eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long) eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES) revisit_result_diagnostics = [] projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {} if revisit_record_batches is None: revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] for batch_idx in range(B): revisit_bank = revisit_banks[batch_idx] batch_token_rows = [] batch_mask_rows = [] batch_max_rows = [] for target_idx in range(T_tgt): target_frame = int(target_frame_indices[target_idx, batch_idx].item()) if use_cache_revisit_records: candidate_records = self._causal_cached_revisit_records( revisit_record_batches[batch_idx], target_frame, ) else: candidate_records = revisit_bank.query( MemoryBankQuery( target_frame=target_frame, include_generated=True, ) ) result = deterministic_revisit_retrieval( candidate_records, target_frame=target_frame, target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx), target_summary=None, topk=revisit_max_frames, exclude_local_context_frames=revisit_context_window_exclusion_frames, fov_overlap_threshold=fov_overlap_threshold, plucker_weight=plucker_weight, target_video_id=_target_video_id_or_none(batch_idx, target_idx), **revisit_retrieval_kwargs, ) selected_records = result.records if use_cache_revisit_records and selected_records: selected_records = self._project_streaming_revisit_records( cache=cache, batch_idx=batch_idx, records=selected_records, device=stream_device, dtype=stream_dtype, token_patch_size=token_patch_size, revisit_pool_h=revisit_pool_h, revisit_pool_w=revisit_pool_w, projection_cache=projected_revisit_record_cache, ) revisit_result_diagnostics.append(result.diagnostics) revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0)) revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0)) revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1)) revisit_bank.assert_causal(target_frame, selected_records) if selected_records: valid_revisit_mask[batch_idx, target_idx] = True stream_tokens, stream_mask, max_source_frame = self._records_to_stream( selected_records, revisit_slots, hidden_size, stream_device, stream_dtype, ) revisit_causal_max[batch_idx, target_idx] = max_source_frame if eval_corruption_enabled: stream_tokens, was_corrupted = apply_revisit_eval_corruption( tokens=stream_tokens, mask=stream_mask, branch=eval_ablation_branch, target_frame=target_frame, ) eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted) batch_token_rows.append(stream_tokens) batch_mask_rows.append(stream_mask) batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long)) revisit_token_rows.append(torch.stack(batch_token_rows, dim=0)) revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) revisit_max_rows.append(torch.stack(batch_max_rows, dim=0)) revisit_tokens = torch.stack(revisit_token_rows, dim=0) revisit_mask = torch.stack(revisit_mask_rows, dim=0) revisit_max = torch.stack(revisit_max_rows, dim=0) if anchor_tokens.shape[-2] != anchor_num_tokens: raise AssertionError(f"anchor slot count mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}") if dynamic_latents is not None and dynamic_latents.shape[0] > 0: _expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target( int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1]) ) if dynamic_tokens.shape[-2] != _expected_dyn: raise AssertionError(f"dynamic slot count mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}") if revisit_tokens.shape[-2] != revisit_target_slots: raise AssertionError(f"revisit slot count mismatch: got {revisit_tokens.shape[-2]}, expected {revisit_target_slots}") anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 gate_module = getattr(self, "dememwm_revisit_gate", None) if gate_module is None: revisit_gate_raw = torch.ones((B, T_tgt), device=stream_device, dtype=stream_dtype) else: revisit_gate_raw = gate_module( valid_revisit_mask=valid_revisit_mask, best_selected_fov_overlap=revisit_best_selected_fov_overlap, best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, selected_gap_frames=revisit_selected_gap_frames, ).to(device=stream_device, dtype=stream_dtype) valid_revisit_eff_mask = valid_revisit_mask if not revisit_stage_config_enabled or force_revisit_off: revisit_gate = torch.zeros_like(revisit_gate_raw) elif force_revisit_on: revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * torch.ones_like(revisit_gate_raw) else: revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * revisit_gate_raw * float(gates.revisit_gate) revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) if not anchor_effective_enabled: anchor_mask = torch.zeros_like(anchor_mask) if not dynamic_effective_enabled: dynamic_mask = torch.zeros_like(dynamic_mask) if not revisit_stage_config_enabled: revisit_mask = torch.zeros_like(revisit_mask) valid_revisit_mask = torch.zeros_like(valid_revisit_mask) revisit_candidate_count = torch.zeros_like(revisit_candidate_count) revisit_selected_count = torch.zeros_like(revisit_selected_count) revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap) revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap) revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0) eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask) valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) revisit_gate_raw = torch.zeros_like(revisit_gate_raw) revisit_gate = torch.zeros_like(revisit_gate) no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask) causal_violation_count = 0 for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max): if source_max is None: continue source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device) valid = source_max_t >= 0 if valid.any(): causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item()) diagnostics = { **curriculum_state.diagnostics(), **getattr(self, "_last_dememwm_freeze_diagnostics", {}), **contract_diag, **cache_diag, **preselection_diag, **revisit_diag, "dememwm_stage": gates.stage, "dememwm_gate_reason": gates.reason, "anchor_config_enabled": anchor_config_enabled, "dynamic_config_enabled": dynamic_config_enabled, "revisit_config_enabled": revisit_config_enabled, "anchor_effective_enabled": anchor_effective_enabled, "dynamic_effective_enabled": dynamic_effective_enabled, "revisit_effective_enabled": revisit_effective_enabled, "revisit_stage_config_enabled": revisit_stage_config_enabled, "revisit_gate_raw": revisit_gate_raw.detach(), "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), "no_valid_revisit_mask": no_valid_revisit_mask, "valid_revisit_eff_mask": valid_revisit_eff_mask, "revisit_candidate_frame_count_per_target": revisit_candidate_count, "revisit_selected_frame_count_per_target": revisit_selected_count, "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap, "revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap, "revisit_selected_gap_frames_per_target": revisit_selected_gap_frames, "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, "revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()), **summarize_noise_bucket_diagnostics( noise_bucket=resolved_noise_bucket, valid_revisit_mask=valid_revisit_mask, no_valid_revisit_mask=no_valid_revisit_mask, noise_bucket_ids=noise_bucket_ids, ), **summarize_eval_ablation_diagnostics( enabled=eval_ablation_enabled, branch=eval_ablation_branch, valid_revisit_mask=valid_revisit_mask, no_valid_revisit_mask=no_valid_revisit_mask, eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None, ), "token_patch_size": token_patch_size, "tokens_per_frame": tokens_per_frame, "anchor_token_slots": int(anchor_tokens.shape[-2]), "anchor_target_slots": anchor_num_tokens, "anchor_pool_h": anchor_pool_h, "anchor_pool_w": anchor_pool_w, "dynamic_token_slots": int(dynamic_tokens.shape[-2]), "dynamic_target_slots": int(dynamic_tokens.shape[-2]), "dynamic_min_gap_to_target": dynamic_min_gap_to_target, "dynamic_max_gap_to_target": dynamic_max_gap_to_target, "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames, "revisit_token_slots": int(revisit_tokens.shape[-2]), "revisit_target_slots": revisit_target_slots, "revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames, "revisit_pool_h": revisit_pool_h, "revisit_pool_w": revisit_pool_w, "revisit_max_frames": revisit_max_frames, "anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0, "dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0, "revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0, "causal_violation_count": causal_violation_count, "anchor_max_source_frame": anchor_max, "dynamic_max_source_frame": dynamic_diag.get("max_source_frame"), "revisit_max_source_frame": revisit_max, "dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"), } if eval_corruption_enabled: diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask return MemoryStreamTensors( anchor_tokens=anchor_tokens, anchor_mask=anchor_mask, dynamic_tokens=dynamic_tokens, dynamic_mask=dynamic_mask, revisit_tokens=revisit_tokens, revisit_mask=revisit_mask, anchor_gate=anchor_gate, dynamic_gate=dynamic_gate, revisit_gate=revisit_gate, revisit_gate_raw=revisit_gate_raw, valid_revisit_mask=valid_revisit_mask, no_valid_revisit_mask=no_valid_revisit_mask, diagnostics=diagnostics, ) def _refresh_stream_gates( self, streams: MemoryStreamTensors, denoising_fraction: float | None = None, noise_bucket: str | None = None, ) -> MemoryStreamTensors: gate_state = self._effective_gate_state( denoising_fraction=denoising_fraction, noise_bucket=noise_bucket, ) gates = gate_state["gates"] device = streams.anchor_tokens.device dtype = streams.anchor_tokens.dtype B, T_tgt = streams.anchor_tokens.shape[:2] valid_revisit_mask = streams.valid_revisit_mask if valid_revisit_mask is None: valid_revisit_mask = torch.zeros((B, T_tgt), device=device, dtype=torch.bool) else: valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool) diagnostics = dict(streams.diagnostics) def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor: value = diagnostics.get(name) if value is None: return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32) tensor = torch.as_tensor(value, device=device, dtype=torch.float32) if tensor.ndim == 0: return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32) return tensor.expand((B, T_tgt)) revisit_best_selected_fov_overlap = _diagnostic_tensor("revisit_best_selected_fov_overlap_per_target") revisit_best_selected_plucker_overlap = _diagnostic_tensor("revisit_best_selected_plucker_overlap_per_target") revisit_selected_gap_frames = _diagnostic_tensor("revisit_selected_gap_frames_per_target", -1.0) anchor_effective_enabled = gate_state["anchor_effective_enabled"] dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 gate_module = getattr(self, "dememwm_revisit_gate", None) if gate_module is None: revisit_gate_raw = torch.ones((B, T_tgt), device=device, dtype=dtype) else: revisit_gate_raw = gate_module( valid_revisit_mask=valid_revisit_mask, best_selected_fov_overlap=revisit_best_selected_fov_overlap, best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, selected_gap_frames=revisit_selected_gap_frames, ).to(device=device, dtype=dtype) valid_revisit_eff_mask = valid_revisit_mask if not revisit_stage_config_enabled or gate_state["force_revisit_off"]: revisit_gate = torch.zeros_like(revisit_gate_raw) elif gate_state["force_revisit_on"]: revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw) else: revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate) if not revisit_stage_config_enabled: valid_revisit_mask = torch.zeros_like(valid_revisit_mask) valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) revisit_gate_raw = torch.zeros_like(revisit_gate_raw) revisit_gate = torch.zeros_like(revisit_gate) no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask") if eval_corrupted_revisit_mask is not None: eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool) revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) diagnostics.update(gate_state["curriculum_state"].diagnostics()) diagnostics.update({ "dememwm_stage": gates.stage, "dememwm_gate_reason": gates.reason, "anchor_config_enabled": gate_state["anchor_config_enabled"], "dynamic_config_enabled": gate_state["dynamic_config_enabled"], "revisit_config_enabled": gate_state["revisit_config_enabled"], "anchor_effective_enabled": anchor_effective_enabled, "dynamic_effective_enabled": dynamic_effective_enabled, "revisit_effective_enabled": revisit_effective_enabled, "revisit_stage_config_enabled": revisit_stage_config_enabled, "revisit_gate_raw": revisit_gate_raw.detach(), "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), "no_valid_revisit_mask": no_valid_revisit_mask, "valid_revisit_eff_mask": valid_revisit_eff_mask, "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, "revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0, }) diagnostics.update(summarize_noise_bucket_diagnostics( noise_bucket=gate_state["resolved_noise_bucket"], valid_revisit_mask=valid_revisit_mask, no_valid_revisit_mask=no_valid_revisit_mask, )) diagnostics.update(summarize_eval_ablation_diagnostics( enabled=gate_state["eval_ablation_enabled"], branch=gate_state["eval_ablation_branch"], valid_revisit_mask=valid_revisit_mask, no_valid_revisit_mask=no_valid_revisit_mask, eval_corrupted_revisit_mask=eval_corrupted_revisit_mask, )) return replace( streams, anchor_gate=anchor_gate, dynamic_gate=dynamic_gate, revisit_gate=revisit_gate, revisit_gate_raw=revisit_gate_raw, valid_revisit_mask=valid_revisit_mask, no_valid_revisit_mask=no_valid_revisit_mask, diagnostics=diagnostics, ) def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> tuple[dict, dict]: memory_kwargs, diagnostics = self.dememwm_injection_adapter(streams, device=streams.anchor_tokens.device, dtype=streams.anchor_tokens.dtype) return memory_kwargs, diagnostics def build_memory_kwargs(self, *args, **kwargs) -> tuple[dict, dict]: streams = self.build_memory_streams(*args, **kwargs) return self._streams_to_kwargs(streams) def _memory_adapter_delta_diagnostics(self) -> dict: dit_model = getattr(getattr(self, "diffusion_model", None), "model", None) diagnostics_fn = getattr(dit_model, "memory_adapter_delta_diagnostics", None) if diagnostics_fn is None: return {} return diagnostics_fn() def _log_memory_diagnostics(self, namespace: str, diagnostics: dict) -> None: if namespace == "training/dememwm": allowed_keys = self._TRAIN_DIAGNOSTIC_LOG_KEYS elif namespace.endswith("/dememwm"): allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS else: allowed_keys = None for key, value in diagnostics.items(): if allowed_keys is not None and key not in allowed_keys: continue if isinstance(value, str) or value is None: continue if torch.is_tensor(value): if value.numel() > 0: self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True) elif isinstance(value, (bool, int, float)): self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True) def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx): from ..df_video import convert_to_plucker image_height, image_width = self._image_size(xs) if self.use_plucker: if self.relative_embedding: input_pose_condition = [] frame_idx_list = [] ref_c2w = c2w_mat[-self.memory_condition_length:] if self.memory_condition_length else c2w_mat[:0] ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0] for i in range(c2w_mat.shape[0]): input_pose_condition.append( convert_to_plucker( torch.cat([c2w_mat[i:i + 1], ref_c2w]).clone(), 0, focal_length=self.focal_length, image_height=image_height, image_width=image_width ).to(xs.dtype) ) frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone()) return torch.cat(input_pose_condition), torch.cat(frame_idx_list) return convert_to_plucker( c2w_mat, 0, focal_length=self.focal_length, image_height=image_height, image_width=image_width ).to(xs.dtype), frame_idx return pose_conditions.to(xs.dtype), None def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]: total_frames = max(0, int(total_frames)) n_tokens = max(1, min(int(self.n_tokens), total_frames)) max_start = max(0, total_frames - n_tokens) if max_start == 0: return 0, n_tokens context_start = self._context_frame_count() min_start = min(context_start, max_start) if min_start == max_start: return min_start, min_start + n_tokens start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item()) return start, start + n_tokens def training_step(self, batch, batch_idx): xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) xs = self._as_latents(xs) # Randomly select a contiguous n_tokens denoising window inside the long # clip. DeMemWM memory streams are selected causally from frames before # each target, then only those selected frames are projected. total_frames = xs.shape[0] start, end = self._training_window_bounds(total_frames, xs.device) xs_window = xs[start:end] conditions_window = conditions[start:end].clone() frame_idx_window = frame_idx[start:end] input_pose_condition, frame_idx_list = self._training_pose_condition( xs_window, pose_conditions[start:end], c2w_mat[start:end], frame_idx_window ) noise_levels = self._generate_noise_levels(xs_window) if self.memory_condition_length: noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level conditions_window[-self.memory_condition_length:] *= 0 source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool) memory_source_latents, source_is_generated, proxy_diagnostics = self._apply_generated_history_proxy( xs, source_is_generated, context_frame_count=self._context_frame_count(), target_start_frame=start, ) timesteps = int(getattr(self, "timesteps", 0) or 0) training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps) training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps) training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps) memory_kwargs, diagnostics = self.build_memory_kwargs( memory_source_latents, frame_idx, target_frame_indices=frame_idx_window, pose=pose_conditions, target_pose=pose_conditions[start:end], action=conditions, target_action=conditions_window, source_is_generated=source_is_generated, denoising_fraction=training_denoising_fraction, noise_bucket=training_noise_bucket, noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1), ) diagnostics.update(proxy_diagnostics) _, loss = self.diffusion_model( xs_window, conditions_window, input_pose_condition, noise_levels=noise_levels, reference_length=self.memory_condition_length, frame_idx=frame_idx_list, **memory_kwargs, ) diagnostics.update(self._memory_adapter_delta_diagnostics()) if self.memory_condition_length: loss = loss[:-self.memory_condition_length] loss_denoise = self.reweight_loss(loss, None) loss_total = loss_denoise diagnostics["training_window_start"] = int(start) diagnostics["training_window_end"] = int(end) diagnostics["training_window_size"] = int(end - start) diagnostics["loss_denoise"] = float(loss_denoise.detach().item()) diagnostics["loss_total"] = float(loss_total.detach().item()) if batch_idx % 20 == 0: self.log("training/loss", loss_total.detach().cpu()) self._log_memory_diagnostics("training/dememwm", diagnostics) return {"loss": loss_total} def validation_step(self, batch, batch_idx, namespace="validation"): import numpy as np from tqdm import tqdm memory_condition_length = self.memory_condition_length xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) total_frame = xs_raw.shape[0] if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): xs = xs_raw.cpu() elif total_frame > 10: xs = torch.cat([self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu() for i in range(10)]) else: xs = self.encode(xs_raw).cpu() n_frames, batch_size, *_ = xs.shape curr_frame = 0 n_context_frames = self.context_frames // self.frame_stack xs_pred = xs[:n_context_frames].clone() curr_frame += n_context_frames streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}") cached_until = 0 pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") last_diagnostics = None while curr_frame < n_frames: if streaming_cache is not None and curr_frame > cached_until: new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) if curr_frame > n_context_frames: rel_start = max(0, n_context_frames - cached_until) new_generated[rel_start:] = True self._update_streaming_cache( streaming_cache, xs_pred[cached_until:curr_frame], frame_idx[cached_until:curr_frame], pose=pose_conditions[cached_until:curr_frame], source_is_generated=new_generated, action=conditions[cached_until:curr_frame], ) cached_until = curr_frame horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens." scheduling_matrix = self._generate_scheduling_matrix(horizon) chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])) chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device) xs_pred = torch.cat([xs_pred, chunk], 0) start_frame = max(0, curr_frame + horizon - self.n_tokens) pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon}) if memory_condition_length: random_idx = self._generate_condition_indices(curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon) xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0) else: random_idx = torch.empty((0, batch_size), dtype=torch.long, device=frame_idx.device) input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions( start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx, image_width=self._image_size(xs_raw)[1], image_height=self._image_size(xs_raw)[0] ) target_idx = frame_idx[start_frame:curr_frame + horizon].to(input_condition.device) use_streaming_cache = streaming_cache is not None and streaming_cache.record_count > 0 target_pose = pose_conditions[start_frame:curr_frame + horizon].to(input_condition.device) target_action = conditions[start_frame:curr_frame + horizon].to(input_condition.device) if use_streaming_cache: committed_latents = None committed_idx = None generated_flags = None source_pose = None source_action = None else: committed_latents = xs_pred[:curr_frame].to(input_condition.device) committed_idx = frame_idx[:curr_frame].to(input_condition.device) generated_flags = torch.zeros(committed_idx.shape, device=input_condition.device, dtype=torch.bool) if curr_frame > n_context_frames: generated_flags[n_context_frames:] = True source_pose = pose_conditions[:curr_frame].to(input_condition.device) source_action = conditions[:curr_frame].to(input_condition.device) memory_streams = self.build_memory_streams( committed_latents, committed_idx, target_frame_indices=target_idx, pose=source_pose, target_pose=target_pose, action=source_action, target_action=target_action, source_is_generated=generated_flags, denoising_fraction=None, streaming_cache=streaming_cache, ) for m in range(scheduling_matrix.shape[0] - 1): from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length) denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0) step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac) memory_kwargs, last_diagnostics = self._streams_to_kwargs(step_streams) xs_pred[start_frame:] = self.diffusion_model.sample_step( xs_pred[start_frame:].to(input_condition.device), input_condition, input_pose_condition, from_noise_levels[start_frame:], to_noise_levels[start_frame:], current_frame=curr_frame, mode="validation", reference_length=memory_condition_length, frame_idx=frame_idx_list, **memory_kwargs, ).cpu() if memory_condition_length: xs_pred = xs_pred[:-memory_condition_length] curr_frame += horizon if streaming_cache is not None and curr_frame > cached_until: new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) if curr_frame > n_context_frames: rel_start = max(0, n_context_frames - cached_until) new_generated[rel_start:] = True self._update_streaming_cache( streaming_cache, xs_pred[cached_until:curr_frame], frame_idx[cached_until:curr_frame], pose=pose_conditions[cached_until:curr_frame], source_is_generated=new_generated, action=conditions[cached_until:curr_frame], ) cached_until = curr_frame if last_diagnostics is not None: last_diagnostics.update(streaming_cache.diagnostics("cache")) pbar.update(horizon) pbar.close() if last_diagnostics is not None: self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics) xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device)) xs_decode = self.decode(xs[n_context_frames:].to(conditions.device)) self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu())) return def strict_checkpoint_key_check(self, state_dict: dict, required_prefixes: Iterable[str] | None = None) -> None: prefixes = tuple(required_prefixes or self.strict_key_prefixes) strip_prefixes = ("", "model.", "module.", "algo.") normalized_keys = [] for key in state_dict.keys(): key = str(key) for strip_prefix in strip_prefixes: if not strip_prefix or key.startswith(strip_prefix): normalized_keys.append(key.removeprefix(strip_prefix)) missing_prefixes = [prefix for prefix in prefixes if not any(key.startswith(prefix) for key in normalized_keys)] missing_substrings = [ marker for marker in self.strict_key_substrings if not any(marker in key for key in normalized_keys) ] if missing_prefixes or missing_substrings: raise RuntimeError( "DeMemWM checkpoint is missing required DeMemWM key coverage: " f"prefixes={missing_prefixes}, memory_adapter_markers={missing_substrings}" ) # Compatibility aliases for old DeMemWM test and experiment call sites. dememwm_strict_key_prefixes = strict_key_prefixes dememwm_strict_key_substrings = strict_key_substrings _DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS _DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS _dememwm_cfg = _memory_cfg _dememwm_stage_policy_cfg = _stage_policy_cfg _dememwm_eval_ablation_cfg = _eval_ablation_cfg _dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg _dememwm_eval_ablation_state = _eval_ablation_state _dememwm_effective_gate_state = _effective_gate_state _dememwm_validate_config_contract = _validate_config_contract _dememwm_stream_enabled = _stream_enabled _dememwm_context_frame_count = _context_frame_count _dememwm_local_context_exclusion_frames = _local_context_exclusion_frames _dememwm_curriculum_state = _curriculum_state _dememwm_generated_history_proxy_prob = _generated_history_proxy_prob _dememwm_apply_generated_history_proxy = _apply_generated_history_proxy _dememwm_checkpoint_cfg = _checkpoint_cfg _dememwm_strict_eval_load_enabled = _strict_eval_load_enabled _dememwm_cache_cfg = _cache_cfg _dememwm_cache_enabled = _cache_enabled _dememwm_new_streaming_cache = _new_streaming_cache _dememwm_is_memory_adapter_param = _is_memory_adapter_param _dememwm_param_group_name = _param_group_name _dememwm_group_trainable = _group_trainable _dememwm_group_lr = _group_lr _dememwm_apply_freeze_policy = _apply_freeze_policy _dememwm_as_latents = _as_latents _dememwm_image_size = _image_size _dememwm_update_streaming_cache = _update_streaming_cache _build_dememwm_streaming_cache_records = _build_streaming_cache_records _build_dememwm_causal_memory_banks = _build_causal_memory_banks _build_dememwm_preselected_causal_memory_banks = _build_preselected_causal_memory_banks _dememwm_records_to_stream = _records_to_stream build_dememwm_memory_streams = build_memory_streams _dememwm_refresh_stream_gates = _refresh_stream_gates _dememwm_streams_to_kwargs = _streams_to_kwargs build_dememwm_memory_kwargs = build_memory_kwargs _dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics _log_dememwm_diagnostics = _log_memory_diagnostics _dememwm_training_window_bounds = _training_window_bounds strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check DeMemWMMemoryDiTMixin = MemoryDiTMixin