|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import replace |
| from typing import Iterable |
|
|
| import torch |
| from einops import rearrange |
|
|
| from .cache import StreamingCache |
| from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens |
| from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics |
| from .injection import InjectionAdapter |
| from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens |
| from .negatives import apply_revisit_eval_corruption |
| from .retrieval import deterministic_revisit_retrieval |
| from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, noise_bucket_from_denoising_fraction, noise_bucket_from_noise_levels, noise_bucket_ids_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum |
| from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors |
|
|
|
|
| class MemoryDiTMixin: |
| """Standalone DeMemWM / Memory-DiT mixin. |
| |
| Reuses the base video-DiT infrastructure while keeping memory construction and |
| injection under the standalone `dememwm` package. Legacy memory-method files |
| are not part of this path. |
| """ |
|
|
| strict_key_prefixes = ( |
| "dememwm_dynamic_compressor.", |
| "dememwm_anchor_proj.", |
| "dememwm_revisit_proj.", |
| "dememwm_revisit_gate.", |
| ) |
| strict_key_substrings = ( |
| ".memory_token_cross_attn.", |
| ) |
| _TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({ |
| "revisit_candidate_frame_count", |
| "revisit_pose_preselect_input_count", |
| "revisit_pose_preselect_selected_count", |
| "revisit_exact_fov_candidate_count", |
| "valid_revisit_frame_count", |
| "no_valid_revisit_count", |
| "revisit_selected_frame_count", |
| "revisit_frame_fov_overlap_mean", |
| "revisit_best_selected_frame_fov_overlap_mean", |
| "revisit_best_selected_plucker_overlap_mean", |
| "revisit_best_selected_gap_frames_mean", |
| "revisit_gate_raw", |
| "revisit_gate_eff", |
| "revisit_learned_gate_mean", |
| "revisit_effective_gate_mean", |
| "generated_history_proxy_prob", |
| "noise_bucket_target_count", |
| "noise_bucket_high_target_count", |
| "noise_bucket_mid_target_count", |
| "noise_bucket_low_target_count", |
| }) |
| _VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({ |
| "cache_records", |
| "cache_slots", |
| }) |
|
|
| def _memory_cfg(self): |
| return getattr(self.cfg, "dememwm", None) |
|
|
| def _cfg_get(self, obj, name, default): |
| if obj is None: |
| return default |
| if isinstance(obj, dict): |
| return obj.get(name, default) |
| return getattr(obj, name, default) |
|
|
| def _cfg_has(self, obj, name: str) -> bool: |
| if obj is None: |
| return False |
| if isinstance(obj, dict): |
| return name in obj |
| try: |
| getattr(obj, name) |
| return True |
| except Exception: |
| return False |
|
|
| def _stage_policy_cfg(self): |
| return self._cfg_get(self._memory_cfg(), "stage_policy", None) |
|
|
| def _eval_ablation_cfg(self): |
| return self._cfg_get(self._memory_cfg(), "eval_ablation", None) |
|
|
| def _generated_history_proxy_cfg(self): |
| return self._cfg_get(self._memory_cfg(), "generated_history_proxy", None) |
|
|
| def _eval_ablation_state(self) -> tuple[bool, str]: |
| cfg = self._eval_ablation_cfg() |
| enabled = bool(self._cfg_get(cfg, "enabled", False)) |
| branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal")) |
| return enabled, branch |
|
|
| def _effective_gate_state(self, denoising_fraction: float | None = None, noise_bucket: str | None = None) -> dict: |
| memory_cfg = self._memory_cfg() |
| anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) |
| dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) |
| revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) |
| injection_cfg = self._cfg_get(memory_cfg, "injection", None) |
| anchor_config_enabled = self._stream_enabled(anchor_cfg) |
| dynamic_config_enabled = self._stream_enabled(dynamic_cfg) |
| revisit_config_enabled = self._stream_enabled(revisit_cfg) |
| curriculum_state = self._curriculum_state() |
| eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state() |
| debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False)) |
| resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction) |
| gates = compute_stream_gates( |
| curriculum_state.stage, |
| denoising_fraction=denoising_fraction, |
| debug_force_all_streams=debug_force, |
| anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)), |
| dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)), |
| revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)), |
| ) |
| anchor_effective_enabled = bool(gates.anchor_enabled and anchor_config_enabled) |
| dynamic_effective_enabled = bool(gates.dynamic_enabled and dynamic_config_enabled) |
| revisit_stage_config_enabled = bool(gates.revisit_enabled and revisit_config_enabled) |
| if eval_ablation_enabled: |
| if eval_ablation_branch == "memory_off": |
| anchor_effective_enabled = False |
| dynamic_effective_enabled = False |
| revisit_stage_config_enabled = False |
| elif eval_ablation_branch == "A_only": |
| dynamic_effective_enabled = False |
| revisit_stage_config_enabled = False |
| elif eval_ablation_branch == "D_only": |
| anchor_effective_enabled = False |
| revisit_stage_config_enabled = False |
| elif eval_ablation_branch == "A_plus_D": |
| revisit_stage_config_enabled = False |
| return { |
| "curriculum_state": curriculum_state, |
| "gates": gates, |
| "resolved_noise_bucket": resolved_noise_bucket, |
| "anchor_config_enabled": anchor_config_enabled, |
| "dynamic_config_enabled": dynamic_config_enabled, |
| "revisit_config_enabled": revisit_config_enabled, |
| "anchor_effective_enabled": anchor_effective_enabled, |
| "dynamic_effective_enabled": dynamic_effective_enabled, |
| "revisit_stage_config_enabled": revisit_stage_config_enabled, |
| "eval_ablation_enabled": eval_ablation_enabled, |
| "eval_ablation_branch": eval_ablation_branch, |
| "force_revisit_off": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_off"), |
| "force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"), |
| } |
|
|
| def _validate_config_contract(self) -> dict: |
| if bool(getattr(self, "_dememwm_contract_validated", False)): |
| return getattr(self, "_last_dememwm_config_diagnostics", {}) |
| memory_cfg = self._memory_cfg() |
| if memory_cfg is None: |
| self._dememwm_contract_validated = True |
| self._last_dememwm_config_diagnostics = {} |
| return {} |
|
|
| stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)] |
| if stale_sections: |
| raise ValueError(f"stale DeMemWM config sections are not part of the final contract: {stale_sections}") |
| ratio_fields = [ |
| name |
| for name in ("anchor_ratio", "dynamic_ratio", "revisit_ratio", "revisit_max_ratio") |
| if self._cfg_has(memory_cfg, name) |
| ] |
| if ratio_fields: |
| raise ValueError(f"standalone DeMemWM derives stream slots from latent shape and compression settings, not ratio fields: {ratio_fields}") |
|
|
| anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) |
| dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) |
| revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) |
| stale_nested = [] |
| for section_name, section_cfg, field_names in ( |
| ("anchor", anchor_cfg, ("policy", "topk", "pin_prefix")), |
| ("dynamic", dynamic_cfg, ("include_generated_recent",)), |
| ("revisit", revisit_cfg, ("deterministic_only", "min_age_frames", "min_gap_frames", "topk", "max_chunks", "chunk_frames", "min_score", "time_weight", "pose_weight", "latent_weight", "pose_overlap_threshold", "action_overlap_threshold", "generated_penalty", "force_gate_zero_when_invalid")), |
| ): |
| stale_nested.extend( |
| f"{section_name}.{field_name}" for field_name in field_names if self._cfg_has(section_cfg, field_name) |
| ) |
| if stale_nested: |
| raise ValueError(f"stale DeMemWM config fields are not part of the final contract: {stale_nested}") |
|
|
| exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) |
| if exclude_latest_local_frames < 0: |
| raise ValueError("dememwm.dynamic.exclude_latest_local_frames must be non-negative") |
| if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)): |
| raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval") |
| fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) |
| if fov_overlap_threshold is not None: |
| fov_overlap_threshold = float(fov_overlap_threshold) |
| if fov_overlap_threshold < 0.0: |
| raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative") |
| high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) |
| if high_quality_fov_threshold < 0.0: |
| raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative") |
| plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) |
| if plucker_weight < 0.0: |
| raise ValueError("dememwm.revisit.plucker_weight must be non-negative") |
| for field_name, default in ( |
| ("fov_half_h", 52.5), |
| ("fov_half_v", 37.5), |
| ("fov_radius", 30.0), |
| ("plucker_focal_length", 0.35), |
| ): |
| value = float(self._cfg_get(revisit_cfg, field_name, default)) |
| if value <= 0.0: |
| raise ValueError(f"dememwm.revisit.{field_name} must be positive") |
| for field_name, default in ( |
| ("fov_yaw_samples", 25), |
| ("fov_pitch_samples", 20), |
| ("fov_depth_samples", 20), |
| ("plucker_grid_h", 4), |
| ("plucker_grid_w", 4), |
| ): |
| value = int(self._cfg_get(revisit_cfg, field_name, default)) |
| if value <= 0: |
| raise ValueError(f"dememwm.revisit.{field_name} must be positive") |
| stage_policy_cfg = self._stage_policy_cfg() |
| if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)): |
| raise ValueError("final DeMemWM keeps noise_bucket logging enabled") |
| proxy_cfg = self._generated_history_proxy_cfg() |
| proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0)) |
| proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0)) |
| proxy_noise_std = float(self._cfg_get(proxy_cfg, "noise_std", 0.0)) |
| proxy_ramp_steps = int(self._cfg_get(proxy_cfg, "ramp_steps", 0)) |
| if proxy_max_prob < 0.0 or proxy_max_prob > 1.0: |
| raise ValueError("dememwm.generated_history_proxy.max_prob must be in [0, 1]") |
| if proxy_dropout_prob < 0.0 or proxy_dropout_prob > 1.0: |
| raise ValueError("dememwm.generated_history_proxy.dropout_prob must be in [0, 1]") |
| if proxy_noise_std < 0.0: |
| raise ValueError("dememwm.generated_history_proxy.noise_std must be non-negative") |
| if proxy_ramp_steps < 0: |
| raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative") |
| eval_ablation_cfg = self._eval_ablation_cfg() |
| normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal")) |
|
|
| diagnostics = { |
| "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, |
| "revisit_deterministic_fov_plucker_retrieval": True, |
| "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(), |
| "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold, |
| "revisit_plucker_weight": plucker_weight, |
| "stage_policy_noise_bucket_logging": True, |
| } |
| self._dememwm_contract_validated = True |
| self._last_dememwm_config_diagnostics = diagnostics |
| return diagnostics |
|
|
| def _stream_enabled(self, stream_cfg) -> bool: |
| return bool(self._cfg_get(stream_cfg, "enabled", True)) |
|
|
| def _context_frame_count(self) -> int: |
| frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) |
| return max(0, int(getattr(self, "context_frames", 0) or 0) // frame_stack) |
|
|
| def _local_context_exclusion_frames(self) -> int: |
| n_tokens = max(0, int(getattr(self, "n_tokens", 0) or 0)) |
| frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) |
| return n_tokens * frame_stack |
|
|
| def _curriculum_state(self, step: int | None = None): |
| if step is None: |
| step = int(getattr(self, "global_step", 0) or 0) |
| return resolve_curriculum(self._memory_cfg(), step) |
|
|
| def _generated_history_proxy_prob(self, step: int | None = None) -> float: |
| cfg = self._generated_history_proxy_cfg() |
| if not bool(self._cfg_get(cfg, "enabled", False)): |
| return 0.0 |
| max_prob = min(max(float(self._cfg_get(cfg, "max_prob", 0.0)), 0.0), 1.0) |
| if max_prob <= 0.0: |
| return 0.0 |
| if step is None: |
| step = int(getattr(self, "global_step", 0) or 0) |
| start_step = int(self._cfg_get(cfg, "start_step", 0)) |
| if step < start_step: |
| return 0.0 |
| ramp_steps = int(self._cfg_get(cfg, "ramp_steps", 0)) |
| if ramp_steps <= 0: |
| return max_prob |
| ramp_fraction = min(max(float(step - start_step) / float(ramp_steps), 0.0), 1.0) |
| return max_prob * ramp_fraction |
|
|
| def _apply_generated_history_proxy( |
| self, |
| source_latents: torch.Tensor, |
| source_is_generated: torch.Tensor | None, |
| context_frame_count: int | None = None, |
| target_start_frame: int | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, dict]: |
| cfg = self._generated_history_proxy_cfg() |
| prob = self._generated_history_proxy_prob() |
| noise_std = float(self._cfg_get(cfg, "noise_std", 0.0)) |
| dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0)) |
| diagnostics = { |
| "generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)), |
| "generated_history_proxy_prob": float(prob), |
| "generated_history_proxy_noise_std": float(noise_std), |
| "generated_history_proxy_dropout_prob": float(dropout_prob), |
| "generated_history_proxy_frame_count": 0, |
| "generated_history_proxy_frame_fraction": 0.0, |
| } |
| if source_is_generated is None: |
| source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) |
| else: |
| source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool) |
| if prob <= 0.0 or source_latents.numel() == 0: |
| return source_latents, source_is_generated, diagnostics |
|
|
| eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) |
| if context_frame_count is not None or target_start_frame is not None: |
| frame_positions = torch.arange(source_latents.shape[0], device=source_latents.device)[:, None] |
| if context_frame_count is not None: |
| eligible_mask &= frame_positions >= max(0, int(context_frame_count)) |
| if target_start_frame is not None: |
| eligible_mask &= frame_positions < max(0, int(target_start_frame)) |
| proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask |
| proxy_count = int(proxy_mask.detach().long().sum().item()) |
| total_count = max(1, int(proxy_mask.numel())) |
| diagnostics["generated_history_proxy_frame_count"] = proxy_count |
| diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count) |
| if proxy_count == 0: |
| return source_latents, source_is_generated, diagnostics |
|
|
| corrupt_latents = source_latents.clone() |
| frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype) |
| if noise_std > 0.0: |
| corrupt_latents = corrupt_latents + torch.randn_like(corrupt_latents) * float(noise_std) * frame_mask |
| if dropout_prob > 0.0: |
| dropout_mask = torch.rand( |
| (*source_latents.shape[:2], 1, source_latents.shape[-2], source_latents.shape[-1]), |
| device=source_latents.device, |
| ) < dropout_prob |
| dropout_mask = dropout_mask & proxy_mask[:, :, None, None, None] |
| corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents) |
| source_is_generated = source_is_generated.clone() |
| source_is_generated |= proxy_mask |
| return corrupt_latents, source_is_generated, diagnostics |
|
|
| def _checkpoint_cfg(self): |
| return self._cfg_get(self._memory_cfg(), "checkpoint", None) |
|
|
| def _strict_eval_load_enabled(self) -> bool: |
| return bool(self._cfg_get(self._checkpoint_cfg(), "strict_dememwm_eval_load", True)) |
|
|
| def _cache_cfg(self): |
| return self._cfg_get(self._memory_cfg(), "cache", None) |
|
|
| def _cache_enabled(self) -> bool: |
| return bool(self._cfg_get(self._cache_cfg(), "enabled", False)) |
|
|
| def _new_streaming_cache(self, video_id=None) -> StreamingCache | None: |
| if not self._cache_enabled(): |
| return None |
| cache = StreamingCache.from_config(self._cache_cfg(), enabled_default=True) |
| if cache.clear_between_videos: |
| cache.reset(video_id=video_id) |
| return cache |
|
|
| def _is_memory_adapter_param(self, name: str) -> bool: |
| return ".memory_token_cross_attn." in name |
|
|
| def _param_group_name(self, name: str, state=None) -> str: |
| state = state or self._curriculum_state() |
| if name.startswith("vae.") or name.startswith("validation_lpips_model."): |
| return "excluded_frozen" |
| if name.startswith(("dememwm_dynamic_compressor.", "dememwm_anchor_proj.", "dememwm_revisit_proj.")): |
| return "dememwm_modules" |
| if self._is_memory_adapter_param(name): |
| return "memory_adapters" |
| if name.startswith("diffusion_model."): |
| return "full_dit" |
| return "dememwm_modules" |
|
|
| def _group_trainable(self, group_name: str, state) -> bool: |
| if group_name in {"dememwm_modules", "memory_adapters"}: |
| return True |
| if group_name == "full_dit": |
| return state.dit_full_trainable |
| return False |
|
|
| def _group_lr(self, group_name: str, state) -> float: |
| if group_name == "dememwm_modules": |
| return state.dememwm_lr |
| if group_name == "memory_adapters": |
| return state.memory_adapter_lr |
| if group_name == "full_dit": |
| return state.full_dit_lr |
| return 0.0 |
|
|
| def _apply_freeze_policy(self, optimizer=None, step: int | None = None): |
| state = self._curriculum_state(step) |
|
|
| |
| |
| |
| freeze_key = (state.stage, state.dit_train_state, state.freeze_vae) |
| last_key = getattr(self, "_last_freeze_key", None) |
| if last_key != freeze_key: |
| trainable_tensors = { |
| "dememwm_modules": 0, |
| "memory_adapters": 0, |
| "full_dit": 0, |
| "excluded_frozen": 0, |
| } |
| trainable_scalars = {key: 0 for key in trainable_tensors} |
| requires_grad_tensors = {key: 0 for key in trainable_tensors} |
| requires_grad_scalars = {key: 0 for key in trainable_tensors} |
| for name, param in self.named_parameters(): |
| group_name = self._param_group_name(name, state) |
| should_train = self._group_trainable(group_name, state) |
| if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae): |
| should_train = False |
| should_require_grad = False |
| else: |
| should_require_grad = True |
| param.requires_grad_(should_require_grad) |
| if should_train: |
| trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1 |
| trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel()) |
| if should_require_grad: |
| requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1 |
| requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel()) |
| self._last_freeze_key = freeze_key |
| self._last_trainable_tensors = trainable_tensors |
| self._last_trainable_scalars = trainable_scalars |
| self._last_requires_grad_tensors = requires_grad_tensors |
| self._last_requires_grad_scalars = requires_grad_scalars |
| else: |
| trainable_tensors = getattr(self, "_last_trainable_tensors", {}) |
| trainable_scalars = getattr(self, "_last_trainable_scalars", {}) |
| requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {}) |
| requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {}) |
|
|
| if optimizer is not None: |
| for param_group in optimizer.param_groups: |
| group_name = param_group.get("name", "") |
| trainable = self._group_trainable(group_name, state) |
| param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0 |
|
|
| diagnostics = state.diagnostics() |
| for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): |
| diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0) |
| diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0) |
| diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0) |
| diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0) |
| diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0 |
| self._last_dememwm_freeze_diagnostics = diagnostics |
| return state |
|
|
| def configure_optimizers(self): |
| state = self._curriculum_state(0) |
| self._apply_freeze_policy(step=0) |
| grouped: dict[str, list[torch.nn.Parameter]] = { |
| "dememwm_modules": [], |
| "memory_adapters": [], |
| "full_dit": [], |
| } |
| for name, param in self.named_parameters(): |
| group_name = self._param_group_name(name, state) |
| if group_name in grouped: |
| grouped[group_name].append(param) |
| param_groups = [] |
| for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): |
| params = grouped[group_name] |
| if params: |
| trainable = self._group_trainable(group_name, state) |
| param_groups.append({ |
| "params": params, |
| "lr": self._group_lr(group_name, state) if trainable else 0.0, |
| "name": group_name, |
| }) |
| if not param_groups: |
| raise RuntimeError("DeMemWM optimizer found no trainable parameter groups") |
| return torch.optim.AdamW( |
| param_groups, |
| weight_decay=self.cfg.weight_decay, |
| betas=self.cfg.optimizer_beta, |
| ) |
|
|
| def on_train_start(self): |
| optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] |
| for optimizer in optimizers: |
| self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) |
|
|
| def on_train_batch_start(self, batch, batch_idx): |
| optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] |
| for optimizer in optimizers: |
| self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) |
|
|
| def on_after_backward(self): |
| step = int(getattr(self, "global_step", 0) or 0) |
| state = self._apply_freeze_policy(step=step) |
| for name, param in self.named_parameters(): |
| if param.grad is None: |
| continue |
| group_name = self._param_group_name(name, state) |
| if not self._group_trainable(group_name, state): |
| param.grad = None |
|
|
| def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): |
| self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) |
| optimizer.step(closure=optimizer_closure) |
| self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0) + 1) |
|
|
| def on_load_checkpoint(self, checkpoint): |
| super().on_load_checkpoint(checkpoint) |
| if self._strict_eval_load_enabled(): |
| state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint |
| self.strict_checkpoint_key_check(state_dict) |
|
|
| def _preprocess_batch(self, batch): |
| """Preprocess RGB or precomputed-latent Minecraft batches for DeMemWM. |
| |
| MinecraftVideoLatentDataset returns an extra image_hw tensor. Keep the |
| DeMemWM path on VAE latents while preserving RGB image size for Plucker |
| pose embeddings. This mirrors the existing latent-dataset contract |
| without routing through the legacy SSM memory implementation. |
| """ |
| from ..df_video import euler_to_camera_to_world_matrix |
|
|
| if len(batch) == 5: |
| xs, conditions, pose_conditions, frame_index, image_hw = batch |
| self._last_dememwm_xs_are_latents = True |
| self._last_dememwm_image_hw = image_hw |
| else: |
| xs, conditions, pose_conditions, frame_index = batch |
| self._last_dememwm_xs_are_latents = False |
| self._last_dememwm_image_hw = None |
|
|
| if self.action_cond_dim: |
| conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) |
| conditions = rearrange(conditions, "b t d -> t b d").contiguous() |
| else: |
| raise NotImplementedError("Only support external cond.") |
|
|
| pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous() |
| c2w_mat = euler_to_camera_to_world_matrix(pose_conditions) |
| xs = rearrange(xs, "b t c ... -> t b c ...").contiguous() |
| frame_index = rearrange(frame_index, "b t -> t b").contiguous() |
| return xs, conditions, pose_conditions, c2w_mat, frame_index |
|
|
| def _as_latents(self, xs: torch.Tensor) -> torch.Tensor: |
| if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): |
| return xs |
| return self.encode(xs) |
|
|
| def _image_size(self, xs: torch.Tensor) -> tuple[int, int]: |
| image_hw = getattr(self, "_last_dememwm_image_hw", None) |
| if image_hw is not None: |
| if torch.is_tensor(image_hw): |
| values = image_hw.detach().cpu().reshape(-1).tolist() |
| else: |
| values = list(image_hw) |
| if len(values) >= 2: |
| return int(values[0]), int(values[1]) |
| return int(xs.shape[-2]), int(xs.shape[-1]) |
|
|
| def _update_streaming_cache( |
| self, |
| cache: StreamingCache | None, |
| new_latents: torch.Tensor, |
| frame_indices: torch.Tensor, |
| pose: torch.Tensor | None = None, |
| source_is_generated: torch.Tensor | None = None, |
| action: torch.Tensor | None = None, |
| ) -> None: |
| if cache is None or not cache.enabled or new_latents is None or new_latents.shape[0] == 0: |
| return |
| cache.add_raw_latents(new_latents, frame_indices, source_is_generated, pose) |
| if not cache.keep_compressed_records: |
| return |
| memory_cfg = self._memory_cfg() |
| anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) |
| dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) |
| revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) |
| token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) |
| anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] |
| anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) |
| anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( |
| int(new_latents.shape[-2]), |
| int(new_latents.shape[-1]), |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| ) |
| anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( |
| anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 |
| ) |
| anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) |
| allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) |
| |
| |
| if cache.records_count("anchor") > 0 and not allow_generated_anchor: |
| anchor_indices = [] |
| anchor_banks, revisit_banks = self._build_streaming_cache_records( |
| new_latents, |
| frame_indices, |
| source_is_generated, |
| pose, |
| action, |
| allow_generated_anchor, |
| anchor_indices, |
| anchor_pool_h, |
| anchor_pool_w, |
| anchor_diverse, |
| token_patch_size, |
| ) |
| cache.add_memory_banks(anchor_banks, revisit_banks) |
|
|
| def _build_model(self): |
| from algorithms.common.metrics import LearnedPerceptualImagePatchSimilarity |
| from .gates import RevisitRawGate |
| from ..models.diffusion import Diffusion |
| from ..models.pose_prediction import PosePredictionNet |
| from ..models.vae import VAE_models |
|
|
| self.diffusion_model = Diffusion( |
| reference_length=self.memory_condition_length, |
| x_shape=self.x_stacked_shape, |
| action_cond_dim=self.action_cond_dim, |
| pose_cond_dim=self.pose_cond_dim, |
| is_causal=self.causal, |
| cfg=self.cfg.diffusion, |
| is_dit=True, |
| use_plucker=self.use_plucker, |
| relative_embedding=self.relative_embedding, |
| state_embed_only_on_qk=self.state_embed_only_on_qk, |
| use_memory_attention=False, |
| add_timestamp_embedding=self.add_timestamp_embedding, |
| memory_token_cross_attention=getattr(self.cfg, "memory_token_cross_attention", True), |
| memory_cross_attn_layers=getattr(self.cfg, "memory_cross_attn_layers", None), |
| ref_mode=self.ref_mode, |
| ) |
| memory_cfg = self._memory_cfg() |
| self._validate_config_contract() |
| injection_cfg = self._cfg_get(memory_cfg, "injection", None) |
| dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) |
| hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) |
| token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) |
| max_source_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) |
| self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( |
| latent_channels=self.x_stacked_shape[0], |
| dit_hidden_size=hidden_size, |
| patch_size=token_patch_size, |
| conv_kernel_t=int(self._cfg_get(dynamic_cfg, "conv_kernel_t", 3)), |
| conv_stride_t=int(self._cfg_get(dynamic_cfg, "conv_stride_t", 2)), |
| max_source_frames=max_source_frames, |
| exclude_latest_local_frames=int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)), |
| ) |
| spatial_mid_channels = self.x_stacked_shape[0] * token_patch_size * token_patch_size |
| self.dememwm_anchor_proj = SpatialConv2DMemoryProjector( |
| latent_channels=self.x_stacked_shape[0], |
| dit_hidden_size=hidden_size, |
| mid_channels=spatial_mid_channels, |
| kernel_size=3, |
| ) |
| self.dememwm_revisit_proj = SpatialConv2DMemoryProjector( |
| latent_channels=self.x_stacked_shape[0], |
| dit_hidden_size=hidden_size, |
| mid_channels=spatial_mid_channels, |
| kernel_size=3, |
| ) |
| self.dememwm_revisit_gate = RevisitRawGate() |
| self.dememwm_injection_adapter = InjectionAdapter() |
| self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity() |
| self.vae = VAE_models["vit-l-20-shallow-encoder"]().eval() |
| for param in self.vae.parameters(): |
| param.requires_grad_(False) |
| if self.require_pose_prediction: |
| self.pose_prediction_model = PosePredictionNet() |
|
|
| def _project_latent_patch_tokens( |
| self, |
| latents: torch.Tensor, |
| projection: torch.nn.Module, |
| patch_size: int, |
| ) -> torch.Tensor: |
| |
| if bool(getattr(projection, "projects_spatial_latents", False)): |
| return projection(latents) |
| patch_vectors = latent_patch_tokens(latents, patch_size) |
| return projection(patch_vectors).permute(1, 0, 2, 3).contiguous() |
|
|
| def _projected_spatial_grid_size( |
| self, |
| latent_h: int, |
| latent_w: int, |
| projection: torch.nn.Module, |
| patch_size: int, |
| ) -> tuple[int, int]: |
| if bool(getattr(projection, "projects_spatial_latents", False)): |
| return int(latent_h), int(latent_w) |
| return int(latent_h) // int(patch_size), int(latent_w) // int(patch_size) |
|
|
| def _take_uniform_slots(self, tokens: torch.Tensor, num_slots: int) -> torch.Tensor: |
| if tokens.ndim != 2: |
| raise ValueError("tokens must have shape (N,D)") |
| num_slots = max(0, int(num_slots)) |
| if num_slots == 0: |
| return tokens[:0] |
| if tokens.shape[0] <= num_slots: |
| return tokens |
| idx = torch.linspace(0, tokens.shape[0] - 1, num_slots, device=tokens.device).round().long() |
| return tokens.index_select(0, idx) |
|
|
| def _spatial_pool_tokens( |
| self, |
| tokens: torch.Tensor, |
| pool_h: int, |
| pool_w: int, |
| src_h: int, |
| src_w: int, |
| ) -> torch.Tensor: |
| return spatial_pool_tokens(tokens, pool_h, pool_w, src_h, src_w) |
|
|
| def _resolve_spatial_pool_size( |
| self, |
| compress_cfg, |
| src_h: int, |
| src_w: int, |
| default_pool_h: int, |
| default_pool_w: int, |
| ) -> tuple[int, int]: |
| ratio = self._cfg_get(compress_cfg, "downsample_ratio", None) |
| ratio_h = self._cfg_get(compress_cfg, "downsample_h", ratio) |
| ratio_w = self._cfg_get(compress_cfg, "downsample_w", ratio) |
| if ratio_h is not None or ratio_w is not None: |
| if ratio_h is None: |
| ratio_h = ratio_w |
| if ratio_w is None: |
| ratio_w = ratio_h |
| ratio_h = float(ratio_h) |
| ratio_w = float(ratio_w) |
| if ratio_h <= 0.0 or ratio_w <= 0.0: |
| raise ValueError("DeMemWM compress downsample ratios must be positive") |
| return ( |
| max(1, int(math.ceil(float(src_h) / ratio_h))), |
| max(1, int(math.ceil(float(src_w) / ratio_w))), |
| ) |
| pool_h = int(self._cfg_get(compress_cfg, "pool_h", default_pool_h)) |
| pool_w = int(self._cfg_get(compress_cfg, "pool_w", default_pool_w)) |
| if pool_h <= 0 or pool_w <= 0: |
| raise ValueError("DeMemWM compress pool_h/pool_w must be positive") |
| return pool_h, pool_w |
|
|
| def _select_diverse_anchor_positions( |
| self, |
| source_positions: torch.Tensor, |
| pose: torch.Tensor | None, |
| num_anchors: int, |
| ) -> torch.Tensor: |
| num_anchors = max(0, int(num_anchors)) |
| if num_anchors == 0: |
| return source_positions[:0] |
| if source_positions.numel() <= num_anchors or pose is None: |
| return source_positions[:num_anchors] |
| poses = pose.float() |
| pairwise = torch.cdist(poses, poses) |
| if not bool((pairwise > 0).any().item()): |
| return source_positions[:num_anchors] |
| available = torch.ones((int(source_positions.numel()),), device=poses.device, dtype=torch.bool) |
| if num_anchors == 1: |
| selected = [int(pairwise.mean(dim=1).argmax().item())] |
| else: |
| first, second = divmod(int(pairwise.argmax().item()), int(pairwise.shape[1])) |
| selected = [int(first), int(second)] |
| for idx in selected: |
| available[idx] = False |
| dists = pairwise[selected].min(dim=0).values |
| dists = dists.masked_fill(~available, float("-inf")) |
| for _ in range(num_anchors - len(selected)): |
| farthest = int(dists.argmax().item()) |
| if not bool(available[farthest].item()): |
| break |
| selected.append(farthest) |
| available[farthest] = False |
| d_new = pairwise[farthest] |
| dists = torch.minimum(dists, d_new) |
| dists = dists.masked_fill(~available, float("-inf")) |
| return source_positions[torch.tensor(sorted(selected), device=source_positions.device)] |
|
|
| def _build_streaming_cache_records( |
| self, |
| source_latents: torch.Tensor, |
| source_frame_indices: torch.Tensor, |
| source_is_generated: torch.Tensor | None, |
| pose: torch.Tensor | None, |
| action: torch.Tensor | None, |
| allow_generated_anchor: bool, |
| anchor_indices: list[int], |
| anchor_pool_h: int, |
| anchor_pool_w: int, |
| anchor_diverse: bool, |
| token_patch_size: int, |
| ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: |
| if source_latents.ndim != 5: |
| raise ValueError("source_latents must have shape (T,B,C,H,W)") |
| if source_frame_indices.ndim != 2: |
| raise ValueError("source_frame_indices must have shape (T,B)") |
| T_src, B = source_frame_indices.shape |
| if source_latents.shape[:2] != (T_src, B): |
| raise ValueError("source_latents and source_frame_indices must share T/B dimensions") |
| _, _, _, latent_H, latent_W = source_latents.shape |
| src_h, src_w = self._projected_spatial_grid_size( |
| latent_H, |
| latent_W, |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| ) |
|
|
| param = next(iter(self.dememwm_anchor_proj.parameters())) |
| project_device = param.device |
| project_dtype = param.dtype |
| hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) |
| generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) |
| anchor_banks: list[CausalMemoryBank] = [] |
| revisit_banks: list[CausalMemoryBank] = [] |
|
|
| def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): |
| if tensor is None or tensor.ndim < 2: |
| return None |
| pos = positions.to(device=tensor.device) |
| if tensor.shape[0] == T_src and tensor.shape[1] == B: |
| return tensor[pos, batch_idx] |
| if tensor.shape[0] == B and tensor.shape[1] == T_src: |
| return tensor[batch_idx, pos] |
| return None |
|
|
| def _metadata_subset(positions: torch.Tensor, batch_idx: int): |
| return {"dememwm_revisit_metadata_only": True} |
|
|
| def _pose_subset(positions: torch.Tensor, batch_idx: int): |
| return _tensor_subset(pose, positions, batch_idx) |
|
|
| def _add_anchor_records(bank: CausalMemoryBank, batch_idx: int, positions: torch.Tensor, generated_anchor: bool) -> None: |
| if positions.numel() == 0: |
| return |
| projected = self._project_latent_patch_tokens( |
| source_latents.index_select(0, positions.to(device=source_latents.device))[:, batch_idx:batch_idx + 1].to(device=project_device, dtype=project_dtype), |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| )[0] |
| src_frames = source_frame_indices[:, batch_idx] |
| for local_idx, source_pos in enumerate(positions): |
| source_pos_i = int(source_pos.item()) |
| anchor_tokens = self._spatial_pool_tokens(projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) |
| n_slots = anchor_tokens.shape[0] |
| record_mask = torch.ones((n_slots,), device=anchor_tokens.device, dtype=torch.bool) |
| if generated_anchor: |
| bank.add_generated_records( |
| anchor_tokens.unsqueeze(0), |
| record_mask.unsqueeze(0), |
| src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), |
| source_type=MemorySourceType.GENERATED, |
| ) |
| else: |
| bank.add_prefix_anchors( |
| anchor_tokens.unsqueeze(0), |
| record_mask.unsqueeze(0), |
| src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), |
| slots_per_anchor=n_slots, |
| ) |
|
|
| for batch_idx in range(B): |
| anchor_bank = CausalMemoryBank() |
| revisit_bank = CausalMemoryBank() |
| src_frames = source_frame_indices[:, batch_idx] |
| if generated is None: |
| non_generated = torch.ones_like(src_frames, dtype=torch.bool) |
| else: |
| non_generated = ~generated[:, batch_idx] |
|
|
| source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() |
| if source_positions.numel() > 0: |
| if anchor_diverse: |
| anchor_source_positions = source_positions[source_positions < self._context_frame_count()] |
| if anchor_source_positions.numel() > 0: |
| anchor_pose = _pose_subset(anchor_source_positions, batch_idx) |
| selected_anchor_positions = self._select_diverse_anchor_positions( |
| anchor_source_positions, anchor_pose, len(anchor_indices) |
| ) |
| else: |
| selected_anchor_positions = source_positions[:0] |
| else: |
| selected_list = [] |
| for anchor_idx in anchor_indices: |
| if 0 <= int(anchor_idx) < source_positions.numel(): |
| selected_list.append(source_positions[int(anchor_idx)]) |
| selected_anchor_positions = torch.stack(selected_list).long() if selected_list else source_positions[:0] |
| if selected_anchor_positions.numel() > 0: |
| _add_anchor_records(anchor_bank, batch_idx, selected_anchor_positions.long(), False) |
|
|
| dummy_tokens = torch.zeros((1, hidden_size), device=source_frame_indices.device, dtype=project_dtype) |
| dummy_mask = torch.ones((1,), device=source_frame_indices.device, dtype=torch.bool) |
| for prefix, positions, source_type, is_generated in ( |
| ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), |
| ( |
| "generated", |
| torch.empty(0, device=source_frame_indices.device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), |
| MemorySourceType.GENERATED, |
| True, |
| ), |
| ): |
| if positions.numel() == 0: |
| continue |
| for source_pos in positions.to(device=source_frame_indices.device, dtype=torch.long): |
| source_pos_i = int(source_pos.item()) |
| frame_index = src_frames[source_pos_i] |
| frame = int(frame_index.detach().item()) |
| frame_pos = source_pos.reshape(1) |
| revisit_bank.add_frame_record( |
| dummy_tokens, |
| dummy_mask, |
| frame_index, |
| pose=_pose_subset(frame_pos, batch_idx), |
| source_type=source_type, |
| metadata=_metadata_subset(frame_pos, batch_idx), |
| is_generated=is_generated, |
| record_id=f"{prefix}_revisit_b{batch_idx}_f{frame}", |
| ) |
|
|
| if allow_generated_anchor and generated is not None and anchor_indices: |
| generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() |
| _add_anchor_records(anchor_bank, batch_idx, generated_positions[:len(anchor_indices)].long(), True) |
|
|
| anchor_banks.append(anchor_bank) |
| revisit_banks.append(revisit_bank) |
| return anchor_banks, revisit_banks |
|
|
|
|
| def _build_causal_memory_banks( |
| self, |
| anchor_projected: torch.Tensor, |
| revisit_projected: torch.Tensor, |
| source_frame_indices: torch.Tensor, |
| source_is_generated: torch.Tensor | None, |
| pose: torch.Tensor | None, |
| action: torch.Tensor | None, |
| allow_generated_anchor: bool, |
| anchor_indices: list[int], |
| anchor_pool_h: int, |
| anchor_pool_w: int, |
| revisit_pool_h: int, |
| revisit_pool_w: int, |
| src_h: int, |
| src_w: int, |
| ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: |
| |
| |
| |
| |
| if anchor_projected.ndim != 4 or revisit_projected.ndim != 4: |
| raise ValueError("anchor/revisit projected tensors must have shape (B,T_src,T_frame,D)") |
| B, T_src, _, _ = anchor_projected.shape |
| if revisit_projected.shape[:3] != anchor_projected.shape[:3]: |
| raise ValueError("anchor/revisit projected tensors must share batch/source/token dimensions") |
| generated = None if source_is_generated is None else source_is_generated.bool().to(source_frame_indices.device) |
| anchor_banks: list[CausalMemoryBank] = [] |
| revisit_banks: list[CausalMemoryBank] = [] |
|
|
| def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): |
| if tensor is None or tensor.ndim < 2: |
| return None |
| if tensor.shape[0] == T_src and tensor.shape[1] == B: |
| return tensor[positions, batch_idx] |
| if tensor.shape[0] == B and tensor.shape[1] == T_src: |
| return tensor[batch_idx, positions] |
| return None |
|
|
| def _metadata_subset(positions: torch.Tensor, batch_idx: int): |
| return {} |
|
|
| def _pose_subset(positions: torch.Tensor, batch_idx: int): |
| return _tensor_subset(pose, positions, batch_idx) |
|
|
| for batch_idx in range(B): |
| anchor_bank = CausalMemoryBank() |
| revisit_bank = CausalMemoryBank() |
| src_frames = source_frame_indices[:, batch_idx] |
| if generated is None: |
| non_generated = torch.ones_like(src_frames, dtype=torch.bool) |
| else: |
| non_generated = ~generated[:, batch_idx] |
|
|
| source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() |
| if source_positions.numel() > 0: |
| selected_anchor_positions = [] |
| for anchor_idx in anchor_indices: |
| if 0 <= int(anchor_idx) < source_positions.numel(): |
| selected_anchor_positions.append(source_positions[int(anchor_idx)]) |
| for source_pos in selected_anchor_positions: |
| source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) |
| anchor_tokens = self._spatial_pool_tokens( |
| anchor_projected[batch_idx, source_pos_i], |
| anchor_pool_h, anchor_pool_w, src_h, src_w, |
| ) |
| n_slots = anchor_tokens.shape[0] |
| record_mask = torch.ones( |
| (n_slots,), |
| device=anchor_projected.device, |
| dtype=torch.bool, |
| ) |
| anchor_bank.add_prefix_anchors( |
| anchor_tokens.unsqueeze(0), |
| record_mask.unsqueeze(0), |
| src_frames[source_pos_i:source_pos_i + 1], |
| slots_per_anchor=n_slots, |
| ) |
|
|
| for source_pos in source_positions: |
| source_pos_i = int(source_pos.item()) |
| frame_index = src_frames[source_pos_i] |
| frame = int(frame_index.detach().item()) |
| frame_pos = source_pos.reshape(1) |
| frame_tokens = self._spatial_pool_tokens( |
| revisit_projected[batch_idx, source_pos_i], |
| revisit_pool_h, revisit_pool_w, src_h, src_w, |
| ) |
| frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) |
| revisit_bank.add_frame_record( |
| frame_tokens, |
| frame_mask, |
| frame_index, |
| pose=_pose_subset(frame_pos, batch_idx), |
| source_type=MemorySourceType.PREFIX_GT, |
| metadata=_metadata_subset(frame_pos, batch_idx), |
| is_generated=False, |
| record_id=f"prefix_revisit_b{batch_idx}_f{frame}", |
| ) |
|
|
| if generated is not None: |
| generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() |
| if generated_positions.numel() > 0: |
| for source_pos in generated_positions: |
| source_pos_i = int(source_pos.item()) |
| frame_index = src_frames[source_pos_i] |
| frame = int(frame_index.detach().item()) |
| frame_pos = source_pos.reshape(1) |
| frame_tokens = self._spatial_pool_tokens( |
| revisit_projected[batch_idx, source_pos_i], |
| revisit_pool_h, revisit_pool_w, src_h, src_w, |
| ) |
| frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) |
| revisit_bank.add_frame_record( |
| frame_tokens, |
| frame_mask, |
| frame_index, |
| pose=_pose_subset(frame_pos, batch_idx), |
| source_type=MemorySourceType.GENERATED, |
| metadata=_metadata_subset(frame_pos, batch_idx), |
| is_generated=True, |
| record_id=f"generated_revisit_b{batch_idx}_f{frame}", |
| ) |
| if allow_generated_anchor: |
| for source_pos in generated_positions[:len(anchor_indices)]: |
| source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) |
| anchor_tokens = self._spatial_pool_tokens( |
| anchor_projected[batch_idx, source_pos_i], |
| anchor_pool_h, anchor_pool_w, src_h, src_w, |
| ) |
| record_mask = torch.ones((anchor_tokens.shape[0],), device=anchor_projected.device, dtype=torch.bool) |
| anchor_bank.add_generated_records( |
| anchor_tokens.unsqueeze(0), |
| record_mask.unsqueeze(0), |
| src_frames[source_pos_i:source_pos_i + 1], |
| source_type=MemorySourceType.GENERATED, |
| ) |
|
|
| anchor_banks.append(anchor_bank) |
| revisit_banks.append(revisit_bank) |
| return anchor_banks, revisit_banks |
|
|
| def _build_preselected_causal_memory_banks( |
| self, |
| committed_latents: torch.Tensor, |
| source_frame_indices: torch.Tensor, |
| source_is_generated: torch.Tensor | None, |
| pose: torch.Tensor | None, |
| action: torch.Tensor | None, |
| target_frame_indices: torch.Tensor, |
| target_pose: torch.Tensor | None, |
| target_action: torch.Tensor | None, |
| target_video_ids, |
| allow_generated_anchor: bool, |
| anchor_indices: list[int], |
| anchor_pool_h: int, |
| anchor_pool_w: int, |
| anchor_diverse: bool, |
| revisit_pool_h: int, |
| revisit_pool_w: int, |
| revisit_max_frames: int, |
| exclude_local_context_frames: int, |
| fov_overlap_threshold, |
| plucker_weight: float, |
| revisit_retrieval_kwargs: dict | None, |
| token_patch_size: int, |
| ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]: |
| if committed_latents.ndim != 5: |
| raise ValueError("committed_latents must have shape (T_src,B,C,H,W)") |
| T_src, B, _, H, W = committed_latents.shape |
| if source_frame_indices.shape != (T_src, B): |
| raise ValueError("source_frame_indices must have shape (T_src,B)") |
| if target_frame_indices.ndim == 1: |
| target_frame_indices = target_frame_indices[:, None] |
| if target_frame_indices.shape[1] != B: |
| raise ValueError("target_frame_indices must have batch dimension B") |
| T_tgt = target_frame_indices.shape[0] |
| stream_device = committed_latents.device |
| hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) |
| src_h, src_w = self._projected_spatial_grid_size( |
| H, |
| W, |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| ) |
| tokens_per_frame = src_h * src_w |
| generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) |
| anchor_banks: list[CausalMemoryBank] = [] |
| revisit_banks: list[CausalMemoryBank] = [] |
| dummy_tokens = committed_latents.new_zeros((1, hidden_size)) |
| dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool) |
| preselection_candidate_count = 0 |
| preselection_valid_candidate_label_count = 0 |
| preselection_selected_count = 0 |
| projected_anchor_frames = 0 |
| projected_revisit_frames = 0 |
| projected_revisit_records = 0 |
| retrieval_kwargs = dict(revisit_retrieval_kwargs or {}) |
|
|
| |
| |
| |
| if pose is not None: |
| pose = pose.to(device=stream_device) |
| if target_pose is not None: |
| target_pose = target_pose.to(device=stream_device) |
|
|
| def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): |
| if tensor is None or tensor.ndim < 2: |
| return None |
| if tensor.shape[0] == T_src and tensor.shape[1] == B: |
| return tensor[positions, batch_idx] |
| if tensor.shape[0] == B and tensor.shape[1] == T_src: |
| return tensor[batch_idx, positions] |
| return None |
|
|
| def _target_tensor(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): |
| if tensor is None or tensor.ndim < 2: |
| return None |
| if tensor.shape[0] == T_tgt and tensor.shape[1] == B: |
| return tensor[target_idx, batch_idx] |
| if tensor.shape[0] == B and tensor.shape[1] == T_tgt: |
| return tensor[batch_idx, target_idx] |
| return None |
|
|
| def _target_video_id(batch_idx: int, target_idx: int): |
| if target_video_ids is None: |
| return None |
| if torch.is_tensor(target_video_ids): |
| ids = target_video_ids.detach().cpu() |
| if ids.ndim == 0: |
| return ids.item() |
| if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: |
| return ids[target_idx, batch_idx].item() |
| if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: |
| return ids[batch_idx, target_idx].item() |
| return None |
| if isinstance(target_video_ids, (list, tuple)): |
| if len(target_video_ids) == B: |
| return target_video_ids[batch_idx] |
| if len(target_video_ids) == T_tgt: |
| row = target_video_ids[target_idx] |
| if isinstance(row, (list, tuple)) and len(row) == B: |
| return row[batch_idx] |
| return row |
| return target_video_ids |
|
|
| def _metadata_subset(positions: torch.Tensor, batch_idx: int): |
| return {} |
|
|
| def _pose_subset(positions: torch.Tensor, batch_idx: int): |
| return _tensor_subset(pose, positions, batch_idx) |
|
|
| def _candidate_record( |
| *, |
| batch_idx: int, |
| frame_position: torch.Tensor, |
| source_type: MemorySourceType, |
| is_generated: bool, |
| record_id: str, |
| ) -> MemoryRecord: |
| frame_values = source_frame_indices[frame_position, batch_idx].to(device=stream_device) |
| frame = int(frame_values.reshape(-1)[0].item()) |
| return MemoryRecord( |
| tokens=dummy_tokens, |
| mask=dummy_mask, |
| source_start=frame, |
| source_end=frame + 1, |
| frame_indices=frame_values.reshape(1), |
| pose=_pose_subset(frame_position, batch_idx), |
| source_type=source_type, |
| is_generated=bool(is_generated), |
| chunk_id=record_id, |
| metadata=_metadata_subset(frame_position, batch_idx), |
| ) |
|
|
| for batch_idx in range(B): |
| anchor_bank = CausalMemoryBank() |
| revisit_bank = CausalMemoryBank() |
| src_frames = source_frame_indices[:, batch_idx] |
| if generated is None: |
| non_generated = torch.ones_like(src_frames, dtype=torch.bool) |
| else: |
| non_generated = ~generated[:, batch_idx] |
| source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() |
|
|
| anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long) |
| if anchor_indices and source_positions.numel() > 0: |
| if anchor_diverse: |
| anchor_source_positions = source_positions[source_positions < self._context_frame_count()] |
| if anchor_source_positions.numel() > 0: |
| anchor_pose = _pose_subset(anchor_source_positions, batch_idx) |
| anchor_positions = self._select_diverse_anchor_positions( |
| anchor_source_positions, anchor_pose, len(anchor_indices) |
| ).to(device=stream_device, dtype=torch.long) |
| else: |
| selected_anchor_positions = [] |
| for anchor_idx in anchor_indices: |
| if 0 <= int(anchor_idx) < source_positions.numel(): |
| selected_anchor_positions.append(source_positions[int(anchor_idx)]) |
| if selected_anchor_positions: |
| anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long) |
| if anchor_positions.numel() > 0: |
| projected_anchor_frames += int(anchor_positions.numel()) |
| anchor_projected = self._project_latent_patch_tokens( |
| committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1], |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| )[0] |
| for local_idx, source_pos in enumerate(anchor_positions): |
| source_pos_i = int(source_pos.item()) |
| anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) |
| n_slots = anchor_tokens.shape[0] |
| record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool) |
| anchor_bank.add_prefix_anchors( |
| anchor_tokens.unsqueeze(0), |
| record_mask.unsqueeze(0), |
| src_frames[source_pos_i:source_pos_i + 1], |
| slots_per_anchor=n_slots, |
| ) |
|
|
| candidate_records: list[MemoryRecord] = [] |
| candidate_positions: dict[str, torch.Tensor] = {} |
| src_frames_cpu = src_frames.detach().cpu() |
| target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long) |
| latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames) |
| for prefix, positions, source_type, is_generated in ( |
| ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), |
| ( |
| "generated", |
| torch.empty(0, device=stream_device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), |
| MemorySourceType.GENERATED, |
| True, |
| ), |
| ): |
| if positions.numel() == 0 or latest_valid_source_frame_exclusive <= 0: |
| continue |
| positions_cpu = positions.detach().cpu().to(dtype=torch.long) |
| for frame_position_cpu in positions_cpu: |
| frame = int(src_frames_cpu[int(frame_position_cpu.item())].item()) |
| if frame >= latest_valid_source_frame_exclusive: |
| continue |
| frame_position = frame_position_cpu.reshape(1).to(device=stream_device, dtype=torch.long) |
| record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}" |
| candidate_positions[record_id] = frame_position |
| candidate_records.append(_candidate_record( |
| batch_idx=batch_idx, |
| frame_position=frame_position, |
| source_type=source_type, |
| is_generated=is_generated, |
| record_id=record_id, |
| )) |
|
|
| selected_frame_record_ids: set[str] = set() |
| selected_frame_metadata: dict[str, dict] = {} |
| for target_idx in range(T_tgt): |
| target_frame = int(target_frame_indices[target_idx, batch_idx].item()) |
| result = deterministic_revisit_retrieval( |
| candidate_records, |
| target_frame=target_frame, |
| target_pose=_target_tensor(target_pose, batch_idx, target_idx), |
| target_summary=None, |
| topk=revisit_max_frames, |
| exclude_local_context_frames=exclude_local_context_frames, |
| fov_overlap_threshold=fov_overlap_threshold, |
| plucker_weight=plucker_weight, |
| target_video_id=_target_video_id(batch_idx, target_idx), |
| **retrieval_kwargs, |
| ) |
| preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) |
| preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0)) |
| preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) |
| for selected_record in result.records: |
| if selected_record.chunk_id is None: |
| continue |
| record_id = str(selected_record.chunk_id) |
| selected_frame_record_ids.add(record_id) |
| selected_frame_metadata[record_id] = dict(selected_record.metadata) |
|
|
| for record in candidate_records: |
| if record.chunk_id not in selected_frame_record_ids: |
| continue |
| record_id = str(record.chunk_id) |
| frame_position = candidate_positions[record_id] |
| projected_revisit_records += 1 |
| projected_revisit_frames += int(frame_position.numel()) |
| revisit_projected = self._project_latent_patch_tokens( |
| committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1], |
| self.dememwm_revisit_proj, |
| token_patch_size, |
| )[0] |
| frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, src_h, src_w) |
| frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool) |
| record_metadata = dict(record.metadata) |
| record_metadata.update(selected_frame_metadata.get(record_id, {})) |
| revisit_bank.add_frame_record( |
| frame_tokens, |
| frame_mask, |
| record.frame_indices.reshape(-1)[0], |
| pose=record.pose, |
| source_type=record.source_type, |
| metadata=record_metadata, |
| is_generated=record.is_generated, |
| record_id=record.chunk_id, |
| ) |
|
|
| anchor_banks.append(anchor_bank) |
| revisit_banks.append(revisit_bank) |
|
|
| diagnostics = { |
| "preselected_anchor_projected_frame_count": projected_anchor_frames, |
| "preselected_revisit_projected_frame_count": projected_revisit_frames, |
| "preselected_revisit_projected_frame_record_count": projected_revisit_records, |
| "preselected_revisit_candidate_frame_count": preselection_candidate_count, |
| "preselected_revisit_candidate_count": preselection_candidate_count, |
| "preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count, |
| "preselected_revisit_selected_frame_count": preselection_selected_count, |
| "preselected_revisit_selected_count": preselection_selected_count, |
| } |
| return anchor_banks, revisit_banks, tokens_per_frame, diagnostics |
|
|
| def _causal_cached_revisit_records( |
| self, |
| records: Iterable[MemoryRecord], |
| target_frame: int, |
| ) -> list[MemoryRecord]: |
| target_frame = int(target_frame) |
| return [record for record in records if int(record.source_end) <= target_frame] |
|
|
| def _records_to_stream( |
| self, |
| records, |
| target_slots: int, |
| hidden_size: int, |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> tuple[torch.Tensor, torch.Tensor, int]: |
| target_slots = max(0, int(target_slots)) |
| record_list = list(records) |
| stacked_tokens, stacked_mask = stack_record_tokens(record_list, target_slots=target_slots) |
| max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1) |
| if stacked_tokens is None or stacked_mask is None or target_slots == 0: |
| tokens = torch.zeros((target_slots, hidden_size), device=device, dtype=dtype) |
| mask = torch.zeros((target_slots,), device=device, dtype=torch.bool) |
| return tokens, mask, max_source_frame |
| n = min(target_slots, stacked_tokens.shape[0]) |
| filled = stacked_tokens[:n].to(device=device, dtype=dtype) |
| filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool) |
| if n < target_slots: |
| pad = filled.new_zeros(target_slots - n, hidden_size) |
| pad_mask = torch.zeros(target_slots - n, device=device, dtype=torch.bool) |
| tokens = torch.cat([filled, pad], dim=0) |
| mask = torch.cat([filled_mask, pad_mask], dim=0) |
| else: |
| tokens = filled |
| mask = filled_mask |
| return tokens, mask, max_source_frame |
|
|
| def _project_streaming_revisit_records( |
| self, |
| *, |
| cache: StreamingCache, |
| batch_idx: int, |
| records: list[MemoryRecord], |
| device: torch.device, |
| dtype: torch.dtype, |
| token_patch_size: int, |
| revisit_pool_h: int, |
| revisit_pool_w: int, |
| projection_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord], |
| ) -> list[MemoryRecord]: |
| projected_records: list[MemoryRecord] = [] |
| for record in records: |
| if not bool(record.metadata.get("dememwm_revisit_metadata_only", False)): |
| projected_records.append(record) |
| continue |
| selected_frame_index = record.metadata.get("dememwm_selected_frame_index") |
| if selected_frame_index is None: |
| best_frame_idx = record.frame_indices[torch.argmax(record.frame_indices)].reshape(1) |
| else: |
| best_frame_idx = torch.as_tensor( |
| [int(selected_frame_index)], |
| device=record.frame_indices.device, |
| dtype=record.frame_indices.dtype, |
| ) |
| key = ( |
| int(batch_idx), |
| str(record.chunk_id or ""), |
| int(record.source_start), |
| int(record.source_end), |
| int(best_frame_idx.detach().cpu().reshape(-1)[0].item()), |
| bool(record.is_generated), |
| ) |
| cached = projection_cache.get(key) |
| if cached is not None: |
| projected_records.append(cached) |
| continue |
|
|
| raw_latents = cache.raw_latents_for_frames( |
| batch_idx=batch_idx, |
| frame_indices=best_frame_idx, |
| device=device, |
| dtype=dtype, |
| ) |
| revisit_projected = self._project_latent_patch_tokens( |
| raw_latents, |
| self.dememwm_revisit_proj, |
| token_patch_size, |
| )[0] |
| _proj_src_h, _proj_src_w = self._projected_spatial_grid_size( |
| raw_latents.shape[3], |
| raw_latents.shape[4], |
| self.dememwm_revisit_proj, |
| token_patch_size, |
| ) |
| frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, _proj_src_h, _proj_src_w) |
| frame_mask = torch.ones((frame_tokens.shape[0],), device=device, dtype=torch.bool) |
| metadata = { |
| key: (value.to(device=device) if torch.is_tensor(value) else value) |
| for key, value in record.metadata.items() |
| } |
| metadata["dememwm_revisit_metadata_only"] = False |
| projected = MemoryRecord( |
| tokens=frame_tokens, |
| mask=frame_mask, |
| source_start=int(record.source_start), |
| source_end=int(record.source_end), |
| frame_indices=record.frame_indices.to(device=device), |
| pose=None if record.pose is None else record.pose.to(device=device), |
| source_type=record.source_type, |
| is_generated=bool(record.is_generated), |
| score=record.score, |
| chunk_id=record.chunk_id, |
| metadata=metadata, |
| ) |
| projection_cache[key] = projected |
| projected_records.append(projected) |
| return projected_records |
|
|
| def build_memory_streams( |
| self, |
| committed_latents: torch.Tensor | None, |
| source_frame_indices: torch.Tensor | None, |
| target_frame_indices: torch.Tensor | None = None, |
| pose: torch.Tensor | None = None, |
| target_pose: torch.Tensor | None = None, |
| action: torch.Tensor | None = None, |
| target_action: torch.Tensor | None = None, |
| target_video_ids=None, |
| source_is_generated: torch.Tensor | None = None, |
| denoising_fraction: float | None = None, |
| noise_bucket: str | None = None, |
| noise_bucket_ids: torch.Tensor | None = None, |
| streaming_cache: StreamingCache | None = None, |
| ) -> MemoryStreamTensors: |
| if target_frame_indices is None: |
| if source_frame_indices is None: |
| raise ValueError("target_frame_indices or source_frame_indices is required") |
| target_frame_indices = source_frame_indices |
| memory_cfg = self._memory_cfg() |
| anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) |
| dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) |
| revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) |
| injection_cfg = self._cfg_get(memory_cfg, "injection", None) |
| contract_diag = self._validate_config_contract() |
| gate_state = self._effective_gate_state( |
| denoising_fraction=denoising_fraction, |
| noise_bucket=noise_bucket, |
| ) |
| anchor_config_enabled = gate_state["anchor_config_enabled"] |
| dynamic_config_enabled = gate_state["dynamic_config_enabled"] |
| revisit_config_enabled = gate_state["revisit_config_enabled"] |
| curriculum_state = gate_state["curriculum_state"] |
| eval_ablation_enabled = gate_state["eval_ablation_enabled"] |
| eval_ablation_branch = gate_state["eval_ablation_branch"] |
| resolved_noise_bucket = gate_state["resolved_noise_bucket"] |
| gates = gate_state["gates"] |
| anchor_effective_enabled = gate_state["anchor_effective_enabled"] |
| dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] |
| revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] |
| force_revisit_off = gate_state["force_revisit_off"] |
| force_revisit_on = gate_state["force_revisit_on"] |
| token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) |
| anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] |
| anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) |
| pool_latent_h = int(committed_latents.shape[-2]) if committed_latents is not None else int(self.x_stacked_shape[-2]) |
| pool_latent_w = int(committed_latents.shape[-1]) if committed_latents is not None else int(self.x_stacked_shape[-1]) |
| anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( |
| pool_latent_h, |
| pool_latent_w, |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| ) |
| anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( |
| anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 |
| ) |
| anchor_num_tokens = len(anchor_indices) * anchor_pool_h * anchor_pool_w |
| anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) |
| allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) |
| revisit_max_frames = int(self._cfg_get(revisit_cfg, "max_frames", 2)) |
| revisit_compress_cfg = self._cfg_get(revisit_cfg, "compress", None) |
| revisit_src_h, revisit_src_w = self._projected_spatial_grid_size( |
| pool_latent_h, |
| pool_latent_w, |
| self.dememwm_revisit_proj, |
| token_patch_size, |
| ) |
| revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size( |
| revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8 |
| ) |
| revisit_target_slots = revisit_max_frames * revisit_pool_h * revisit_pool_w |
| recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) |
| dynamic_recent_exclusion_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) |
| revisit_context_window_exclusion_frames = self._local_context_exclusion_frames() |
| fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) |
| high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) |
| plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) |
| revisit_retrieval_kwargs = { |
| "high_quality_fov_threshold": high_quality_fov_threshold, |
| "fov_half_h": float(self._cfg_get(revisit_cfg, "fov_half_h", 52.5)), |
| "fov_half_v": float(self._cfg_get(revisit_cfg, "fov_half_v", 37.5)), |
| "fov_yaw_samples": int(self._cfg_get(revisit_cfg, "fov_yaw_samples", 25)), |
| "fov_pitch_samples": int(self._cfg_get(revisit_cfg, "fov_pitch_samples", 20)), |
| "fov_depth_samples": int(self._cfg_get(revisit_cfg, "fov_depth_samples", 20)), |
| "fov_radius": float(self._cfg_get(revisit_cfg, "fov_radius", 30.0)), |
| "pose_preselect_topk": self._cfg_get(revisit_cfg, "pose_preselect_topk", 64), |
| "plucker_grid_h": int(self._cfg_get(revisit_cfg, "plucker_grid_h", 4)), |
| "plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)), |
| "plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)), |
| } |
| preselection_diag = {} |
| use_cache_revisit_records = False |
| revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None |
|
|
| cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None |
| cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0} |
| if committed_latents is not None: |
| stream_device = committed_latents.device |
| stream_dtype = committed_latents.dtype |
| else: |
| param = next(iter(self.dememwm_anchor_proj.parameters())) |
| stream_device = param.device |
| stream_dtype = param.dtype |
| target_frame_indices = target_frame_indices.to(device=stream_device) |
| if target_frame_indices.ndim == 1: |
| target_frame_indices = target_frame_indices[:, None] |
|
|
| use_cache_records = cache is not None and cache.keep_compressed_records and cache.record_count > 0 |
| dynamic_latents = committed_latents if dynamic_effective_enabled else None |
| dynamic_frame_indices = source_frame_indices if dynamic_effective_enabled else None |
| dynamic_generated = source_is_generated if dynamic_effective_enabled else None |
| dynamic_pose = pose if dynamic_effective_enabled else None |
| if dynamic_effective_enabled and cache is not None and cache.raw_frame_slots > 0: |
| raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( |
| device=stream_device, |
| dtype=stream_dtype, |
| max_recent_frames=recent_frames, |
| target_frame_indices=target_frame_indices, |
| exclude_latest_local_frames=dynamic_recent_exclusion_frames, |
| ) |
| if raw_latents is not None: |
| dynamic_latents = raw_latents |
| dynamic_frame_indices = raw_frames |
| dynamic_generated = raw_generated |
| dynamic_pose = raw_pose |
|
|
| if use_cache_records: |
| B = target_frame_indices.shape[1] |
| hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) |
| anchor_banks = ( |
| cache.memory_banks("anchor", device=stream_device, dtype=stream_dtype, batch_size=B) |
| if anchor_effective_enabled else [CausalMemoryBank() for _ in range(B)] |
| ) |
| revisit_banks = [CausalMemoryBank() for _ in range(B)] |
| revisit_record_batches = ( |
| [cache.records_for_batch("revisit", batch_idx) for batch_idx in range(B)] |
| if revisit_stage_config_enabled else [tuple() for _ in range(B)] |
| ) |
| use_cache_revisit_records = bool(revisit_stage_config_enabled) |
| if dynamic_latents is not None and dynamic_latents.ndim == 5 and dynamic_latents.shape[0] > 0: |
| tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( |
| dynamic_latents.shape[-2], |
| dynamic_latents.shape[-1], |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| ) |
| tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w |
| else: |
| latent_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 0 |
| latent_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 0 |
| tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( |
| latent_h, |
| latent_w, |
| self.dememwm_anchor_proj, |
| token_patch_size, |
| ) |
| tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w |
| else: |
| if committed_latents is None or source_frame_indices is None: |
| raise ValueError("committed_latents/source_frame_indices are required when no streaming cache records are available") |
| B = committed_latents.shape[1] |
| hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) |
| target_pose_source = target_pose if target_pose is not None else pose |
| anchor_banks, revisit_banks, tokens_per_frame, preselection_diag = self._build_preselected_causal_memory_banks( |
| committed_latents, |
| source_frame_indices.to(device=stream_device), |
| None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool), |
| None if pose is None else pose.to(device=stream_device), |
| None, |
| target_frame_indices, |
| None if target_pose_source is None else target_pose_source.to(device=stream_device), |
| None, |
| target_video_ids, |
| allow_generated_anchor, |
| anchor_indices, |
| anchor_pool_h, |
| anchor_pool_w, |
| anchor_diverse, |
| revisit_pool_h, |
| revisit_pool_w, |
| revisit_max_frames, |
| revisit_context_window_exclusion_frames, |
| fov_overlap_threshold, |
| plucker_weight, |
| revisit_retrieval_kwargs, |
| token_patch_size, |
| ) |
| revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] |
|
|
| T_tgt = target_frame_indices.shape[0] |
| anchor_slots = max(0, anchor_num_tokens) |
| revisit_slots = max(0, revisit_target_slots) |
| anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT |
| anchor_include_generated = allow_generated_anchor |
| anchor_token_rows = [] |
| anchor_mask_rows = [] |
| anchor_max_rows = [] |
| for batch_idx, anchor_bank in enumerate(anchor_banks): |
| batch_token_rows = [] |
| batch_mask_rows = [] |
| batch_max_rows = [] |
| for target_idx in range(T_tgt): |
| target_frame = int(target_frame_indices[target_idx, batch_idx].item()) |
| records = anchor_bank.query( |
| MemoryBankQuery( |
| target_frame=target_frame, |
| source_type=anchor_source_type, |
| include_generated=anchor_include_generated, |
| max_records=len(anchor_indices), |
| ) |
| ) |
| anchor_bank.assert_causal(target_frame, records) |
| stream_tokens, stream_mask, max_source_frame = self._records_to_stream( |
| records, |
| anchor_slots, |
| hidden_size, |
| stream_device, |
| stream_dtype, |
| ) |
| batch_token_rows.append(stream_tokens) |
| batch_mask_rows.append(stream_mask) |
| batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long)) |
| anchor_token_rows.append(torch.stack(batch_token_rows, dim=0)) |
| anchor_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) |
| anchor_max_rows.append(torch.stack(batch_max_rows, dim=0)) |
| anchor_tokens = torch.stack(anchor_token_rows, dim=0) |
| anchor_mask = torch.stack(anchor_mask_rows, dim=0) |
| anchor_max = torch.stack(anchor_max_rows, dim=0) |
|
|
| if dynamic_latents is None or dynamic_frame_indices is None or dynamic_latents.shape[0] == 0: |
| _fallback_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 18 |
| _fallback_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 32 |
| dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w) |
| dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype) |
| dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool) |
| dynamic_diag = { |
| "selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), |
| "max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), |
| "generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device), |
| "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), |
| "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), |
| "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), |
| "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames, |
| } |
| else: |
| |
| |
| |
| _dfi = dynamic_frame_indices.to(device=stream_device) |
| _max_src = self.dememwm_dynamic_compressor.max_source_frames |
| _needed: list[int] = [] |
| for _b in range(B): |
| for _j in range(T_tgt): |
| _target = int(target_frame_indices[_j, _b].item()) |
| _valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten() |
| _needed.extend(_valid[-_max_src:].tolist()) |
| if _needed: |
| _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long) |
| _dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx) |
| _dynamic_fi_small = _dfi.index_select(0, _needed_idx) |
| _dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None |
| _dynamic_gen_small = ( |
| dynamic_generated.to(device=stream_device, dtype=torch.bool).index_select(0, _needed_idx) |
| if dynamic_generated is not None else None |
| ) |
| else: |
| _dynamic_latents_small = dynamic_latents[:0] |
| _dynamic_fi_small = _dfi[:0] |
| _dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None |
| _dynamic_gen_small = None |
| dynamic_tokens, dynamic_mask, dynamic_diag = self.dememwm_dynamic_compressor( |
| _dynamic_latents_small, |
| _dynamic_fi_small, |
| _dynamic_pose_small, |
| target_frame_indices, |
| _dynamic_gen_small, |
| exclude_latest_local_frames=dynamic_recent_exclusion_frames, |
| ) |
|
|
| dynamic_min_gap_tensor = torch.as_tensor( |
| dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), |
| device=stream_device, |
| ) |
| dynamic_max_gap_tensor = torch.as_tensor( |
| dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), |
| device=stream_device, |
| ) |
| dynamic_gap_valid = dynamic_min_gap_tensor >= 0 |
| dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1 |
| dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0 |
| dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1 |
| def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): |
| if tensor is None or tensor.ndim < 2: |
| return None |
| tensor_dev = tensor.to(device=stream_device) |
| if tensor_dev.shape[0] == T_tgt and tensor_dev.shape[1] == B: |
| return tensor_dev[target_idx, batch_idx] |
| if tensor_dev.shape[0] == B and tensor_dev.shape[1] == T_tgt: |
| return tensor_dev[batch_idx, target_idx] |
| return None |
|
|
| def _target_video_id_or_none(batch_idx: int, target_idx: int): |
| if target_video_ids is None: |
| return None |
| if torch.is_tensor(target_video_ids): |
| ids = target_video_ids.detach().cpu() |
| if ids.ndim == 0: |
| return ids.item() |
| if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: |
| return ids[target_idx, batch_idx].item() |
| if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: |
| return ids[batch_idx, target_idx].item() |
| return None |
| if isinstance(target_video_ids, (list, tuple)): |
| if len(target_video_ids) == B: |
| return target_video_ids[batch_idx] |
| if len(target_video_ids) == T_tgt: |
| row = target_video_ids[target_idx] |
| if isinstance(row, (list, tuple)) and len(row) == B: |
| return row[batch_idx] |
| return row |
| return target_video_ids |
|
|
| target_pose_source = target_pose if target_pose is not None else pose |
|
|
| revisit_token_rows = [] |
| revisit_mask_rows = [] |
| revisit_max_rows = [] |
| valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) |
| revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) |
| revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) |
| revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) |
| revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) |
| revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32) |
| eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) |
| revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long) |
| eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES) |
| revisit_result_diagnostics = [] |
| projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {} |
| if revisit_record_batches is None: |
| revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] |
| for batch_idx in range(B): |
| revisit_bank = revisit_banks[batch_idx] |
| batch_token_rows = [] |
| batch_mask_rows = [] |
| batch_max_rows = [] |
| for target_idx in range(T_tgt): |
| target_frame = int(target_frame_indices[target_idx, batch_idx].item()) |
| if use_cache_revisit_records: |
| candidate_records = self._causal_cached_revisit_records( |
| revisit_record_batches[batch_idx], |
| target_frame, |
| ) |
| else: |
| candidate_records = revisit_bank.query( |
| MemoryBankQuery( |
| target_frame=target_frame, |
| include_generated=True, |
| ) |
| ) |
| result = deterministic_revisit_retrieval( |
| candidate_records, |
| target_frame=target_frame, |
| target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx), |
| target_summary=None, |
| topk=revisit_max_frames, |
| exclude_local_context_frames=revisit_context_window_exclusion_frames, |
| fov_overlap_threshold=fov_overlap_threshold, |
| plucker_weight=plucker_weight, |
| target_video_id=_target_video_id_or_none(batch_idx, target_idx), |
| **revisit_retrieval_kwargs, |
| ) |
| selected_records = result.records |
| if use_cache_revisit_records and selected_records: |
| selected_records = self._project_streaming_revisit_records( |
| cache=cache, |
| batch_idx=batch_idx, |
| records=selected_records, |
| device=stream_device, |
| dtype=stream_dtype, |
| token_patch_size=token_patch_size, |
| revisit_pool_h=revisit_pool_h, |
| revisit_pool_w=revisit_pool_w, |
| projection_cache=projected_revisit_record_cache, |
| ) |
| revisit_result_diagnostics.append(result.diagnostics) |
| revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) |
| revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) |
| revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0)) |
| revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0)) |
| revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1)) |
| revisit_bank.assert_causal(target_frame, selected_records) |
| if selected_records: |
| valid_revisit_mask[batch_idx, target_idx] = True |
| stream_tokens, stream_mask, max_source_frame = self._records_to_stream( |
| selected_records, |
| revisit_slots, |
| hidden_size, |
| stream_device, |
| stream_dtype, |
| ) |
| revisit_causal_max[batch_idx, target_idx] = max_source_frame |
| if eval_corruption_enabled: |
| stream_tokens, was_corrupted = apply_revisit_eval_corruption( |
| tokens=stream_tokens, |
| mask=stream_mask, |
| branch=eval_ablation_branch, |
| target_frame=target_frame, |
| ) |
| eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted) |
| batch_token_rows.append(stream_tokens) |
| batch_mask_rows.append(stream_mask) |
| batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long)) |
| revisit_token_rows.append(torch.stack(batch_token_rows, dim=0)) |
| revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) |
| revisit_max_rows.append(torch.stack(batch_max_rows, dim=0)) |
| revisit_tokens = torch.stack(revisit_token_rows, dim=0) |
| revisit_mask = torch.stack(revisit_mask_rows, dim=0) |
| revisit_max = torch.stack(revisit_max_rows, dim=0) |
|
|
| if anchor_tokens.shape[-2] != anchor_num_tokens: |
| raise AssertionError(f"anchor slot count mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}") |
| if dynamic_latents is not None and dynamic_latents.shape[0] > 0: |
| _expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target( |
| int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1]) |
| ) |
| if dynamic_tokens.shape[-2] != _expected_dyn: |
| raise AssertionError(f"dynamic slot count mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}") |
| if revisit_tokens.shape[-2] != revisit_target_slots: |
| raise AssertionError(f"revisit slot count mismatch: got {revisit_tokens.shape[-2]}, expected {revisit_target_slots}") |
| anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 |
| dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 |
| gate_module = getattr(self, "dememwm_revisit_gate", None) |
| if gate_module is None: |
| revisit_gate_raw = torch.ones((B, T_tgt), device=stream_device, dtype=stream_dtype) |
| else: |
| revisit_gate_raw = gate_module( |
| valid_revisit_mask=valid_revisit_mask, |
| best_selected_fov_overlap=revisit_best_selected_fov_overlap, |
| best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, |
| selected_gap_frames=revisit_selected_gap_frames, |
| ).to(device=stream_device, dtype=stream_dtype) |
| valid_revisit_eff_mask = valid_revisit_mask |
| if not revisit_stage_config_enabled or force_revisit_off: |
| revisit_gate = torch.zeros_like(revisit_gate_raw) |
| elif force_revisit_on: |
| revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * torch.ones_like(revisit_gate_raw) |
| else: |
| revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * revisit_gate_raw * float(gates.revisit_gate) |
| revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) |
| if not anchor_effective_enabled: |
| anchor_mask = torch.zeros_like(anchor_mask) |
| if not dynamic_effective_enabled: |
| dynamic_mask = torch.zeros_like(dynamic_mask) |
| if not revisit_stage_config_enabled: |
| revisit_mask = torch.zeros_like(revisit_mask) |
| valid_revisit_mask = torch.zeros_like(valid_revisit_mask) |
| revisit_candidate_count = torch.zeros_like(revisit_candidate_count) |
| revisit_selected_count = torch.zeros_like(revisit_selected_count) |
| revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap) |
| revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap) |
| revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0) |
| eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask) |
| valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) |
| revisit_gate_raw = torch.zeros_like(revisit_gate_raw) |
| revisit_gate = torch.zeros_like(revisit_gate) |
| no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) |
| revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask) |
| causal_violation_count = 0 |
| for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max): |
| if source_max is None: |
| continue |
| source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device) |
| valid = source_max_t >= 0 |
| if valid.any(): |
| causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item()) |
| diagnostics = { |
| **curriculum_state.diagnostics(), |
| **getattr(self, "_last_dememwm_freeze_diagnostics", {}), |
| **contract_diag, |
| **cache_diag, |
| **preselection_diag, |
| **revisit_diag, |
| "dememwm_stage": gates.stage, |
| "dememwm_gate_reason": gates.reason, |
| "anchor_config_enabled": anchor_config_enabled, |
| "dynamic_config_enabled": dynamic_config_enabled, |
| "revisit_config_enabled": revisit_config_enabled, |
| "anchor_effective_enabled": anchor_effective_enabled, |
| "dynamic_effective_enabled": dynamic_effective_enabled, |
| "revisit_effective_enabled": revisit_effective_enabled, |
| "revisit_stage_config_enabled": revisit_stage_config_enabled, |
| "revisit_gate_raw": revisit_gate_raw.detach(), |
| "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), |
| "no_valid_revisit_mask": no_valid_revisit_mask, |
| "valid_revisit_eff_mask": valid_revisit_eff_mask, |
| "revisit_candidate_frame_count_per_target": revisit_candidate_count, |
| "revisit_selected_frame_count_per_target": revisit_selected_count, |
| "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap, |
| "revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap, |
| "revisit_selected_gap_frames_per_target": revisit_selected_gap_frames, |
| "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, |
| "revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()), |
| **summarize_noise_bucket_diagnostics( |
| noise_bucket=resolved_noise_bucket, |
| valid_revisit_mask=valid_revisit_mask, |
| no_valid_revisit_mask=no_valid_revisit_mask, |
| noise_bucket_ids=noise_bucket_ids, |
| ), |
| **summarize_eval_ablation_diagnostics( |
| enabled=eval_ablation_enabled, |
| branch=eval_ablation_branch, |
| valid_revisit_mask=valid_revisit_mask, |
| no_valid_revisit_mask=no_valid_revisit_mask, |
| eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None, |
| ), |
| "token_patch_size": token_patch_size, |
| "tokens_per_frame": tokens_per_frame, |
| "anchor_token_slots": int(anchor_tokens.shape[-2]), |
| "anchor_target_slots": anchor_num_tokens, |
| "anchor_pool_h": anchor_pool_h, |
| "anchor_pool_w": anchor_pool_w, |
| "dynamic_token_slots": int(dynamic_tokens.shape[-2]), |
| "dynamic_target_slots": int(dynamic_tokens.shape[-2]), |
| "dynamic_min_gap_to_target": dynamic_min_gap_to_target, |
| "dynamic_max_gap_to_target": dynamic_max_gap_to_target, |
| "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames, |
| "revisit_token_slots": int(revisit_tokens.shape[-2]), |
| "revisit_target_slots": revisit_target_slots, |
| "revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames, |
| "revisit_pool_h": revisit_pool_h, |
| "revisit_pool_w": revisit_pool_w, |
| "revisit_max_frames": revisit_max_frames, |
| "anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0, |
| "dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0, |
| "revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0, |
| "causal_violation_count": causal_violation_count, |
| "anchor_max_source_frame": anchor_max, |
| "dynamic_max_source_frame": dynamic_diag.get("max_source_frame"), |
| "revisit_max_source_frame": revisit_max, |
| "dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"), |
| } |
| if eval_corruption_enabled: |
| diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask |
|
|
| return MemoryStreamTensors( |
| anchor_tokens=anchor_tokens, |
| anchor_mask=anchor_mask, |
| dynamic_tokens=dynamic_tokens, |
| dynamic_mask=dynamic_mask, |
| revisit_tokens=revisit_tokens, |
| revisit_mask=revisit_mask, |
| anchor_gate=anchor_gate, |
| dynamic_gate=dynamic_gate, |
| revisit_gate=revisit_gate, |
| revisit_gate_raw=revisit_gate_raw, |
| valid_revisit_mask=valid_revisit_mask, |
| no_valid_revisit_mask=no_valid_revisit_mask, |
| diagnostics=diagnostics, |
| ) |
|
|
| def _refresh_stream_gates( |
| self, |
| streams: MemoryStreamTensors, |
| denoising_fraction: float | None = None, |
| noise_bucket: str | None = None, |
| ) -> MemoryStreamTensors: |
| gate_state = self._effective_gate_state( |
| denoising_fraction=denoising_fraction, |
| noise_bucket=noise_bucket, |
| ) |
| gates = gate_state["gates"] |
| device = streams.anchor_tokens.device |
| dtype = streams.anchor_tokens.dtype |
| B, T_tgt = streams.anchor_tokens.shape[:2] |
| valid_revisit_mask = streams.valid_revisit_mask |
| if valid_revisit_mask is None: |
| valid_revisit_mask = torch.zeros((B, T_tgt), device=device, dtype=torch.bool) |
| else: |
| valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool) |
|
|
| diagnostics = dict(streams.diagnostics) |
|
|
| def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor: |
| value = diagnostics.get(name) |
| if value is None: |
| return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32) |
| tensor = torch.as_tensor(value, device=device, dtype=torch.float32) |
| if tensor.ndim == 0: |
| return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32) |
| return tensor.expand((B, T_tgt)) |
|
|
| revisit_best_selected_fov_overlap = _diagnostic_tensor("revisit_best_selected_fov_overlap_per_target") |
| revisit_best_selected_plucker_overlap = _diagnostic_tensor("revisit_best_selected_plucker_overlap_per_target") |
| revisit_selected_gap_frames = _diagnostic_tensor("revisit_selected_gap_frames_per_target", -1.0) |
|
|
| anchor_effective_enabled = gate_state["anchor_effective_enabled"] |
| dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] |
| revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] |
| anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 |
| dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 |
| gate_module = getattr(self, "dememwm_revisit_gate", None) |
| if gate_module is None: |
| revisit_gate_raw = torch.ones((B, T_tgt), device=device, dtype=dtype) |
| else: |
| revisit_gate_raw = gate_module( |
| valid_revisit_mask=valid_revisit_mask, |
| best_selected_fov_overlap=revisit_best_selected_fov_overlap, |
| best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, |
| selected_gap_frames=revisit_selected_gap_frames, |
| ).to(device=device, dtype=dtype) |
| valid_revisit_eff_mask = valid_revisit_mask |
| if not revisit_stage_config_enabled or gate_state["force_revisit_off"]: |
| revisit_gate = torch.zeros_like(revisit_gate_raw) |
| elif gate_state["force_revisit_on"]: |
| revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw) |
| else: |
| revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate) |
| if not revisit_stage_config_enabled: |
| valid_revisit_mask = torch.zeros_like(valid_revisit_mask) |
| valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) |
| revisit_gate_raw = torch.zeros_like(revisit_gate_raw) |
| revisit_gate = torch.zeros_like(revisit_gate) |
| no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) |
| eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask") |
| if eval_corrupted_revisit_mask is not None: |
| eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool) |
| revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) |
| diagnostics.update(gate_state["curriculum_state"].diagnostics()) |
| diagnostics.update({ |
| "dememwm_stage": gates.stage, |
| "dememwm_gate_reason": gates.reason, |
| "anchor_config_enabled": gate_state["anchor_config_enabled"], |
| "dynamic_config_enabled": gate_state["dynamic_config_enabled"], |
| "revisit_config_enabled": gate_state["revisit_config_enabled"], |
| "anchor_effective_enabled": anchor_effective_enabled, |
| "dynamic_effective_enabled": dynamic_effective_enabled, |
| "revisit_effective_enabled": revisit_effective_enabled, |
| "revisit_stage_config_enabled": revisit_stage_config_enabled, |
| "revisit_gate_raw": revisit_gate_raw.detach(), |
| "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), |
| "no_valid_revisit_mask": no_valid_revisit_mask, |
| "valid_revisit_eff_mask": valid_revisit_eff_mask, |
| "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, |
| "revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0, |
| }) |
| diagnostics.update(summarize_noise_bucket_diagnostics( |
| noise_bucket=gate_state["resolved_noise_bucket"], |
| valid_revisit_mask=valid_revisit_mask, |
| no_valid_revisit_mask=no_valid_revisit_mask, |
| )) |
| diagnostics.update(summarize_eval_ablation_diagnostics( |
| enabled=gate_state["eval_ablation_enabled"], |
| branch=gate_state["eval_ablation_branch"], |
| valid_revisit_mask=valid_revisit_mask, |
| no_valid_revisit_mask=no_valid_revisit_mask, |
| eval_corrupted_revisit_mask=eval_corrupted_revisit_mask, |
| )) |
| return replace( |
| streams, |
| anchor_gate=anchor_gate, |
| dynamic_gate=dynamic_gate, |
| revisit_gate=revisit_gate, |
| revisit_gate_raw=revisit_gate_raw, |
| valid_revisit_mask=valid_revisit_mask, |
| no_valid_revisit_mask=no_valid_revisit_mask, |
| diagnostics=diagnostics, |
| ) |
|
|
| def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> tuple[dict, dict]: |
| memory_kwargs, diagnostics = self.dememwm_injection_adapter(streams, device=streams.anchor_tokens.device, dtype=streams.anchor_tokens.dtype) |
| return memory_kwargs, diagnostics |
|
|
| def build_memory_kwargs(self, *args, **kwargs) -> tuple[dict, dict]: |
| streams = self.build_memory_streams(*args, **kwargs) |
| return self._streams_to_kwargs(streams) |
|
|
| def _memory_adapter_delta_diagnostics(self) -> dict: |
| dit_model = getattr(getattr(self, "diffusion_model", None), "model", None) |
| diagnostics_fn = getattr(dit_model, "memory_adapter_delta_diagnostics", None) |
| if diagnostics_fn is None: |
| return {} |
| return diagnostics_fn() |
|
|
| def _log_memory_diagnostics(self, namespace: str, diagnostics: dict) -> None: |
| if namespace == "training/dememwm": |
| allowed_keys = self._TRAIN_DIAGNOSTIC_LOG_KEYS |
| elif namespace.endswith("/dememwm"): |
| allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS |
| else: |
| allowed_keys = None |
| for key, value in diagnostics.items(): |
| if allowed_keys is not None and key not in allowed_keys: |
| continue |
| if isinstance(value, str) or value is None: |
| continue |
| if torch.is_tensor(value): |
| if value.numel() > 0: |
| self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True) |
| elif isinstance(value, (bool, int, float)): |
| self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True) |
|
|
| def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx): |
| from ..df_video import convert_to_plucker |
| image_height, image_width = self._image_size(xs) |
| if self.use_plucker: |
| if self.relative_embedding: |
| input_pose_condition = [] |
| frame_idx_list = [] |
| ref_c2w = c2w_mat[-self.memory_condition_length:] if self.memory_condition_length else c2w_mat[:0] |
| ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0] |
| for i in range(c2w_mat.shape[0]): |
| input_pose_condition.append( |
| convert_to_plucker( |
| torch.cat([c2w_mat[i:i + 1], ref_c2w]).clone(), |
| 0, |
| focal_length=self.focal_length, |
| image_height=image_height, image_width=image_width |
| ).to(xs.dtype) |
| ) |
| frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone()) |
| return torch.cat(input_pose_condition), torch.cat(frame_idx_list) |
| return convert_to_plucker( |
| c2w_mat, 0, focal_length=self.focal_length, |
| image_height=image_height, image_width=image_width |
| ).to(xs.dtype), frame_idx |
| return pose_conditions.to(xs.dtype), None |
|
|
| def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]: |
| total_frames = max(0, int(total_frames)) |
| n_tokens = max(1, min(int(self.n_tokens), total_frames)) |
| max_start = max(0, total_frames - n_tokens) |
| if max_start == 0: |
| return 0, n_tokens |
| context_start = self._context_frame_count() |
| min_start = min(context_start, max_start) |
| if min_start == max_start: |
| return min_start, min_start + n_tokens |
| start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item()) |
| return start, start + n_tokens |
|
|
| def training_step(self, batch, batch_idx): |
| xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) |
| xs = self._as_latents(xs) |
|
|
| |
| |
| |
| total_frames = xs.shape[0] |
| start, end = self._training_window_bounds(total_frames, xs.device) |
|
|
| xs_window = xs[start:end] |
| conditions_window = conditions[start:end].clone() |
| frame_idx_window = frame_idx[start:end] |
|
|
| input_pose_condition, frame_idx_list = self._training_pose_condition( |
| xs_window, pose_conditions[start:end], c2w_mat[start:end], frame_idx_window |
| ) |
|
|
| noise_levels = self._generate_noise_levels(xs_window) |
| if self.memory_condition_length: |
| noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level |
| conditions_window[-self.memory_condition_length:] *= 0 |
| source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool) |
| memory_source_latents, source_is_generated, proxy_diagnostics = self._apply_generated_history_proxy( |
| xs, |
| source_is_generated, |
| context_frame_count=self._context_frame_count(), |
| target_start_frame=start, |
| ) |
| timesteps = int(getattr(self, "timesteps", 0) or 0) |
| training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps) |
| training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps) |
| training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps) |
| memory_kwargs, diagnostics = self.build_memory_kwargs( |
| memory_source_latents, |
| frame_idx, |
| target_frame_indices=frame_idx_window, |
| pose=pose_conditions, |
| target_pose=pose_conditions[start:end], |
| action=conditions, |
| target_action=conditions_window, |
| source_is_generated=source_is_generated, |
| denoising_fraction=training_denoising_fraction, |
| noise_bucket=training_noise_bucket, |
| noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1), |
| ) |
| diagnostics.update(proxy_diagnostics) |
| _, loss = self.diffusion_model( |
| xs_window, |
| conditions_window, |
| input_pose_condition, |
| noise_levels=noise_levels, |
| reference_length=self.memory_condition_length, |
| frame_idx=frame_idx_list, |
| **memory_kwargs, |
| ) |
| diagnostics.update(self._memory_adapter_delta_diagnostics()) |
| if self.memory_condition_length: |
| loss = loss[:-self.memory_condition_length] |
| loss_denoise = self.reweight_loss(loss, None) |
| loss_total = loss_denoise |
| diagnostics["training_window_start"] = int(start) |
| diagnostics["training_window_end"] = int(end) |
| diagnostics["training_window_size"] = int(end - start) |
| diagnostics["loss_denoise"] = float(loss_denoise.detach().item()) |
| diagnostics["loss_total"] = float(loss_total.detach().item()) |
| if batch_idx % 20 == 0: |
| self.log("training/loss", loss_total.detach().cpu()) |
| self._log_memory_diagnostics("training/dememwm", diagnostics) |
| return {"loss": loss_total} |
|
|
| def validation_step(self, batch, batch_idx, namespace="validation"): |
| import numpy as np |
| from tqdm import tqdm |
|
|
| memory_condition_length = self.memory_condition_length |
| xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) |
| total_frame = xs_raw.shape[0] |
| if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): |
| xs = xs_raw.cpu() |
| elif total_frame > 10: |
| xs = torch.cat([self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu() for i in range(10)]) |
| else: |
| xs = self.encode(xs_raw).cpu() |
| n_frames, batch_size, *_ = xs.shape |
| curr_frame = 0 |
| n_context_frames = self.context_frames // self.frame_stack |
| xs_pred = xs[:n_context_frames].clone() |
| curr_frame += n_context_frames |
| streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}") |
| cached_until = 0 |
| pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") |
| last_diagnostics = None |
| while curr_frame < n_frames: |
| if streaming_cache is not None and curr_frame > cached_until: |
| new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) |
| if curr_frame > n_context_frames: |
| rel_start = max(0, n_context_frames - cached_until) |
| new_generated[rel_start:] = True |
| self._update_streaming_cache( |
| streaming_cache, |
| xs_pred[cached_until:curr_frame], |
| frame_idx[cached_until:curr_frame], |
| pose=pose_conditions[cached_until:curr_frame], |
| source_is_generated=new_generated, |
| action=conditions[cached_until:curr_frame], |
| ) |
| cached_until = curr_frame |
| horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame |
| assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens." |
| scheduling_matrix = self._generate_scheduling_matrix(horizon) |
| chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])) |
| chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device) |
| xs_pred = torch.cat([xs_pred, chunk], 0) |
| start_frame = max(0, curr_frame + horizon - self.n_tokens) |
| pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon}) |
| if memory_condition_length: |
| random_idx = self._generate_condition_indices(curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon) |
| xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0) |
| else: |
| random_idx = torch.empty((0, batch_size), dtype=torch.long, device=frame_idx.device) |
| input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions( |
| start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx, |
| image_width=self._image_size(xs_raw)[1], image_height=self._image_size(xs_raw)[0] |
| ) |
| target_idx = frame_idx[start_frame:curr_frame + horizon].to(input_condition.device) |
| use_streaming_cache = streaming_cache is not None and streaming_cache.record_count > 0 |
| target_pose = pose_conditions[start_frame:curr_frame + horizon].to(input_condition.device) |
| target_action = conditions[start_frame:curr_frame + horizon].to(input_condition.device) |
| if use_streaming_cache: |
| committed_latents = None |
| committed_idx = None |
| generated_flags = None |
| source_pose = None |
| source_action = None |
| else: |
| committed_latents = xs_pred[:curr_frame].to(input_condition.device) |
| committed_idx = frame_idx[:curr_frame].to(input_condition.device) |
| generated_flags = torch.zeros(committed_idx.shape, device=input_condition.device, dtype=torch.bool) |
| if curr_frame > n_context_frames: |
| generated_flags[n_context_frames:] = True |
| source_pose = pose_conditions[:curr_frame].to(input_condition.device) |
| source_action = conditions[:curr_frame].to(input_condition.device) |
| memory_streams = self.build_memory_streams( |
| committed_latents, |
| committed_idx, |
| target_frame_indices=target_idx, |
| pose=source_pose, |
| target_pose=target_pose, |
| action=source_action, |
| target_action=target_action, |
| source_is_generated=generated_flags, |
| denoising_fraction=None, |
| streaming_cache=streaming_cache, |
| ) |
| for m in range(scheduling_matrix.shape[0] - 1): |
| from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length) |
| denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0) |
| step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac) |
| memory_kwargs, last_diagnostics = self._streams_to_kwargs(step_streams) |
| xs_pred[start_frame:] = self.diffusion_model.sample_step( |
| xs_pred[start_frame:].to(input_condition.device), |
| input_condition, |
| input_pose_condition, |
| from_noise_levels[start_frame:], |
| to_noise_levels[start_frame:], |
| current_frame=curr_frame, |
| mode="validation", |
| reference_length=memory_condition_length, |
| frame_idx=frame_idx_list, |
| **memory_kwargs, |
| ).cpu() |
| if memory_condition_length: |
| xs_pred = xs_pred[:-memory_condition_length] |
| curr_frame += horizon |
| if streaming_cache is not None and curr_frame > cached_until: |
| new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) |
| if curr_frame > n_context_frames: |
| rel_start = max(0, n_context_frames - cached_until) |
| new_generated[rel_start:] = True |
| self._update_streaming_cache( |
| streaming_cache, |
| xs_pred[cached_until:curr_frame], |
| frame_idx[cached_until:curr_frame], |
| pose=pose_conditions[cached_until:curr_frame], |
| source_is_generated=new_generated, |
| action=conditions[cached_until:curr_frame], |
| ) |
| cached_until = curr_frame |
| if last_diagnostics is not None: |
| last_diagnostics.update(streaming_cache.diagnostics("cache")) |
| pbar.update(horizon) |
| pbar.close() |
| if last_diagnostics is not None: |
| self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics) |
| xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device)) |
| xs_decode = self.decode(xs[n_context_frames:].to(conditions.device)) |
| self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu())) |
| return |
|
|
| def strict_checkpoint_key_check(self, state_dict: dict, required_prefixes: Iterable[str] | None = None) -> None: |
| prefixes = tuple(required_prefixes or self.strict_key_prefixes) |
| strip_prefixes = ("", "model.", "module.", "algo.") |
| normalized_keys = [] |
| for key in state_dict.keys(): |
| key = str(key) |
| for strip_prefix in strip_prefixes: |
| if not strip_prefix or key.startswith(strip_prefix): |
| normalized_keys.append(key.removeprefix(strip_prefix)) |
| missing_prefixes = [prefix for prefix in prefixes if not any(key.startswith(prefix) for key in normalized_keys)] |
| missing_substrings = [ |
| marker |
| for marker in self.strict_key_substrings |
| if not any(marker in key for key in normalized_keys) |
| ] |
| if missing_prefixes or missing_substrings: |
| raise RuntimeError( |
| "DeMemWM checkpoint is missing required DeMemWM key coverage: " |
| f"prefixes={missing_prefixes}, memory_adapter_markers={missing_substrings}" |
| ) |
|
|
| |
| dememwm_strict_key_prefixes = strict_key_prefixes |
| dememwm_strict_key_substrings = strict_key_substrings |
| _DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS |
| _DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS |
| _dememwm_cfg = _memory_cfg |
| _dememwm_stage_policy_cfg = _stage_policy_cfg |
| _dememwm_eval_ablation_cfg = _eval_ablation_cfg |
| _dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg |
| _dememwm_eval_ablation_state = _eval_ablation_state |
| _dememwm_effective_gate_state = _effective_gate_state |
| _dememwm_validate_config_contract = _validate_config_contract |
| _dememwm_stream_enabled = _stream_enabled |
| _dememwm_context_frame_count = _context_frame_count |
| _dememwm_local_context_exclusion_frames = _local_context_exclusion_frames |
| _dememwm_curriculum_state = _curriculum_state |
| _dememwm_generated_history_proxy_prob = _generated_history_proxy_prob |
| _dememwm_apply_generated_history_proxy = _apply_generated_history_proxy |
| _dememwm_checkpoint_cfg = _checkpoint_cfg |
| _dememwm_strict_eval_load_enabled = _strict_eval_load_enabled |
| _dememwm_cache_cfg = _cache_cfg |
| _dememwm_cache_enabled = _cache_enabled |
| _dememwm_new_streaming_cache = _new_streaming_cache |
| _dememwm_is_memory_adapter_param = _is_memory_adapter_param |
| _dememwm_param_group_name = _param_group_name |
| _dememwm_group_trainable = _group_trainable |
| _dememwm_group_lr = _group_lr |
| _dememwm_apply_freeze_policy = _apply_freeze_policy |
| _dememwm_as_latents = _as_latents |
| _dememwm_image_size = _image_size |
| _dememwm_update_streaming_cache = _update_streaming_cache |
| _build_dememwm_streaming_cache_records = _build_streaming_cache_records |
| _build_dememwm_causal_memory_banks = _build_causal_memory_banks |
| _build_dememwm_preselected_causal_memory_banks = _build_preselected_causal_memory_banks |
| _dememwm_records_to_stream = _records_to_stream |
| build_dememwm_memory_streams = build_memory_streams |
| _dememwm_refresh_stream_gates = _refresh_stream_gates |
| _dememwm_streams_to_kwargs = _streams_to_kwargs |
| build_dememwm_memory_kwargs = build_memory_kwargs |
| _dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics |
| _log_dememwm_diagnostics = _log_memory_diagnostics |
| _dememwm_training_window_bounds = _training_window_bounds |
| strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check |
|
|
|
|
| DeMemWMMemoryDiTMixin = MemoryDiTMixin |
|
|