Optimize DeMemWM memory retrieval and remove diagnostics
Browse files- algorithms/worldmem/dememwm/algorithm.py +358 -454
- algorithms/worldmem/dememwm/cache.py +0 -17
- algorithms/worldmem/dememwm/compression.py +70 -77
- algorithms/worldmem/dememwm/diagnostics.py +0 -172
- algorithms/worldmem/dememwm/injection.py +7 -19
- algorithms/worldmem/dememwm/retrieval.py +368 -121
- algorithms/worldmem/dememwm/schedules.py +1 -67
- algorithms/worldmem/dememwm/types.py +6 -4
- algorithms/worldmem/models/dit.py +0 -74
- configurations/algorithm/dememwm_memory_dit.yaml +0 -3
- scripts/dememwm_full_eval.slurm +0 -2
- scripts/dememwm_full_train.slurm +0 -2
- tests/test_dememwm_compression.py +126 -17
- tests/test_dememwm_config_static.py +33 -27
- tests/test_dememwm_dit_extension_static.py +5 -10
- tests/test_dememwm_eval_ablation.py +2 -25
- tests/test_dememwm_freeze_policy.py +0 -3
- tests/test_dememwm_generated_history_proxy.py +4 -6
- tests/test_dememwm_injection_static.py +4 -8
- tests/test_dememwm_noise_bucket.py +18 -101
- tests/test_dememwm_preselection.py +9 -8
- tests/test_dememwm_retrieval.py +25 -35
- tests/test_dememwm_schedules.py +3 -26
- train_dememwm_full_berzelius.sh +0 -2
algorithms/worldmem/dememwm/algorithm.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
import math
|
|
@@ -10,13 +9,13 @@ from einops import rearrange
|
|
| 10 |
|
| 11 |
from .cache import StreamingCache
|
| 12 |
from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens
|
| 13 |
-
from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics
|
| 14 |
from .injection import InjectionAdapter
|
| 15 |
from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
|
| 16 |
from .negatives import apply_revisit_eval_corruption
|
| 17 |
-
from .retrieval import deterministic_revisit_retrieval
|
| 18 |
-
from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels,
|
| 19 |
from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class MemoryDiTMixin:
|
|
@@ -36,35 +35,9 @@ class MemoryDiTMixin:
|
|
| 36 |
strict_key_substrings = (
|
| 37 |
".memory_token_cross_attn.",
|
| 38 |
)
|
| 39 |
-
_TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({
|
| 40 |
-
"revisit_candidate_frame_count",
|
| 41 |
-
"revisit_pose_preselect_input_count",
|
| 42 |
-
"revisit_pose_preselect_selected_count",
|
| 43 |
-
"revisit_exact_fov_candidate_count",
|
| 44 |
-
"valid_revisit_frame_count",
|
| 45 |
-
"no_valid_revisit_count",
|
| 46 |
-
"revisit_selected_frame_count",
|
| 47 |
-
"revisit_frame_fov_overlap_mean",
|
| 48 |
-
"revisit_best_selected_frame_fov_overlap_mean",
|
| 49 |
-
"revisit_best_selected_plucker_overlap_mean",
|
| 50 |
-
"revisit_best_selected_gap_frames_mean",
|
| 51 |
-
"revisit_gate_raw",
|
| 52 |
-
"revisit_gate_eff",
|
| 53 |
-
"revisit_learned_gate_mean",
|
| 54 |
-
"revisit_effective_gate_mean",
|
| 55 |
-
"generated_history_proxy_prob",
|
| 56 |
-
"noise_bucket_target_count",
|
| 57 |
-
"noise_bucket_high_target_count",
|
| 58 |
-
"noise_bucket_mid_target_count",
|
| 59 |
-
"noise_bucket_low_target_count",
|
| 60 |
-
})
|
| 61 |
-
_VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({
|
| 62 |
-
"cache_records",
|
| 63 |
-
"cache_slots",
|
| 64 |
-
})
|
| 65 |
|
| 66 |
def _memory_cfg(self):
|
| 67 |
-
return getattr(self
|
| 68 |
|
| 69 |
def _cfg_get(self, obj, name, default):
|
| 70 |
if obj is None:
|
|
@@ -84,8 +57,6 @@ class MemoryDiTMixin:
|
|
| 84 |
except Exception:
|
| 85 |
return False
|
| 86 |
|
| 87 |
-
def _stage_policy_cfg(self):
|
| 88 |
-
return self._cfg_get(self._memory_cfg(), "stage_policy", None)
|
| 89 |
|
| 90 |
def _eval_ablation_cfg(self):
|
| 91 |
return self._cfg_get(self._memory_cfg(), "eval_ablation", None)
|
|
@@ -99,7 +70,7 @@ class MemoryDiTMixin:
|
|
| 99 |
branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal"))
|
| 100 |
return enabled, branch
|
| 101 |
|
| 102 |
-
def _effective_gate_state(self, denoising_fraction: float | None = None
|
| 103 |
memory_cfg = self._memory_cfg()
|
| 104 |
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
|
| 105 |
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
|
|
@@ -110,12 +81,9 @@ class MemoryDiTMixin:
|
|
| 110 |
revisit_config_enabled = self._stream_enabled(revisit_cfg)
|
| 111 |
curriculum_state = self._curriculum_state()
|
| 112 |
eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state()
|
| 113 |
-
debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False))
|
| 114 |
-
resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction)
|
| 115 |
gates = compute_stream_gates(
|
| 116 |
curriculum_state.stage,
|
| 117 |
denoising_fraction=denoising_fraction,
|
| 118 |
-
debug_force_all_streams=debug_force,
|
| 119 |
anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)),
|
| 120 |
dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)),
|
| 121 |
revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)),
|
|
@@ -139,7 +107,6 @@ class MemoryDiTMixin:
|
|
| 139 |
return {
|
| 140 |
"curriculum_state": curriculum_state,
|
| 141 |
"gates": gates,
|
| 142 |
-
"resolved_noise_bucket": resolved_noise_bucket,
|
| 143 |
"anchor_config_enabled": anchor_config_enabled,
|
| 144 |
"dynamic_config_enabled": dynamic_config_enabled,
|
| 145 |
"revisit_config_enabled": revisit_config_enabled,
|
|
@@ -152,14 +119,13 @@ class MemoryDiTMixin:
|
|
| 152 |
"force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"),
|
| 153 |
}
|
| 154 |
|
| 155 |
-
def _validate_config_contract(self) ->
|
| 156 |
if bool(getattr(self, "_dememwm_contract_validated", False)):
|
| 157 |
-
return
|
| 158 |
memory_cfg = self._memory_cfg()
|
| 159 |
if memory_cfg is None:
|
| 160 |
self._dememwm_contract_validated = True
|
| 161 |
-
|
| 162 |
-
return {}
|
| 163 |
|
| 164 |
stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)]
|
| 165 |
if stale_sections:
|
|
@@ -193,10 +159,8 @@ class MemoryDiTMixin:
|
|
| 193 |
if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)):
|
| 194 |
raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval")
|
| 195 |
fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
|
| 196 |
-
if fov_overlap_threshold is not None:
|
| 197 |
-
|
| 198 |
-
if fov_overlap_threshold < 0.0:
|
| 199 |
-
raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative")
|
| 200 |
high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
|
| 201 |
if high_quality_fov_threshold < 0.0:
|
| 202 |
raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative")
|
|
@@ -222,9 +186,6 @@ class MemoryDiTMixin:
|
|
| 222 |
value = int(self._cfg_get(revisit_cfg, field_name, default))
|
| 223 |
if value <= 0:
|
| 224 |
raise ValueError(f"dememwm.revisit.{field_name} must be positive")
|
| 225 |
-
stage_policy_cfg = self._stage_policy_cfg()
|
| 226 |
-
if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)):
|
| 227 |
-
raise ValueError("final DeMemWM keeps noise_bucket logging enabled")
|
| 228 |
proxy_cfg = self._generated_history_proxy_cfg()
|
| 229 |
proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0))
|
| 230 |
proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0))
|
|
@@ -240,18 +201,7 @@ class MemoryDiTMixin:
|
|
| 240 |
raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative")
|
| 241 |
eval_ablation_cfg = self._eval_ablation_cfg()
|
| 242 |
normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal"))
|
| 243 |
-
|
| 244 |
-
diagnostics = {
|
| 245 |
-
"dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
|
| 246 |
-
"revisit_deterministic_fov_plucker_retrieval": True,
|
| 247 |
-
"revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
|
| 248 |
-
"revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
|
| 249 |
-
"revisit_plucker_weight": plucker_weight,
|
| 250 |
-
"stage_policy_noise_bucket_logging": True,
|
| 251 |
-
}
|
| 252 |
self._dememwm_contract_validated = True
|
| 253 |
-
self._last_dememwm_config_diagnostics = diagnostics
|
| 254 |
-
return diagnostics
|
| 255 |
|
| 256 |
def _stream_enabled(self, stream_cfg) -> bool:
|
| 257 |
return bool(self._cfg_get(stream_cfg, "enabled", True))
|
|
@@ -294,25 +244,17 @@ class MemoryDiTMixin:
|
|
| 294 |
source_is_generated: torch.Tensor | None,
|
| 295 |
context_frame_count: int | None = None,
|
| 296 |
target_start_frame: int | None = None,
|
| 297 |
-
) -> tuple[torch.Tensor, torch.Tensor
|
| 298 |
cfg = self._generated_history_proxy_cfg()
|
| 299 |
prob = self._generated_history_proxy_prob()
|
| 300 |
noise_std = float(self._cfg_get(cfg, "noise_std", 0.0))
|
| 301 |
dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0))
|
| 302 |
-
diagnostics = {
|
| 303 |
-
"generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)),
|
| 304 |
-
"generated_history_proxy_prob": float(prob),
|
| 305 |
-
"generated_history_proxy_noise_std": float(noise_std),
|
| 306 |
-
"generated_history_proxy_dropout_prob": float(dropout_prob),
|
| 307 |
-
"generated_history_proxy_frame_count": 0,
|
| 308 |
-
"generated_history_proxy_frame_fraction": 0.0,
|
| 309 |
-
}
|
| 310 |
if source_is_generated is None:
|
| 311 |
source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
|
| 312 |
else:
|
| 313 |
source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool)
|
| 314 |
if prob <= 0.0 or source_latents.numel() == 0:
|
| 315 |
-
return source_latents, source_is_generated
|
| 316 |
|
| 317 |
eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
|
| 318 |
if context_frame_count is not None or target_start_frame is not None:
|
|
@@ -322,12 +264,8 @@ class MemoryDiTMixin:
|
|
| 322 |
if target_start_frame is not None:
|
| 323 |
eligible_mask &= frame_positions < max(0, int(target_start_frame))
|
| 324 |
proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
diagnostics["generated_history_proxy_frame_count"] = proxy_count
|
| 328 |
-
diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count)
|
| 329 |
-
if proxy_count == 0:
|
| 330 |
-
return source_latents, source_is_generated, diagnostics
|
| 331 |
|
| 332 |
corrupt_latents = source_latents.clone()
|
| 333 |
frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype)
|
|
@@ -342,7 +280,7 @@ class MemoryDiTMixin:
|
|
| 342 |
corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents)
|
| 343 |
source_is_generated = source_is_generated.clone()
|
| 344 |
source_is_generated |= proxy_mask
|
| 345 |
-
return corrupt_latents, source_is_generated
|
| 346 |
|
| 347 |
def _checkpoint_cfg(self):
|
| 348 |
return self._cfg_get(self._memory_cfg(), "checkpoint", None)
|
|
@@ -397,62 +335,21 @@ class MemoryDiTMixin:
|
|
| 397 |
|
| 398 |
def _apply_freeze_policy(self, optimizer=None, step: int | None = None):
|
| 399 |
state = self._curriculum_state(step)
|
| 400 |
-
|
| 401 |
-
# Keep DDP's trainable graph stable: DiT params stay requires_grad=True
|
| 402 |
-
# from step 0 and are frozen by optimizer LR=0 until the full stage.
|
| 403 |
-
# Re-walk only when curriculum diagnostics can change.
|
| 404 |
freeze_key = (state.stage, state.dit_train_state, state.freeze_vae)
|
| 405 |
-
|
| 406 |
-
if last_key != freeze_key:
|
| 407 |
-
trainable_tensors = {
|
| 408 |
-
"dememwm_modules": 0,
|
| 409 |
-
"memory_adapters": 0,
|
| 410 |
-
"full_dit": 0,
|
| 411 |
-
"excluded_frozen": 0,
|
| 412 |
-
}
|
| 413 |
-
trainable_scalars = {key: 0 for key in trainable_tensors}
|
| 414 |
-
requires_grad_tensors = {key: 0 for key in trainable_tensors}
|
| 415 |
-
requires_grad_scalars = {key: 0 for key in trainable_tensors}
|
| 416 |
for name, param in self.named_parameters():
|
| 417 |
group_name = self._param_group_name(name, state)
|
| 418 |
-
should_train = self._group_trainable(group_name, state)
|
| 419 |
if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae):
|
| 420 |
-
|
| 421 |
-
should_require_grad = False
|
| 422 |
else:
|
| 423 |
-
|
| 424 |
-
param.requires_grad_(should_require_grad)
|
| 425 |
-
if should_train:
|
| 426 |
-
trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1
|
| 427 |
-
trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel())
|
| 428 |
-
if should_require_grad:
|
| 429 |
-
requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1
|
| 430 |
-
requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel())
|
| 431 |
self._last_freeze_key = freeze_key
|
| 432 |
-
self._last_trainable_tensors = trainable_tensors
|
| 433 |
-
self._last_trainable_scalars = trainable_scalars
|
| 434 |
-
self._last_requires_grad_tensors = requires_grad_tensors
|
| 435 |
-
self._last_requires_grad_scalars = requires_grad_scalars
|
| 436 |
-
else:
|
| 437 |
-
trainable_tensors = getattr(self, "_last_trainable_tensors", {})
|
| 438 |
-
trainable_scalars = getattr(self, "_last_trainable_scalars", {})
|
| 439 |
-
requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {})
|
| 440 |
-
requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {})
|
| 441 |
|
| 442 |
if optimizer is not None:
|
| 443 |
for param_group in optimizer.param_groups:
|
| 444 |
group_name = param_group.get("name", "")
|
| 445 |
trainable = self._group_trainable(group_name, state)
|
| 446 |
param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0
|
| 447 |
-
|
| 448 |
-
diagnostics = state.diagnostics()
|
| 449 |
-
for group_name in ("dememwm_modules", "memory_adapters", "full_dit"):
|
| 450 |
-
diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0)
|
| 451 |
-
diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0)
|
| 452 |
-
diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0)
|
| 453 |
-
diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0)
|
| 454 |
-
diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0
|
| 455 |
-
self._last_dememwm_freeze_diagnostics = diagnostics
|
| 456 |
return state
|
| 457 |
|
| 458 |
def configure_optimizers(self):
|
|
@@ -1101,7 +998,7 @@ class MemoryDiTMixin:
|
|
| 1101 |
plucker_weight: float,
|
| 1102 |
revisit_retrieval_kwargs: dict | None,
|
| 1103 |
token_patch_size: int,
|
| 1104 |
-
) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]:
|
| 1105 |
if committed_latents.ndim != 5:
|
| 1106 |
raise ValueError("committed_latents must have shape (T_src,B,C,H,W)")
|
| 1107 |
T_src, B, _, H, W = committed_latents.shape
|
|
@@ -1126,12 +1023,6 @@ class MemoryDiTMixin:
|
|
| 1126 |
revisit_banks: list[CausalMemoryBank] = []
|
| 1127 |
dummy_tokens = committed_latents.new_zeros((1, hidden_size))
|
| 1128 |
dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool)
|
| 1129 |
-
preselection_candidate_count = 0
|
| 1130 |
-
preselection_valid_candidate_label_count = 0
|
| 1131 |
-
preselection_selected_count = 0
|
| 1132 |
-
projected_anchor_frames = 0
|
| 1133 |
-
projected_revisit_frames = 0
|
| 1134 |
-
projected_revisit_records = 0
|
| 1135 |
retrieval_kwargs = dict(revisit_retrieval_kwargs or {})
|
| 1136 |
|
| 1137 |
# Pre-convert pose tensors to stream_device once so that the
|
|
@@ -1188,6 +1079,175 @@ class MemoryDiTMixin:
|
|
| 1188 |
def _pose_subset(positions: torch.Tensor, batch_idx: int):
|
| 1189 |
return _tensor_subset(pose, positions, batch_idx)
|
| 1190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1191 |
def _candidate_record(
|
| 1192 |
*,
|
| 1193 |
batch_idx: int,
|
|
@@ -1238,7 +1298,6 @@ class MemoryDiTMixin:
|
|
| 1238 |
if selected_anchor_positions:
|
| 1239 |
anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
|
| 1240 |
if anchor_positions.numel() > 0:
|
| 1241 |
-
projected_anchor_frames += int(anchor_positions.numel())
|
| 1242 |
anchor_projected = self._project_latent_patch_tokens(
|
| 1243 |
committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
|
| 1244 |
self.dememwm_anchor_proj,
|
|
@@ -1258,9 +1317,7 @@ class MemoryDiTMixin:
|
|
| 1258 |
|
| 1259 |
candidate_records: list[MemoryRecord] = []
|
| 1260 |
candidate_positions: dict[str, torch.Tensor] = {}
|
| 1261 |
-
|
| 1262 |
-
target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long)
|
| 1263 |
-
latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames)
|
| 1264 |
for prefix, positions, source_type, is_generated in (
|
| 1265 |
("prefix", source_positions, MemorySourceType.PREFIX_GT, False),
|
| 1266 |
(
|
|
@@ -1270,14 +1327,15 @@ class MemoryDiTMixin:
|
|
| 1270 |
True,
|
| 1271 |
),
|
| 1272 |
):
|
| 1273 |
-
if positions.numel() == 0
|
| 1274 |
continue
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
|
| 1278 |
-
|
| 1279 |
-
|
| 1280 |
-
|
|
|
|
| 1281 |
record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}"
|
| 1282 |
candidate_positions[record_id] = frame_position
|
| 1283 |
candidate_records.append(_candidate_record(
|
|
@@ -1301,12 +1359,9 @@ class MemoryDiTMixin:
|
|
| 1301 |
exclude_local_context_frames=exclude_local_context_frames,
|
| 1302 |
fov_overlap_threshold=fov_overlap_threshold,
|
| 1303 |
plucker_weight=plucker_weight,
|
| 1304 |
-
target_video_id=
|
| 1305 |
**retrieval_kwargs,
|
| 1306 |
)
|
| 1307 |
-
preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0)))
|
| 1308 |
-
preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0))
|
| 1309 |
-
preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0)))
|
| 1310 |
for selected_record in result.records:
|
| 1311 |
if selected_record.chunk_id is None:
|
| 1312 |
continue
|
|
@@ -1319,8 +1374,6 @@ class MemoryDiTMixin:
|
|
| 1319 |
continue
|
| 1320 |
record_id = str(record.chunk_id)
|
| 1321 |
frame_position = candidate_positions[record_id]
|
| 1322 |
-
projected_revisit_records += 1
|
| 1323 |
-
projected_revisit_frames += int(frame_position.numel())
|
| 1324 |
revisit_projected = self._project_latent_patch_tokens(
|
| 1325 |
committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1],
|
| 1326 |
self.dememwm_revisit_proj,
|
|
@@ -1344,17 +1397,7 @@ class MemoryDiTMixin:
|
|
| 1344 |
anchor_banks.append(anchor_bank)
|
| 1345 |
revisit_banks.append(revisit_bank)
|
| 1346 |
|
| 1347 |
-
|
| 1348 |
-
"preselected_anchor_projected_frame_count": projected_anchor_frames,
|
| 1349 |
-
"preselected_revisit_projected_frame_count": projected_revisit_frames,
|
| 1350 |
-
"preselected_revisit_projected_frame_record_count": projected_revisit_records,
|
| 1351 |
-
"preselected_revisit_candidate_frame_count": preselection_candidate_count,
|
| 1352 |
-
"preselected_revisit_candidate_count": preselection_candidate_count,
|
| 1353 |
-
"preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count,
|
| 1354 |
-
"preselected_revisit_selected_frame_count": preselection_selected_count,
|
| 1355 |
-
"preselected_revisit_selected_count": preselection_selected_count,
|
| 1356 |
-
}
|
| 1357 |
-
return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
|
| 1358 |
|
| 1359 |
def _causal_cached_revisit_records(
|
| 1360 |
self,
|
|
@@ -1486,8 +1529,6 @@ class MemoryDiTMixin:
|
|
| 1486 |
target_video_ids=None,
|
| 1487 |
source_is_generated: torch.Tensor | None = None,
|
| 1488 |
denoising_fraction: float | None = None,
|
| 1489 |
-
noise_bucket: str | None = None,
|
| 1490 |
-
noise_bucket_ids: torch.Tensor | None = None,
|
| 1491 |
streaming_cache: StreamingCache | None = None,
|
| 1492 |
) -> MemoryStreamTensors:
|
| 1493 |
if target_frame_indices is None:
|
|
@@ -1499,10 +1540,9 @@ class MemoryDiTMixin:
|
|
| 1499 |
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
|
| 1500 |
revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
|
| 1501 |
injection_cfg = self._cfg_get(memory_cfg, "injection", None)
|
| 1502 |
-
|
| 1503 |
gate_state = self._effective_gate_state(
|
| 1504 |
denoising_fraction=denoising_fraction,
|
| 1505 |
-
noise_bucket=noise_bucket,
|
| 1506 |
)
|
| 1507 |
anchor_config_enabled = gate_state["anchor_config_enabled"]
|
| 1508 |
dynamic_config_enabled = gate_state["dynamic_config_enabled"]
|
|
@@ -1510,7 +1550,6 @@ class MemoryDiTMixin:
|
|
| 1510 |
curriculum_state = gate_state["curriculum_state"]
|
| 1511 |
eval_ablation_enabled = gate_state["eval_ablation_enabled"]
|
| 1512 |
eval_ablation_branch = gate_state["eval_ablation_branch"]
|
| 1513 |
-
resolved_noise_bucket = gate_state["resolved_noise_bucket"]
|
| 1514 |
gates = gate_state["gates"]
|
| 1515 |
anchor_effective_enabled = gate_state["anchor_effective_enabled"]
|
| 1516 |
dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
|
|
@@ -1565,12 +1604,12 @@ class MemoryDiTMixin:
|
|
| 1565 |
"plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)),
|
| 1566 |
"plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)),
|
| 1567 |
}
|
| 1568 |
-
|
|
|
|
| 1569 |
use_cache_revisit_records = False
|
| 1570 |
revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None
|
| 1571 |
|
| 1572 |
cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None
|
| 1573 |
-
cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0}
|
| 1574 |
if committed_latents is not None:
|
| 1575 |
stream_device = committed_latents.device
|
| 1576 |
stream_dtype = committed_latents.dtype
|
|
@@ -1638,7 +1677,7 @@ class MemoryDiTMixin:
|
|
| 1638 |
B = committed_latents.shape[1]
|
| 1639 |
hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
|
| 1640 |
target_pose_source = target_pose if target_pose is not None else pose
|
| 1641 |
-
anchor_banks, revisit_banks, tokens_per_frame,
|
| 1642 |
committed_latents,
|
| 1643 |
source_frame_indices.to(device=stream_device),
|
| 1644 |
None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool),
|
|
@@ -1710,29 +1749,21 @@ class MemoryDiTMixin:
|
|
| 1710 |
dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w)
|
| 1711 |
dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype)
|
| 1712 |
dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool)
|
| 1713 |
-
dynamic_diag = {
|
| 1714 |
-
"selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
|
| 1715 |
-
"max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1716 |
-
"generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device),
|
| 1717 |
-
"dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1718 |
-
"dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1719 |
-
"dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
|
| 1720 |
-
"dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
|
| 1721 |
-
}
|
| 1722 |
else:
|
| 1723 |
# Pre-select dynamic source frame positions using only frame index metadata
|
| 1724 |
# before touching latents, so we pass a small slice instead of the full
|
| 1725 |
# 1000-frame tensor to the compressor.
|
| 1726 |
_dfi = dynamic_frame_indices.to(device=stream_device)
|
| 1727 |
_max_src = self.dememwm_dynamic_compressor.max_source_frames
|
| 1728 |
-
|
| 1729 |
for _b in range(B):
|
| 1730 |
for _j in range(T_tgt):
|
| 1731 |
-
_target =
|
| 1732 |
_valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
|
| 1733 |
-
|
| 1734 |
-
|
| 1735 |
-
|
|
|
|
| 1736 |
_dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx)
|
| 1737 |
_dynamic_fi_small = _dfi.index_select(0, _needed_idx)
|
| 1738 |
_dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None
|
|
@@ -1745,7 +1776,7 @@ class MemoryDiTMixin:
|
|
| 1745 |
_dynamic_fi_small = _dfi[:0]
|
| 1746 |
_dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None
|
| 1747 |
_dynamic_gen_small = None
|
| 1748 |
-
dynamic_tokens, dynamic_mask
|
| 1749 |
_dynamic_latents_small,
|
| 1750 |
_dynamic_fi_small,
|
| 1751 |
_dynamic_pose_small,
|
|
@@ -1754,18 +1785,6 @@ class MemoryDiTMixin:
|
|
| 1754 |
exclude_latest_local_frames=dynamic_recent_exclusion_frames,
|
| 1755 |
)
|
| 1756 |
|
| 1757 |
-
dynamic_min_gap_tensor = torch.as_tensor(
|
| 1758 |
-
dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)),
|
| 1759 |
-
device=stream_device,
|
| 1760 |
-
)
|
| 1761 |
-
dynamic_max_gap_tensor = torch.as_tensor(
|
| 1762 |
-
dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)),
|
| 1763 |
-
device=stream_device,
|
| 1764 |
-
)
|
| 1765 |
-
dynamic_gap_valid = dynamic_min_gap_tensor >= 0
|
| 1766 |
-
dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1
|
| 1767 |
-
dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0
|
| 1768 |
-
dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1
|
| 1769 |
def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int):
|
| 1770 |
if tensor is None or tensor.ndim < 2:
|
| 1771 |
return None
|
|
@@ -1804,15 +1823,11 @@ class MemoryDiTMixin:
|
|
| 1804 |
revisit_mask_rows = []
|
| 1805 |
revisit_max_rows = []
|
| 1806 |
valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
| 1807 |
-
revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1808 |
-
revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1809 |
revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1810 |
revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1811 |
revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
|
| 1812 |
eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
| 1813 |
-
revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
|
| 1814 |
eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
|
| 1815 |
-
revisit_result_diagnostics = []
|
| 1816 |
projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {}
|
| 1817 |
if revisit_record_batches is None:
|
| 1818 |
revisit_record_batches = [tuple(bank.records) for bank in revisit_banks]
|
|
@@ -1823,49 +1838,65 @@ class MemoryDiTMixin:
|
|
| 1823 |
batch_max_rows = []
|
| 1824 |
for target_idx in range(T_tgt):
|
| 1825 |
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
|
| 1826 |
-
if
|
| 1827 |
-
|
| 1828 |
-
|
| 1829 |
-
|
| 1830 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1831 |
else:
|
| 1832 |
-
|
| 1833 |
-
|
| 1834 |
-
|
| 1835 |
-
|
| 1836 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1837 |
)
|
| 1838 |
-
|
| 1839 |
-
|
| 1840 |
-
|
| 1841 |
-
|
| 1842 |
-
|
| 1843 |
-
|
| 1844 |
-
|
| 1845 |
-
|
| 1846 |
-
|
| 1847 |
-
|
| 1848 |
-
|
| 1849 |
-
|
| 1850 |
-
|
| 1851 |
-
|
| 1852 |
-
|
| 1853 |
-
|
| 1854 |
-
batch_idx=batch_idx,
|
| 1855 |
-
records=selected_records,
|
| 1856 |
-
device=stream_device,
|
| 1857 |
-
dtype=stream_dtype,
|
| 1858 |
-
token_patch_size=token_patch_size,
|
| 1859 |
-
revisit_pool_h=revisit_pool_h,
|
| 1860 |
-
revisit_pool_w=revisit_pool_w,
|
| 1861 |
-
projection_cache=projected_revisit_record_cache,
|
| 1862 |
-
)
|
| 1863 |
-
revisit_result_diagnostics.append(result.diagnostics)
|
| 1864 |
-
revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0)))
|
| 1865 |
-
revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0)))
|
| 1866 |
-
revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
|
| 1867 |
-
revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
|
| 1868 |
-
revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
|
| 1869 |
revisit_bank.assert_causal(target_frame, selected_records)
|
| 1870 |
if selected_records:
|
| 1871 |
valid_revisit_mask[batch_idx, target_idx] = True
|
|
@@ -1876,7 +1907,6 @@ class MemoryDiTMixin:
|
|
| 1876 |
stream_device,
|
| 1877 |
stream_dtype,
|
| 1878 |
)
|
| 1879 |
-
revisit_causal_max[batch_idx, target_idx] = max_source_frame
|
| 1880 |
if eval_corruption_enabled:
|
| 1881 |
stream_tokens, was_corrupted = apply_revisit_eval_corruption(
|
| 1882 |
tokens=stream_tokens,
|
|
@@ -1932,8 +1962,6 @@ class MemoryDiTMixin:
|
|
| 1932 |
if not revisit_stage_config_enabled:
|
| 1933 |
revisit_mask = torch.zeros_like(revisit_mask)
|
| 1934 |
valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
|
| 1935 |
-
revisit_candidate_count = torch.zeros_like(revisit_candidate_count)
|
| 1936 |
-
revisit_selected_count = torch.zeros_like(revisit_selected_count)
|
| 1937 |
revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
|
| 1938 |
revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
|
| 1939 |
revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
|
|
@@ -1941,85 +1969,6 @@ class MemoryDiTMixin:
|
|
| 1941 |
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
|
| 1942 |
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
|
| 1943 |
revisit_gate = torch.zeros_like(revisit_gate)
|
| 1944 |
-
no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask)
|
| 1945 |
-
revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask)
|
| 1946 |
-
causal_violation_count = 0
|
| 1947 |
-
for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max):
|
| 1948 |
-
if source_max is None:
|
| 1949 |
-
continue
|
| 1950 |
-
source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device)
|
| 1951 |
-
valid = source_max_t >= 0
|
| 1952 |
-
if valid.any():
|
| 1953 |
-
causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item())
|
| 1954 |
-
diagnostics = {
|
| 1955 |
-
**curriculum_state.diagnostics(),
|
| 1956 |
-
**getattr(self, "_last_dememwm_freeze_diagnostics", {}),
|
| 1957 |
-
**contract_diag,
|
| 1958 |
-
**cache_diag,
|
| 1959 |
-
**preselection_diag,
|
| 1960 |
-
**revisit_diag,
|
| 1961 |
-
"dememwm_stage": gates.stage,
|
| 1962 |
-
"dememwm_gate_reason": gates.reason,
|
| 1963 |
-
"anchor_config_enabled": anchor_config_enabled,
|
| 1964 |
-
"dynamic_config_enabled": dynamic_config_enabled,
|
| 1965 |
-
"revisit_config_enabled": revisit_config_enabled,
|
| 1966 |
-
"anchor_effective_enabled": anchor_effective_enabled,
|
| 1967 |
-
"dynamic_effective_enabled": dynamic_effective_enabled,
|
| 1968 |
-
"revisit_effective_enabled": revisit_effective_enabled,
|
| 1969 |
-
"revisit_stage_config_enabled": revisit_stage_config_enabled,
|
| 1970 |
-
"revisit_gate_raw": revisit_gate_raw.detach(),
|
| 1971 |
-
"revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
|
| 1972 |
-
"no_valid_revisit_mask": no_valid_revisit_mask,
|
| 1973 |
-
"valid_revisit_eff_mask": valid_revisit_eff_mask,
|
| 1974 |
-
"revisit_candidate_frame_count_per_target": revisit_candidate_count,
|
| 1975 |
-
"revisit_selected_frame_count_per_target": revisit_selected_count,
|
| 1976 |
-
"revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
|
| 1977 |
-
"revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap,
|
| 1978 |
-
"revisit_selected_gap_frames_per_target": revisit_selected_gap_frames,
|
| 1979 |
-
"revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0,
|
| 1980 |
-
"revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()),
|
| 1981 |
-
**summarize_noise_bucket_diagnostics(
|
| 1982 |
-
noise_bucket=resolved_noise_bucket,
|
| 1983 |
-
valid_revisit_mask=valid_revisit_mask,
|
| 1984 |
-
no_valid_revisit_mask=no_valid_revisit_mask,
|
| 1985 |
-
noise_bucket_ids=noise_bucket_ids,
|
| 1986 |
-
),
|
| 1987 |
-
**summarize_eval_ablation_diagnostics(
|
| 1988 |
-
enabled=eval_ablation_enabled,
|
| 1989 |
-
branch=eval_ablation_branch,
|
| 1990 |
-
valid_revisit_mask=valid_revisit_mask,
|
| 1991 |
-
no_valid_revisit_mask=no_valid_revisit_mask,
|
| 1992 |
-
eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None,
|
| 1993 |
-
),
|
| 1994 |
-
"token_patch_size": token_patch_size,
|
| 1995 |
-
"tokens_per_frame": tokens_per_frame,
|
| 1996 |
-
"anchor_token_slots": int(anchor_tokens.shape[-2]),
|
| 1997 |
-
"anchor_target_slots": anchor_num_tokens,
|
| 1998 |
-
"anchor_pool_h": anchor_pool_h,
|
| 1999 |
-
"anchor_pool_w": anchor_pool_w,
|
| 2000 |
-
"dynamic_token_slots": int(dynamic_tokens.shape[-2]),
|
| 2001 |
-
"dynamic_target_slots": int(dynamic_tokens.shape[-2]),
|
| 2002 |
-
"dynamic_min_gap_to_target": dynamic_min_gap_to_target,
|
| 2003 |
-
"dynamic_max_gap_to_target": dynamic_max_gap_to_target,
|
| 2004 |
-
"dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
|
| 2005 |
-
"revisit_token_slots": int(revisit_tokens.shape[-2]),
|
| 2006 |
-
"revisit_target_slots": revisit_target_slots,
|
| 2007 |
-
"revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames,
|
| 2008 |
-
"revisit_pool_h": revisit_pool_h,
|
| 2009 |
-
"revisit_pool_w": revisit_pool_w,
|
| 2010 |
-
"revisit_max_frames": revisit_max_frames,
|
| 2011 |
-
"anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0,
|
| 2012 |
-
"dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0,
|
| 2013 |
-
"revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0,
|
| 2014 |
-
"causal_violation_count": causal_violation_count,
|
| 2015 |
-
"anchor_max_source_frame": anchor_max,
|
| 2016 |
-
"dynamic_max_source_frame": dynamic_diag.get("max_source_frame"),
|
| 2017 |
-
"revisit_max_source_frame": revisit_max,
|
| 2018 |
-
"dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"),
|
| 2019 |
-
}
|
| 2020 |
-
if eval_corruption_enabled:
|
| 2021 |
-
diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask
|
| 2022 |
-
|
| 2023 |
return MemoryStreamTensors(
|
| 2024 |
anchor_tokens=anchor_tokens,
|
| 2025 |
anchor_mask=anchor_mask,
|
|
@@ -2032,19 +1981,18 @@ class MemoryDiTMixin:
|
|
| 2032 |
revisit_gate=revisit_gate,
|
| 2033 |
revisit_gate_raw=revisit_gate_raw,
|
| 2034 |
valid_revisit_mask=valid_revisit_mask,
|
| 2035 |
-
|
| 2036 |
-
|
|
|
|
| 2037 |
)
|
| 2038 |
|
| 2039 |
def _refresh_stream_gates(
|
| 2040 |
self,
|
| 2041 |
streams: MemoryStreamTensors,
|
| 2042 |
denoising_fraction: float | None = None,
|
| 2043 |
-
noise_bucket: str | None = None,
|
| 2044 |
) -> MemoryStreamTensors:
|
| 2045 |
gate_state = self._effective_gate_state(
|
| 2046 |
denoising_fraction=denoising_fraction,
|
| 2047 |
-
noise_bucket=noise_bucket,
|
| 2048 |
)
|
| 2049 |
gates = gate_state["gates"]
|
| 2050 |
device = streams.anchor_tokens.device
|
|
@@ -2056,20 +2004,17 @@ class MemoryDiTMixin:
|
|
| 2056 |
else:
|
| 2057 |
valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool)
|
| 2058 |
|
| 2059 |
-
|
| 2060 |
-
|
| 2061 |
-
def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor:
|
| 2062 |
-
value = diagnostics.get(name)
|
| 2063 |
if value is None:
|
| 2064 |
return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32)
|
| 2065 |
-
tensor =
|
| 2066 |
if tensor.ndim == 0:
|
| 2067 |
return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32)
|
| 2068 |
return tensor.expand((B, T_tgt))
|
| 2069 |
|
| 2070 |
-
revisit_best_selected_fov_overlap =
|
| 2071 |
-
revisit_best_selected_plucker_overlap =
|
| 2072 |
-
revisit_selected_gap_frames =
|
| 2073 |
|
| 2074 |
anchor_effective_enabled = gate_state["anchor_effective_enabled"]
|
| 2075 |
dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
|
|
@@ -2086,53 +2031,16 @@ class MemoryDiTMixin:
|
|
| 2086 |
best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
|
| 2087 |
selected_gap_frames=revisit_selected_gap_frames,
|
| 2088 |
).to(device=device, dtype=dtype)
|
| 2089 |
-
valid_revisit_eff_mask = valid_revisit_mask
|
| 2090 |
if not revisit_stage_config_enabled or gate_state["force_revisit_off"]:
|
| 2091 |
revisit_gate = torch.zeros_like(revisit_gate_raw)
|
| 2092 |
elif gate_state["force_revisit_on"]:
|
| 2093 |
-
revisit_gate =
|
| 2094 |
else:
|
| 2095 |
-
revisit_gate =
|
| 2096 |
if not revisit_stage_config_enabled:
|
| 2097 |
valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
|
| 2098 |
-
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
|
| 2099 |
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
|
| 2100 |
revisit_gate = torch.zeros_like(revisit_gate)
|
| 2101 |
-
no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask)
|
| 2102 |
-
eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask")
|
| 2103 |
-
if eval_corrupted_revisit_mask is not None:
|
| 2104 |
-
eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool)
|
| 2105 |
-
revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item())
|
| 2106 |
-
diagnostics.update(gate_state["curriculum_state"].diagnostics())
|
| 2107 |
-
diagnostics.update({
|
| 2108 |
-
"dememwm_stage": gates.stage,
|
| 2109 |
-
"dememwm_gate_reason": gates.reason,
|
| 2110 |
-
"anchor_config_enabled": gate_state["anchor_config_enabled"],
|
| 2111 |
-
"dynamic_config_enabled": gate_state["dynamic_config_enabled"],
|
| 2112 |
-
"revisit_config_enabled": gate_state["revisit_config_enabled"],
|
| 2113 |
-
"anchor_effective_enabled": anchor_effective_enabled,
|
| 2114 |
-
"dynamic_effective_enabled": dynamic_effective_enabled,
|
| 2115 |
-
"revisit_effective_enabled": revisit_effective_enabled,
|
| 2116 |
-
"revisit_stage_config_enabled": revisit_stage_config_enabled,
|
| 2117 |
-
"revisit_gate_raw": revisit_gate_raw.detach(),
|
| 2118 |
-
"revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
|
| 2119 |
-
"no_valid_revisit_mask": no_valid_revisit_mask,
|
| 2120 |
-
"valid_revisit_eff_mask": valid_revisit_eff_mask,
|
| 2121 |
-
"revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0,
|
| 2122 |
-
"revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0,
|
| 2123 |
-
})
|
| 2124 |
-
diagnostics.update(summarize_noise_bucket_diagnostics(
|
| 2125 |
-
noise_bucket=gate_state["resolved_noise_bucket"],
|
| 2126 |
-
valid_revisit_mask=valid_revisit_mask,
|
| 2127 |
-
no_valid_revisit_mask=no_valid_revisit_mask,
|
| 2128 |
-
))
|
| 2129 |
-
diagnostics.update(summarize_eval_ablation_diagnostics(
|
| 2130 |
-
enabled=gate_state["eval_ablation_enabled"],
|
| 2131 |
-
branch=gate_state["eval_ablation_branch"],
|
| 2132 |
-
valid_revisit_mask=valid_revisit_mask,
|
| 2133 |
-
no_valid_revisit_mask=no_valid_revisit_mask,
|
| 2134 |
-
eval_corrupted_revisit_mask=eval_corrupted_revisit_mask,
|
| 2135 |
-
))
|
| 2136 |
return replace(
|
| 2137 |
streams,
|
| 2138 |
anchor_gate=anchor_gate,
|
|
@@ -2140,84 +2048,90 @@ class MemoryDiTMixin:
|
|
| 2140 |
revisit_gate=revisit_gate,
|
| 2141 |
revisit_gate_raw=revisit_gate_raw,
|
| 2142 |
valid_revisit_mask=valid_revisit_mask,
|
| 2143 |
-
|
| 2144 |
-
|
|
|
|
| 2145 |
)
|
| 2146 |
|
| 2147 |
-
def
|
| 2148 |
-
|
| 2149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2150 |
|
| 2151 |
-
def build_memory_kwargs(self, *args, **kwargs) ->
|
| 2152 |
streams = self.build_memory_streams(*args, **kwargs)
|
| 2153 |
return self._streams_to_kwargs(streams)
|
| 2154 |
|
| 2155 |
-
def
|
| 2156 |
-
|
| 2157 |
-
|
| 2158 |
-
|
| 2159 |
-
|
| 2160 |
-
|
| 2161 |
-
|
| 2162 |
-
|
| 2163 |
-
|
| 2164 |
-
|
| 2165 |
-
elif namespace.endswith("/dememwm"):
|
| 2166 |
-
allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS
|
| 2167 |
-
else:
|
| 2168 |
-
allowed_keys = None
|
| 2169 |
-
for key, value in diagnostics.items():
|
| 2170 |
-
if allowed_keys is not None and key not in allowed_keys:
|
| 2171 |
-
continue
|
| 2172 |
-
if isinstance(value, str) or value is None:
|
| 2173 |
-
continue
|
| 2174 |
-
if torch.is_tensor(value):
|
| 2175 |
-
if value.numel() > 0:
|
| 2176 |
-
self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True)
|
| 2177 |
-
elif isinstance(value, (bool, int, float)):
|
| 2178 |
-
self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True)
|
| 2179 |
-
|
| 2180 |
-
def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx):
|
| 2181 |
-
from ..df_video import convert_to_plucker
|
| 2182 |
-
image_height, image_width = self._image_size(xs)
|
| 2183 |
if self.use_plucker:
|
| 2184 |
if self.relative_embedding:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2185 |
input_pose_condition = []
|
| 2186 |
frame_idx_list = []
|
| 2187 |
-
|
| 2188 |
-
ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0]
|
| 2189 |
-
for i in range(c2w_mat.shape[0]):
|
| 2190 |
input_pose_condition.append(
|
| 2191 |
convert_to_plucker(
|
| 2192 |
-
torch.cat([c2w_mat[
|
| 2193 |
0,
|
| 2194 |
focal_length=self.focal_length,
|
| 2195 |
-
|
| 2196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2197 |
)
|
| 2198 |
-
frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone())
|
| 2199 |
return torch.cat(input_pose_condition), torch.cat(frame_idx_list)
|
| 2200 |
-
return
|
| 2201 |
-
|
| 2202 |
-
|
| 2203 |
-
|
| 2204 |
-
|
| 2205 |
-
|
| 2206 |
-
|
| 2207 |
-
|
| 2208 |
-
|
| 2209 |
-
|
| 2210 |
-
|
| 2211 |
-
return 0, n_tokens
|
| 2212 |
-
context_start = self._context_frame_count()
|
| 2213 |
-
min_start = min(context_start, max_start)
|
| 2214 |
-
if min_start == max_start:
|
| 2215 |
-
return min_start, min_start + n_tokens
|
| 2216 |
-
start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item())
|
| 2217 |
-
return start, start + n_tokens
|
| 2218 |
|
| 2219 |
def training_step(self, batch, batch_idx):
|
| 2220 |
xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
|
|
|
|
| 2221 |
xs = self._as_latents(xs)
|
| 2222 |
|
| 2223 |
# Randomly select a contiguous n_tokens denoising window inside the long
|
|
@@ -2231,7 +2145,12 @@ class MemoryDiTMixin:
|
|
| 2231 |
frame_idx_window = frame_idx[start:end]
|
| 2232 |
|
| 2233 |
input_pose_condition, frame_idx_list = self._training_pose_condition(
|
| 2234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2235 |
)
|
| 2236 |
|
| 2237 |
noise_levels = self._generate_noise_levels(xs_window)
|
|
@@ -2239,17 +2158,15 @@ class MemoryDiTMixin:
|
|
| 2239 |
noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
|
| 2240 |
conditions_window[-self.memory_condition_length:] *= 0
|
| 2241 |
source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool)
|
| 2242 |
-
memory_source_latents, source_is_generated
|
| 2243 |
xs,
|
| 2244 |
source_is_generated,
|
| 2245 |
context_frame_count=self._context_frame_count(),
|
| 2246 |
target_start_frame=start,
|
| 2247 |
)
|
| 2248 |
timesteps = int(getattr(self, "timesteps", 0) or 0)
|
| 2249 |
-
training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps)
|
| 2250 |
-
training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps)
|
| 2251 |
training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps)
|
| 2252 |
-
memory_kwargs
|
| 2253 |
memory_source_latents,
|
| 2254 |
frame_idx,
|
| 2255 |
target_frame_indices=frame_idx_window,
|
|
@@ -2259,10 +2176,7 @@ class MemoryDiTMixin:
|
|
| 2259 |
target_action=conditions_window,
|
| 2260 |
source_is_generated=source_is_generated,
|
| 2261 |
denoising_fraction=training_denoising_fraction,
|
| 2262 |
-
noise_bucket=training_noise_bucket,
|
| 2263 |
-
noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1),
|
| 2264 |
)
|
| 2265 |
-
diagnostics.update(proxy_diagnostics)
|
| 2266 |
_, loss = self.diffusion_model(
|
| 2267 |
xs_window,
|
| 2268 |
conditions_window,
|
|
@@ -2272,19 +2186,19 @@ class MemoryDiTMixin:
|
|
| 2272 |
frame_idx=frame_idx_list,
|
| 2273 |
**memory_kwargs,
|
| 2274 |
)
|
| 2275 |
-
diagnostics.update(self._memory_adapter_delta_diagnostics())
|
| 2276 |
if self.memory_condition_length:
|
| 2277 |
loss = loss[:-self.memory_condition_length]
|
| 2278 |
loss_denoise = self.reweight_loss(loss, None)
|
| 2279 |
loss_total = loss_denoise
|
| 2280 |
-
diagnostics["training_window_start"] = int(start)
|
| 2281 |
-
diagnostics["training_window_end"] = int(end)
|
| 2282 |
-
diagnostics["training_window_size"] = int(end - start)
|
| 2283 |
-
diagnostics["loss_denoise"] = float(loss_denoise.detach().item())
|
| 2284 |
-
diagnostics["loss_total"] = float(loss_total.detach().item())
|
| 2285 |
if batch_idx % 20 == 0:
|
| 2286 |
-
|
| 2287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2288 |
return {"loss": loss_total}
|
| 2289 |
|
| 2290 |
def validation_step(self, batch, batch_idx, namespace="validation"):
|
|
@@ -2308,7 +2222,6 @@ class MemoryDiTMixin:
|
|
| 2308 |
streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}")
|
| 2309 |
cached_until = 0
|
| 2310 |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 2311 |
-
last_diagnostics = None
|
| 2312 |
while curr_frame < n_frames:
|
| 2313 |
if streaming_cache is not None and curr_frame > cached_until:
|
| 2314 |
new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device)
|
|
@@ -2375,7 +2288,7 @@ class MemoryDiTMixin:
|
|
| 2375 |
from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length)
|
| 2376 |
denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0)
|
| 2377 |
step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac)
|
| 2378 |
-
memory_kwargs
|
| 2379 |
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 2380 |
xs_pred[start_frame:].to(input_condition.device),
|
| 2381 |
input_condition,
|
|
@@ -2405,12 +2318,8 @@ class MemoryDiTMixin:
|
|
| 2405 |
action=conditions[cached_until:curr_frame],
|
| 2406 |
)
|
| 2407 |
cached_until = curr_frame
|
| 2408 |
-
if last_diagnostics is not None:
|
| 2409 |
-
last_diagnostics.update(streaming_cache.diagnostics("cache"))
|
| 2410 |
pbar.update(horizon)
|
| 2411 |
pbar.close()
|
| 2412 |
-
if last_diagnostics is not None:
|
| 2413 |
-
self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics)
|
| 2414 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
|
| 2415 |
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
|
| 2416 |
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
|
|
@@ -2440,10 +2349,7 @@ class MemoryDiTMixin:
|
|
| 2440 |
# Compatibility aliases for old DeMemWM test and experiment call sites.
|
| 2441 |
dememwm_strict_key_prefixes = strict_key_prefixes
|
| 2442 |
dememwm_strict_key_substrings = strict_key_substrings
|
| 2443 |
-
_DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS
|
| 2444 |
-
_DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS
|
| 2445 |
_dememwm_cfg = _memory_cfg
|
| 2446 |
-
_dememwm_stage_policy_cfg = _stage_policy_cfg
|
| 2447 |
_dememwm_eval_ablation_cfg = _eval_ablation_cfg
|
| 2448 |
_dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg
|
| 2449 |
_dememwm_eval_ablation_state = _eval_ablation_state
|
|
@@ -2476,8 +2382,6 @@ class MemoryDiTMixin:
|
|
| 2476 |
_dememwm_refresh_stream_gates = _refresh_stream_gates
|
| 2477 |
_dememwm_streams_to_kwargs = _streams_to_kwargs
|
| 2478 |
build_dememwm_memory_kwargs = build_memory_kwargs
|
| 2479 |
-
_dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics
|
| 2480 |
-
_log_dememwm_diagnostics = _log_memory_diagnostics
|
| 2481 |
_dememwm_training_window_bounds = _training_window_bounds
|
| 2482 |
strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check
|
| 2483 |
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import math
|
|
|
|
| 9 |
|
| 10 |
from .cache import StreamingCache
|
| 11 |
from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens
|
|
|
|
| 12 |
from .injection import InjectionAdapter
|
| 13 |
from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
|
| 14 |
from .negatives import apply_revisit_eval_corruption
|
| 15 |
+
from .retrieval import batched_revisit_select_positions, deterministic_revisit_retrieval
|
| 16 |
+
from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum
|
| 17 |
from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors
|
| 18 |
+
from ..df_video import convert_to_plucker
|
| 19 |
|
| 20 |
|
| 21 |
class MemoryDiTMixin:
|
|
|
|
| 35 |
strict_key_substrings = (
|
| 36 |
".memory_token_cross_attn.",
|
| 37 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def _memory_cfg(self):
|
| 40 |
+
return getattr(getattr(self, "cfg", None), "dememwm", None)
|
| 41 |
|
| 42 |
def _cfg_get(self, obj, name, default):
|
| 43 |
if obj is None:
|
|
|
|
| 57 |
except Exception:
|
| 58 |
return False
|
| 59 |
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def _eval_ablation_cfg(self):
|
| 62 |
return self._cfg_get(self._memory_cfg(), "eval_ablation", None)
|
|
|
|
| 70 |
branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal"))
|
| 71 |
return enabled, branch
|
| 72 |
|
| 73 |
+
def _effective_gate_state(self, denoising_fraction: float | None = None) -> dict:
|
| 74 |
memory_cfg = self._memory_cfg()
|
| 75 |
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
|
| 76 |
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
|
|
|
|
| 81 |
revisit_config_enabled = self._stream_enabled(revisit_cfg)
|
| 82 |
curriculum_state = self._curriculum_state()
|
| 83 |
eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state()
|
|
|
|
|
|
|
| 84 |
gates = compute_stream_gates(
|
| 85 |
curriculum_state.stage,
|
| 86 |
denoising_fraction=denoising_fraction,
|
|
|
|
| 87 |
anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)),
|
| 88 |
dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)),
|
| 89 |
revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)),
|
|
|
|
| 107 |
return {
|
| 108 |
"curriculum_state": curriculum_state,
|
| 109 |
"gates": gates,
|
|
|
|
| 110 |
"anchor_config_enabled": anchor_config_enabled,
|
| 111 |
"dynamic_config_enabled": dynamic_config_enabled,
|
| 112 |
"revisit_config_enabled": revisit_config_enabled,
|
|
|
|
| 119 |
"force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"),
|
| 120 |
}
|
| 121 |
|
| 122 |
+
def _validate_config_contract(self) -> None:
|
| 123 |
if bool(getattr(self, "_dememwm_contract_validated", False)):
|
| 124 |
+
return
|
| 125 |
memory_cfg = self._memory_cfg()
|
| 126 |
if memory_cfg is None:
|
| 127 |
self._dememwm_contract_validated = True
|
| 128 |
+
return
|
|
|
|
| 129 |
|
| 130 |
stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)]
|
| 131 |
if stale_sections:
|
|
|
|
| 159 |
if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)):
|
| 160 |
raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval")
|
| 161 |
fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
|
| 162 |
+
if fov_overlap_threshold is not None and float(fov_overlap_threshold) < 0.0:
|
| 163 |
+
raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative")
|
|
|
|
|
|
|
| 164 |
high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
|
| 165 |
if high_quality_fov_threshold < 0.0:
|
| 166 |
raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative")
|
|
|
|
| 186 |
value = int(self._cfg_get(revisit_cfg, field_name, default))
|
| 187 |
if value <= 0:
|
| 188 |
raise ValueError(f"dememwm.revisit.{field_name} must be positive")
|
|
|
|
|
|
|
|
|
|
| 189 |
proxy_cfg = self._generated_history_proxy_cfg()
|
| 190 |
proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0))
|
| 191 |
proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0))
|
|
|
|
| 201 |
raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative")
|
| 202 |
eval_ablation_cfg = self._eval_ablation_cfg()
|
| 203 |
normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
self._dememwm_contract_validated = True
|
|
|
|
|
|
|
| 205 |
|
| 206 |
def _stream_enabled(self, stream_cfg) -> bool:
|
| 207 |
return bool(self._cfg_get(stream_cfg, "enabled", True))
|
|
|
|
| 244 |
source_is_generated: torch.Tensor | None,
|
| 245 |
context_frame_count: int | None = None,
|
| 246 |
target_start_frame: int | None = None,
|
| 247 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 248 |
cfg = self._generated_history_proxy_cfg()
|
| 249 |
prob = self._generated_history_proxy_prob()
|
| 250 |
noise_std = float(self._cfg_get(cfg, "noise_std", 0.0))
|
| 251 |
dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
if source_is_generated is None:
|
| 253 |
source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
|
| 254 |
else:
|
| 255 |
source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool)
|
| 256 |
if prob <= 0.0 or source_latents.numel() == 0:
|
| 257 |
+
return source_latents, source_is_generated
|
| 258 |
|
| 259 |
eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
|
| 260 |
if context_frame_count is not None or target_start_frame is not None:
|
|
|
|
| 264 |
if target_start_frame is not None:
|
| 265 |
eligible_mask &= frame_positions < max(0, int(target_start_frame))
|
| 266 |
proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask
|
| 267 |
+
if not proxy_mask.any():
|
| 268 |
+
return source_latents, source_is_generated
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
corrupt_latents = source_latents.clone()
|
| 271 |
frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype)
|
|
|
|
| 280 |
corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents)
|
| 281 |
source_is_generated = source_is_generated.clone()
|
| 282 |
source_is_generated |= proxy_mask
|
| 283 |
+
return corrupt_latents, source_is_generated
|
| 284 |
|
| 285 |
def _checkpoint_cfg(self):
|
| 286 |
return self._cfg_get(self._memory_cfg(), "checkpoint", None)
|
|
|
|
| 335 |
|
| 336 |
def _apply_freeze_policy(self, optimizer=None, step: int | None = None):
|
| 337 |
state = self._curriculum_state(step)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
freeze_key = (state.stage, state.dit_train_state, state.freeze_vae)
|
| 339 |
+
if getattr(self, "_last_freeze_key", None) != freeze_key:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
for name, param in self.named_parameters():
|
| 341 |
group_name = self._param_group_name(name, state)
|
|
|
|
| 342 |
if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae):
|
| 343 |
+
param.requires_grad_(False)
|
|
|
|
| 344 |
else:
|
| 345 |
+
param.requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
self._last_freeze_key = freeze_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
if optimizer is not None:
|
| 349 |
for param_group in optimizer.param_groups:
|
| 350 |
group_name = param_group.get("name", "")
|
| 351 |
trainable = self._group_trainable(group_name, state)
|
| 352 |
param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
return state
|
| 354 |
|
| 355 |
def configure_optimizers(self):
|
|
|
|
| 998 |
plucker_weight: float,
|
| 999 |
revisit_retrieval_kwargs: dict | None,
|
| 1000 |
token_patch_size: int,
|
| 1001 |
+
) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, list[list[tuple[MemoryRecord, ...]]] | None, dict | None]:
|
| 1002 |
if committed_latents.ndim != 5:
|
| 1003 |
raise ValueError("committed_latents must have shape (T_src,B,C,H,W)")
|
| 1004 |
T_src, B, _, H, W = committed_latents.shape
|
|
|
|
| 1023 |
revisit_banks: list[CausalMemoryBank] = []
|
| 1024 |
dummy_tokens = committed_latents.new_zeros((1, hidden_size))
|
| 1025 |
dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1026 |
retrieval_kwargs = dict(revisit_retrieval_kwargs or {})
|
| 1027 |
|
| 1028 |
# Pre-convert pose tensors to stream_device once so that the
|
|
|
|
| 1079 |
def _pose_subset(positions: torch.Tensor, batch_idx: int):
|
| 1080 |
return _tensor_subset(pose, positions, batch_idx)
|
| 1081 |
|
| 1082 |
+
fast_source_pose_shape = (
|
| 1083 |
+
pose is not None
|
| 1084 |
+
and pose.ndim >= 3
|
| 1085 |
+
and ((pose.shape[0] == T_src and pose.shape[1] == B) or (pose.shape[0] == B and pose.shape[1] == T_src))
|
| 1086 |
+
)
|
| 1087 |
+
fast_target_pose_shape = (
|
| 1088 |
+
target_pose is not None
|
| 1089 |
+
and target_pose.ndim >= 3
|
| 1090 |
+
and ((target_pose.shape[0] == T_tgt and target_pose.shape[1] == B) or (target_pose.shape[0] == B and target_pose.shape[1] == T_tgt))
|
| 1091 |
+
)
|
| 1092 |
+
if fast_source_pose_shape and fast_target_pose_shape:
|
| 1093 |
+
selection = batched_revisit_select_positions(
|
| 1094 |
+
source_frame_indices,
|
| 1095 |
+
pose,
|
| 1096 |
+
target_frame_indices,
|
| 1097 |
+
target_pose,
|
| 1098 |
+
topk=revisit_max_frames,
|
| 1099 |
+
exclude_local_context_frames=exclude_local_context_frames,
|
| 1100 |
+
fov_overlap_threshold=fov_overlap_threshold,
|
| 1101 |
+
plucker_weight=plucker_weight,
|
| 1102 |
+
fov_half_h=float(retrieval_kwargs.get("fov_half_h", 52.5)),
|
| 1103 |
+
fov_half_v=float(retrieval_kwargs.get("fov_half_v", 37.5)),
|
| 1104 |
+
fov_yaw_samples=int(retrieval_kwargs.get("fov_yaw_samples", 25)),
|
| 1105 |
+
fov_pitch_samples=int(retrieval_kwargs.get("fov_pitch_samples", 20)),
|
| 1106 |
+
fov_depth_samples=int(retrieval_kwargs.get("fov_depth_samples", 20)),
|
| 1107 |
+
fov_radius=float(retrieval_kwargs.get("fov_radius", 30.0)),
|
| 1108 |
+
plucker_grid_h=int(retrieval_kwargs.get("plucker_grid_h", 4)),
|
| 1109 |
+
plucker_grid_w=int(retrieval_kwargs.get("plucker_grid_w", 4)),
|
| 1110 |
+
plucker_focal_length=float(retrieval_kwargs.get("plucker_focal_length", 0.35)),
|
| 1111 |
+
pose_preselect_topk=retrieval_kwargs.get("pose_preselect_topk", 64),
|
| 1112 |
+
)
|
| 1113 |
+
selected_records_by_target: list[list[tuple[MemoryRecord, ...]]] = [
|
| 1114 |
+
[tuple() for _ in range(T_tgt)] for _ in range(B)
|
| 1115 |
+
]
|
| 1116 |
+
|
| 1117 |
+
for batch_idx in range(B):
|
| 1118 |
+
anchor_bank = CausalMemoryBank()
|
| 1119 |
+
revisit_bank = CausalMemoryBank()
|
| 1120 |
+
src_frames = source_frame_indices[:, batch_idx]
|
| 1121 |
+
if generated is None:
|
| 1122 |
+
non_generated = torch.ones_like(src_frames, dtype=torch.bool)
|
| 1123 |
+
else:
|
| 1124 |
+
non_generated = ~generated[:, batch_idx]
|
| 1125 |
+
|
| 1126 |
+
source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
|
| 1127 |
+
anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long)
|
| 1128 |
+
if anchor_indices and source_positions.numel() > 0:
|
| 1129 |
+
if anchor_diverse:
|
| 1130 |
+
anchor_source_positions = source_positions[source_positions < self._context_frame_count()]
|
| 1131 |
+
if anchor_source_positions.numel() > 0:
|
| 1132 |
+
anchor_pose = _pose_subset(anchor_source_positions, batch_idx)
|
| 1133 |
+
anchor_positions = self._select_diverse_anchor_positions(
|
| 1134 |
+
anchor_source_positions, anchor_pose, len(anchor_indices)
|
| 1135 |
+
).to(device=stream_device, dtype=torch.long)
|
| 1136 |
+
else:
|
| 1137 |
+
selected_anchor_positions = []
|
| 1138 |
+
for anchor_idx in anchor_indices:
|
| 1139 |
+
if 0 <= int(anchor_idx) < source_positions.numel():
|
| 1140 |
+
selected_anchor_positions.append(source_positions[int(anchor_idx)])
|
| 1141 |
+
if selected_anchor_positions:
|
| 1142 |
+
anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
|
| 1143 |
+
if anchor_positions.numel() > 0:
|
| 1144 |
+
anchor_projected = self._project_latent_patch_tokens(
|
| 1145 |
+
committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
|
| 1146 |
+
self.dememwm_anchor_proj,
|
| 1147 |
+
token_patch_size,
|
| 1148 |
+
)[0]
|
| 1149 |
+
for local_idx, source_pos in enumerate(anchor_positions):
|
| 1150 |
+
source_pos_i = int(source_pos.item())
|
| 1151 |
+
anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w)
|
| 1152 |
+
n_slots = anchor_tokens.shape[0]
|
| 1153 |
+
record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool)
|
| 1154 |
+
anchor_bank.add_prefix_anchors(
|
| 1155 |
+
anchor_tokens.unsqueeze(0),
|
| 1156 |
+
record_mask.unsqueeze(0),
|
| 1157 |
+
src_frames[source_pos_i:source_pos_i + 1],
|
| 1158 |
+
slots_per_anchor=n_slots,
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
selected_b = selection.selected_positions[batch_idx]
|
| 1162 |
+
selected_mask_b = selection.selected_mask[batch_idx]
|
| 1163 |
+
selected_fov_b = selection.selected_fov_overlap[batch_idx]
|
| 1164 |
+
selected_plucker_b = selection.selected_plucker_overlap[batch_idx]
|
| 1165 |
+
selected_gap_b = selection.selected_gap_frames[batch_idx]
|
| 1166 |
+
high_quality_fov_threshold = float(retrieval_kwargs.get("high_quality_fov_threshold", 0.70))
|
| 1167 |
+
metadata_by_position: dict[int, dict] = {}
|
| 1168 |
+
for target_idx in range(T_tgt):
|
| 1169 |
+
for slot_idx in range(selected_b.shape[1]):
|
| 1170 |
+
if not bool(selected_mask_b[target_idx, slot_idx].detach().item()):
|
| 1171 |
+
continue
|
| 1172 |
+
source_pos_i = int(selected_b[target_idx, slot_idx].detach().item())
|
| 1173 |
+
if source_pos_i < 0:
|
| 1174 |
+
continue
|
| 1175 |
+
frame = int(src_frames[source_pos_i].detach().item())
|
| 1176 |
+
fov_overlap = float(selected_fov_b[target_idx, slot_idx].detach().item())
|
| 1177 |
+
plucker_overlap = float(selected_plucker_b[target_idx, slot_idx].detach().item())
|
| 1178 |
+
gap_frames = float(selected_gap_b[target_idx, slot_idx].detach().item())
|
| 1179 |
+
existing = metadata_by_position.get(source_pos_i)
|
| 1180 |
+
if existing is not None:
|
| 1181 |
+
existing_rank = (
|
| 1182 |
+
float(existing.get("dememwm_selected_frame_fov_overlap", 0.0)),
|
| 1183 |
+
-float(existing.get("dememwm_selected_gap_frames", 1.0e9)),
|
| 1184 |
+
)
|
| 1185 |
+
new_rank = (fov_overlap, -gap_frames)
|
| 1186 |
+
if existing_rank >= new_rank:
|
| 1187 |
+
continue
|
| 1188 |
+
metadata_by_position[source_pos_i] = {
|
| 1189 |
+
"dememwm_selected_revisit_fov_overlap": fov_overlap,
|
| 1190 |
+
"dememwm_selected_revisit_plucker_overlap": plucker_overlap,
|
| 1191 |
+
"dememwm_selected_gap_frames": gap_frames,
|
| 1192 |
+
"dememwm_selected_frame_index": frame,
|
| 1193 |
+
"dememwm_selected_frame_fov_overlap": fov_overlap,
|
| 1194 |
+
"dememwm_selected_frame_fov_threshold": high_quality_fov_threshold,
|
| 1195 |
+
"dememwm_selected_frame_passes_high_quality": bool(fov_overlap >= high_quality_fov_threshold),
|
| 1196 |
+
}
|
| 1197 |
+
flat_selected = selected_b[selected_b >= 0].to(device=stream_device, dtype=torch.long)
|
| 1198 |
+
unique_positions = torch.unique(flat_selected, sorted=True) if flat_selected.numel() > 0 else flat_selected
|
| 1199 |
+
records_by_position: dict[int, MemoryRecord] = {}
|
| 1200 |
+
if unique_positions.numel() > 0:
|
| 1201 |
+
revisit_projected = self._project_latent_patch_tokens(
|
| 1202 |
+
committed_latents.index_select(0, unique_positions)[:, batch_idx:batch_idx + 1],
|
| 1203 |
+
self.dememwm_revisit_proj,
|
| 1204 |
+
token_patch_size,
|
| 1205 |
+
)[0]
|
| 1206 |
+
for local_idx, source_pos in enumerate(unique_positions):
|
| 1207 |
+
source_pos_i = int(source_pos.item())
|
| 1208 |
+
frame_index = src_frames[source_pos_i]
|
| 1209 |
+
frame = int(frame_index.detach().item())
|
| 1210 |
+
is_generated = False if generated is None else bool(generated[source_pos_i, batch_idx].detach().item())
|
| 1211 |
+
source_type = MemorySourceType.GENERATED if is_generated else MemorySourceType.PREFIX_GT
|
| 1212 |
+
prefix = "generated" if is_generated else "prefix"
|
| 1213 |
+
frame_tokens = self._spatial_pool_tokens(revisit_projected[local_idx], revisit_pool_h, revisit_pool_w, src_h, src_w)
|
| 1214 |
+
frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool)
|
| 1215 |
+
record = MemoryRecord(
|
| 1216 |
+
tokens=frame_tokens,
|
| 1217 |
+
mask=frame_mask,
|
| 1218 |
+
source_start=frame,
|
| 1219 |
+
source_end=frame + 1,
|
| 1220 |
+
frame_indices=frame_index.reshape(1).to(device=stream_device),
|
| 1221 |
+
pose=_pose_subset(source_pos.reshape(1), batch_idx),
|
| 1222 |
+
source_type=source_type,
|
| 1223 |
+
is_generated=is_generated,
|
| 1224 |
+
chunk_id=f"{prefix}_revisit_b{batch_idx}_f{frame}",
|
| 1225 |
+
metadata=metadata_by_position.get(source_pos_i, {}),
|
| 1226 |
+
)
|
| 1227 |
+
revisit_bank.add_record(record)
|
| 1228 |
+
records_by_position[source_pos_i] = record
|
| 1229 |
+
|
| 1230 |
+
for target_idx in range(T_tgt):
|
| 1231 |
+
target_records: list[MemoryRecord] = []
|
| 1232 |
+
for source_pos in selected_b[target_idx]:
|
| 1233 |
+
source_pos_i = int(source_pos.detach().item())
|
| 1234 |
+
if source_pos_i < 0:
|
| 1235 |
+
continue
|
| 1236 |
+
record = records_by_position.get(source_pos_i)
|
| 1237 |
+
if record is not None:
|
| 1238 |
+
target_records.append(record)
|
| 1239 |
+
selected_records_by_target[batch_idx][target_idx] = tuple(target_records)
|
| 1240 |
+
|
| 1241 |
+
anchor_banks.append(anchor_bank)
|
| 1242 |
+
revisit_banks.append(revisit_bank)
|
| 1243 |
+
|
| 1244 |
+
fast_revisit_stats = {
|
| 1245 |
+
"best_selected_fov_overlap": selection.best_selected_fov_overlap.to(device=stream_device),
|
| 1246 |
+
"best_selected_plucker_overlap": selection.best_selected_plucker_overlap.to(device=stream_device),
|
| 1247 |
+
"best_selected_gap_frames": selection.best_selected_gap_frames.to(device=stream_device),
|
| 1248 |
+
}
|
| 1249 |
+
return anchor_banks, revisit_banks, tokens_per_frame, selected_records_by_target, fast_revisit_stats
|
| 1250 |
+
|
| 1251 |
def _candidate_record(
|
| 1252 |
*,
|
| 1253 |
batch_idx: int,
|
|
|
|
| 1298 |
if selected_anchor_positions:
|
| 1299 |
anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
|
| 1300 |
if anchor_positions.numel() > 0:
|
|
|
|
| 1301 |
anchor_projected = self._project_latent_patch_tokens(
|
| 1302 |
committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
|
| 1303 |
self.dememwm_anchor_proj,
|
|
|
|
| 1317 |
|
| 1318 |
candidate_records: list[MemoryRecord] = []
|
| 1319 |
candidate_positions: dict[str, torch.Tensor] = {}
|
| 1320 |
+
latest_valid_source_frame_exclusive = target_frame_indices[:, batch_idx].amax() - int(exclude_local_context_frames)
|
|
|
|
|
|
|
| 1321 |
for prefix, positions, source_type, is_generated in (
|
| 1322 |
("prefix", source_positions, MemorySourceType.PREFIX_GT, False),
|
| 1323 |
(
|
|
|
|
| 1327 |
True,
|
| 1328 |
),
|
| 1329 |
):
|
| 1330 |
+
if positions.numel() == 0:
|
| 1331 |
continue
|
| 1332 |
+
positions = positions.to(device=stream_device, dtype=torch.long)
|
| 1333 |
+
frame_values = src_frames.index_select(0, positions).to(device=stream_device)
|
| 1334 |
+
valid_positions = positions[frame_values < latest_valid_source_frame_exclusive]
|
| 1335 |
+
valid_frames = src_frames.index_select(0, valid_positions) if valid_positions.numel() else src_frames[:0]
|
| 1336 |
+
for frame_position, frame_tensor in zip(valid_positions.unbind(0), valid_frames.unbind(0)):
|
| 1337 |
+
frame = int(frame_tensor.item())
|
| 1338 |
+
frame_position = frame_position.reshape(1)
|
| 1339 |
record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}"
|
| 1340 |
candidate_positions[record_id] = frame_position
|
| 1341 |
candidate_records.append(_candidate_record(
|
|
|
|
| 1359 |
exclude_local_context_frames=exclude_local_context_frames,
|
| 1360 |
fov_overlap_threshold=fov_overlap_threshold,
|
| 1361 |
plucker_weight=plucker_weight,
|
| 1362 |
+
target_video_id=None,
|
| 1363 |
**retrieval_kwargs,
|
| 1364 |
)
|
|
|
|
|
|
|
|
|
|
| 1365 |
for selected_record in result.records:
|
| 1366 |
if selected_record.chunk_id is None:
|
| 1367 |
continue
|
|
|
|
| 1374 |
continue
|
| 1375 |
record_id = str(record.chunk_id)
|
| 1376 |
frame_position = candidate_positions[record_id]
|
|
|
|
|
|
|
| 1377 |
revisit_projected = self._project_latent_patch_tokens(
|
| 1378 |
committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1],
|
| 1379 |
self.dememwm_revisit_proj,
|
|
|
|
| 1397 |
anchor_banks.append(anchor_bank)
|
| 1398 |
revisit_banks.append(revisit_bank)
|
| 1399 |
|
| 1400 |
+
return anchor_banks, revisit_banks, tokens_per_frame, None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1401 |
|
| 1402 |
def _causal_cached_revisit_records(
|
| 1403 |
self,
|
|
|
|
| 1529 |
target_video_ids=None,
|
| 1530 |
source_is_generated: torch.Tensor | None = None,
|
| 1531 |
denoising_fraction: float | None = None,
|
|
|
|
|
|
|
| 1532 |
streaming_cache: StreamingCache | None = None,
|
| 1533 |
) -> MemoryStreamTensors:
|
| 1534 |
if target_frame_indices is None:
|
|
|
|
| 1540 |
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
|
| 1541 |
revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
|
| 1542 |
injection_cfg = self._cfg_get(memory_cfg, "injection", None)
|
| 1543 |
+
self._validate_config_contract()
|
| 1544 |
gate_state = self._effective_gate_state(
|
| 1545 |
denoising_fraction=denoising_fraction,
|
|
|
|
| 1546 |
)
|
| 1547 |
anchor_config_enabled = gate_state["anchor_config_enabled"]
|
| 1548 |
dynamic_config_enabled = gate_state["dynamic_config_enabled"]
|
|
|
|
| 1550 |
curriculum_state = gate_state["curriculum_state"]
|
| 1551 |
eval_ablation_enabled = gate_state["eval_ablation_enabled"]
|
| 1552 |
eval_ablation_branch = gate_state["eval_ablation_branch"]
|
|
|
|
| 1553 |
gates = gate_state["gates"]
|
| 1554 |
anchor_effective_enabled = gate_state["anchor_effective_enabled"]
|
| 1555 |
dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
|
|
|
|
| 1604 |
"plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)),
|
| 1605 |
"plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)),
|
| 1606 |
}
|
| 1607 |
+
preselected_revisit_records_by_target: list[list[tuple[MemoryRecord, ...]]] | None = None
|
| 1608 |
+
preselected_revisit_stats: dict | None = None
|
| 1609 |
use_cache_revisit_records = False
|
| 1610 |
revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None
|
| 1611 |
|
| 1612 |
cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None
|
|
|
|
| 1613 |
if committed_latents is not None:
|
| 1614 |
stream_device = committed_latents.device
|
| 1615 |
stream_dtype = committed_latents.dtype
|
|
|
|
| 1677 |
B = committed_latents.shape[1]
|
| 1678 |
hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
|
| 1679 |
target_pose_source = target_pose if target_pose is not None else pose
|
| 1680 |
+
anchor_banks, revisit_banks, tokens_per_frame, preselected_revisit_records_by_target, preselected_revisit_stats = self._build_preselected_causal_memory_banks(
|
| 1681 |
committed_latents,
|
| 1682 |
source_frame_indices.to(device=stream_device),
|
| 1683 |
None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool),
|
|
|
|
| 1749 |
dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w)
|
| 1750 |
dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype)
|
| 1751 |
dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1752 |
else:
|
| 1753 |
# Pre-select dynamic source frame positions using only frame index metadata
|
| 1754 |
# before touching latents, so we pass a small slice instead of the full
|
| 1755 |
# 1000-frame tensor to the compressor.
|
| 1756 |
_dfi = dynamic_frame_indices.to(device=stream_device)
|
| 1757 |
_max_src = self.dememwm_dynamic_compressor.max_source_frames
|
| 1758 |
+
_needed_tensors: list[torch.Tensor] = []
|
| 1759 |
for _b in range(B):
|
| 1760 |
for _j in range(T_tgt):
|
| 1761 |
+
_target = target_frame_indices[_j, _b]
|
| 1762 |
_valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
|
| 1763 |
+
if _valid.numel() > 0:
|
| 1764 |
+
_needed_tensors.append(_valid[-_max_src:])
|
| 1765 |
+
if _needed_tensors:
|
| 1766 |
+
_needed_idx = torch.unique(torch.cat(_needed_tensors, dim=0), sorted=True).to(device=stream_device, dtype=torch.long)
|
| 1767 |
_dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx)
|
| 1768 |
_dynamic_fi_small = _dfi.index_select(0, _needed_idx)
|
| 1769 |
_dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None
|
|
|
|
| 1776 |
_dynamic_fi_small = _dfi[:0]
|
| 1777 |
_dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None
|
| 1778 |
_dynamic_gen_small = None
|
| 1779 |
+
dynamic_tokens, dynamic_mask = self.dememwm_dynamic_compressor(
|
| 1780 |
_dynamic_latents_small,
|
| 1781 |
_dynamic_fi_small,
|
| 1782 |
_dynamic_pose_small,
|
|
|
|
| 1785 |
exclude_latest_local_frames=dynamic_recent_exclusion_frames,
|
| 1786 |
)
|
| 1787 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1788 |
def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int):
|
| 1789 |
if tensor is None or tensor.ndim < 2:
|
| 1790 |
return None
|
|
|
|
| 1823 |
revisit_mask_rows = []
|
| 1824 |
revisit_max_rows = []
|
| 1825 |
valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
|
|
|
|
|
|
| 1826 |
revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1827 |
revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1828 |
revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
|
| 1829 |
eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
|
|
|
| 1830 |
eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
|
|
|
|
| 1831 |
projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {}
|
| 1832 |
if revisit_record_batches is None:
|
| 1833 |
revisit_record_batches = [tuple(bank.records) for bank in revisit_banks]
|
|
|
|
| 1838 |
batch_max_rows = []
|
| 1839 |
for target_idx in range(T_tgt):
|
| 1840 |
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
|
| 1841 |
+
if preselected_revisit_records_by_target is not None:
|
| 1842 |
+
selected_records = list(preselected_revisit_records_by_target[batch_idx][target_idx])
|
| 1843 |
+
if preselected_revisit_stats is not None:
|
| 1844 |
+
revisit_best_selected_fov_overlap[batch_idx, target_idx] = torch.as_tensor(
|
| 1845 |
+
preselected_revisit_stats["best_selected_fov_overlap"][batch_idx, target_idx],
|
| 1846 |
+
device=stream_device,
|
| 1847 |
+
dtype=torch.float32,
|
| 1848 |
+
)
|
| 1849 |
+
revisit_best_selected_plucker_overlap[batch_idx, target_idx] = torch.as_tensor(
|
| 1850 |
+
preselected_revisit_stats["best_selected_plucker_overlap"][batch_idx, target_idx],
|
| 1851 |
+
device=stream_device,
|
| 1852 |
+
dtype=torch.float32,
|
| 1853 |
+
)
|
| 1854 |
+
revisit_selected_gap_frames[batch_idx, target_idx] = torch.as_tensor(
|
| 1855 |
+
preselected_revisit_stats["best_selected_gap_frames"][batch_idx, target_idx],
|
| 1856 |
+
device=stream_device,
|
| 1857 |
+
dtype=torch.float32,
|
| 1858 |
+
)
|
| 1859 |
else:
|
| 1860 |
+
if use_cache_revisit_records:
|
| 1861 |
+
candidate_records = self._causal_cached_revisit_records(
|
| 1862 |
+
revisit_record_batches[batch_idx],
|
| 1863 |
+
target_frame,
|
| 1864 |
)
|
| 1865 |
+
else:
|
| 1866 |
+
candidate_records = revisit_bank.query(
|
| 1867 |
+
MemoryBankQuery(
|
| 1868 |
+
target_frame=target_frame,
|
| 1869 |
+
include_generated=True,
|
| 1870 |
+
)
|
| 1871 |
+
)
|
| 1872 |
+
result = deterministic_revisit_retrieval(
|
| 1873 |
+
candidate_records,
|
| 1874 |
+
target_frame=target_frame,
|
| 1875 |
+
target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
|
| 1876 |
+
target_summary=None,
|
| 1877 |
+
topk=revisit_max_frames,
|
| 1878 |
+
exclude_local_context_frames=revisit_context_window_exclusion_frames,
|
| 1879 |
+
fov_overlap_threshold=fov_overlap_threshold,
|
| 1880 |
+
plucker_weight=plucker_weight,
|
| 1881 |
+
target_video_id=None,
|
| 1882 |
+
**revisit_retrieval_kwargs,
|
| 1883 |
)
|
| 1884 |
+
selected_records = result.records
|
| 1885 |
+
if use_cache_revisit_records and selected_records:
|
| 1886 |
+
selected_records = self._project_streaming_revisit_records(
|
| 1887 |
+
cache=cache,
|
| 1888 |
+
batch_idx=batch_idx,
|
| 1889 |
+
records=selected_records,
|
| 1890 |
+
device=stream_device,
|
| 1891 |
+
dtype=stream_dtype,
|
| 1892 |
+
token_patch_size=token_patch_size,
|
| 1893 |
+
revisit_pool_h=revisit_pool_h,
|
| 1894 |
+
revisit_pool_w=revisit_pool_w,
|
| 1895 |
+
projection_cache=projected_revisit_record_cache,
|
| 1896 |
+
)
|
| 1897 |
+
revisit_best_selected_fov_overlap[batch_idx, target_idx] = result.best_selected_fov_overlap.to(device=stream_device, dtype=torch.float32)
|
| 1898 |
+
revisit_best_selected_plucker_overlap[batch_idx, target_idx] = result.best_selected_plucker_overlap.to(device=stream_device, dtype=torch.float32)
|
| 1899 |
+
revisit_selected_gap_frames[batch_idx, target_idx] = result.best_selected_gap_frames.to(device=stream_device, dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1900 |
revisit_bank.assert_causal(target_frame, selected_records)
|
| 1901 |
if selected_records:
|
| 1902 |
valid_revisit_mask[batch_idx, target_idx] = True
|
|
|
|
| 1907 |
stream_device,
|
| 1908 |
stream_dtype,
|
| 1909 |
)
|
|
|
|
| 1910 |
if eval_corruption_enabled:
|
| 1911 |
stream_tokens, was_corrupted = apply_revisit_eval_corruption(
|
| 1912 |
tokens=stream_tokens,
|
|
|
|
| 1962 |
if not revisit_stage_config_enabled:
|
| 1963 |
revisit_mask = torch.zeros_like(revisit_mask)
|
| 1964 |
valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
|
|
|
|
|
|
|
| 1965 |
revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
|
| 1966 |
revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
|
| 1967 |
revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
|
|
|
|
| 1969 |
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
|
| 1970 |
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
|
| 1971 |
revisit_gate = torch.zeros_like(revisit_gate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1972 |
return MemoryStreamTensors(
|
| 1973 |
anchor_tokens=anchor_tokens,
|
| 1974 |
anchor_mask=anchor_mask,
|
|
|
|
| 1981 |
revisit_gate=revisit_gate,
|
| 1982 |
revisit_gate_raw=revisit_gate_raw,
|
| 1983 |
valid_revisit_mask=valid_revisit_mask,
|
| 1984 |
+
revisit_best_selected_fov_overlap=revisit_best_selected_fov_overlap,
|
| 1985 |
+
revisit_best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
|
| 1986 |
+
revisit_selected_gap_frames=revisit_selected_gap_frames,
|
| 1987 |
)
|
| 1988 |
|
| 1989 |
def _refresh_stream_gates(
|
| 1990 |
self,
|
| 1991 |
streams: MemoryStreamTensors,
|
| 1992 |
denoising_fraction: float | None = None,
|
|
|
|
| 1993 |
) -> MemoryStreamTensors:
|
| 1994 |
gate_state = self._effective_gate_state(
|
| 1995 |
denoising_fraction=denoising_fraction,
|
|
|
|
| 1996 |
)
|
| 1997 |
gates = gate_state["gates"]
|
| 1998 |
device = streams.anchor_tokens.device
|
|
|
|
| 2004 |
else:
|
| 2005 |
valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool)
|
| 2006 |
|
| 2007 |
+
def _gate_feature(value: torch.Tensor | None, fill_value: float = 0.0) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
| 2008 |
if value is None:
|
| 2009 |
return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32)
|
| 2010 |
+
tensor = value.to(device=device, dtype=torch.float32)
|
| 2011 |
if tensor.ndim == 0:
|
| 2012 |
return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32)
|
| 2013 |
return tensor.expand((B, T_tgt))
|
| 2014 |
|
| 2015 |
+
revisit_best_selected_fov_overlap = _gate_feature(streams.revisit_best_selected_fov_overlap)
|
| 2016 |
+
revisit_best_selected_plucker_overlap = _gate_feature(streams.revisit_best_selected_plucker_overlap)
|
| 2017 |
+
revisit_selected_gap_frames = _gate_feature(streams.revisit_selected_gap_frames, -1.0)
|
| 2018 |
|
| 2019 |
anchor_effective_enabled = gate_state["anchor_effective_enabled"]
|
| 2020 |
dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
|
|
|
|
| 2031 |
best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
|
| 2032 |
selected_gap_frames=revisit_selected_gap_frames,
|
| 2033 |
).to(device=device, dtype=dtype)
|
|
|
|
| 2034 |
if not revisit_stage_config_enabled or gate_state["force_revisit_off"]:
|
| 2035 |
revisit_gate = torch.zeros_like(revisit_gate_raw)
|
| 2036 |
elif gate_state["force_revisit_on"]:
|
| 2037 |
+
revisit_gate = valid_revisit_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw)
|
| 2038 |
else:
|
| 2039 |
+
revisit_gate = valid_revisit_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate)
|
| 2040 |
if not revisit_stage_config_enabled:
|
| 2041 |
valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
|
|
|
|
| 2042 |
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
|
| 2043 |
revisit_gate = torch.zeros_like(revisit_gate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2044 |
return replace(
|
| 2045 |
streams,
|
| 2046 |
anchor_gate=anchor_gate,
|
|
|
|
| 2048 |
revisit_gate=revisit_gate,
|
| 2049 |
revisit_gate_raw=revisit_gate_raw,
|
| 2050 |
valid_revisit_mask=valid_revisit_mask,
|
| 2051 |
+
revisit_best_selected_fov_overlap=revisit_best_selected_fov_overlap,
|
| 2052 |
+
revisit_best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
|
| 2053 |
+
revisit_selected_gap_frames=revisit_selected_gap_frames,
|
| 2054 |
)
|
| 2055 |
|
| 2056 |
+
def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]:
|
| 2057 |
+
total_frames = int(total_frames)
|
| 2058 |
+
window = int(getattr(self, "n_tokens", total_frames) or total_frames)
|
| 2059 |
+
if total_frames <= 0:
|
| 2060 |
+
return 0, 0
|
| 2061 |
+
if window <= 0 or total_frames <= window:
|
| 2062 |
+
return 0, total_frames
|
| 2063 |
+
max_start = max(0, total_frames - window)
|
| 2064 |
+
min_start = min(max(0, self._context_frame_count()), max_start)
|
| 2065 |
+
if max_start <= min_start:
|
| 2066 |
+
start = min_start
|
| 2067 |
+
else:
|
| 2068 |
+
start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item())
|
| 2069 |
+
return start, start + window
|
| 2070 |
+
|
| 2071 |
+
def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> dict:
|
| 2072 |
+
return self.dememwm_injection_adapter(
|
| 2073 |
+
streams,
|
| 2074 |
+
device=streams.anchor_tokens.device,
|
| 2075 |
+
dtype=streams.anchor_tokens.dtype,
|
| 2076 |
+
)
|
| 2077 |
|
| 2078 |
+
def build_memory_kwargs(self, *args, **kwargs) -> dict:
|
| 2079 |
streams = self.build_memory_streams(*args, **kwargs)
|
| 2080 |
return self._streams_to_kwargs(streams)
|
| 2081 |
|
| 2082 |
+
def _training_pose_condition(
|
| 2083 |
+
self,
|
| 2084 |
+
pose_conditions: torch.Tensor,
|
| 2085 |
+
c2w_mat: torch.Tensor,
|
| 2086 |
+
frame_idx: torch.Tensor,
|
| 2087 |
+
*,
|
| 2088 |
+
dtype: torch.dtype,
|
| 2089 |
+
image_width: int,
|
| 2090 |
+
image_height: int,
|
| 2091 |
+
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2092 |
if self.use_plucker:
|
| 2093 |
if self.relative_embedding:
|
| 2094 |
+
memory_condition_length = max(0, int(self.memory_condition_length or 0))
|
| 2095 |
+
if memory_condition_length:
|
| 2096 |
+
memory_c2w = c2w_mat[-memory_condition_length:]
|
| 2097 |
+
memory_frame_idx = frame_idx[-memory_condition_length:]
|
| 2098 |
+
else:
|
| 2099 |
+
memory_c2w = c2w_mat[:0]
|
| 2100 |
+
memory_frame_idx = frame_idx[:0]
|
| 2101 |
input_pose_condition = []
|
| 2102 |
frame_idx_list = []
|
| 2103 |
+
for target_idx in range(c2w_mat.shape[0]):
|
|
|
|
|
|
|
| 2104 |
input_pose_condition.append(
|
| 2105 |
convert_to_plucker(
|
| 2106 |
+
torch.cat([c2w_mat[target_idx:target_idx + 1], memory_c2w]).clone(),
|
| 2107 |
0,
|
| 2108 |
focal_length=self.focal_length,
|
| 2109 |
+
image_width=image_width,
|
| 2110 |
+
image_height=image_height,
|
| 2111 |
+
).to(dtype)
|
| 2112 |
+
)
|
| 2113 |
+
frame_idx_list.append(
|
| 2114 |
+
torch.cat([
|
| 2115 |
+
frame_idx[target_idx:target_idx + 1] - frame_idx[target_idx:target_idx + 1],
|
| 2116 |
+
memory_frame_idx - frame_idx[target_idx:target_idx + 1],
|
| 2117 |
+
]).clone()
|
| 2118 |
)
|
|
|
|
| 2119 |
return torch.cat(input_pose_condition), torch.cat(frame_idx_list)
|
| 2120 |
+
return (
|
| 2121 |
+
convert_to_plucker(
|
| 2122 |
+
c2w_mat,
|
| 2123 |
+
0,
|
| 2124 |
+
focal_length=self.focal_length,
|
| 2125 |
+
image_width=image_width,
|
| 2126 |
+
image_height=image_height,
|
| 2127 |
+
).to(dtype),
|
| 2128 |
+
frame_idx,
|
| 2129 |
+
)
|
| 2130 |
+
return pose_conditions.to(dtype), None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2131 |
|
| 2132 |
def training_step(self, batch, batch_idx):
|
| 2133 |
xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
|
| 2134 |
+
image_height, image_width = self._image_size(xs)
|
| 2135 |
xs = self._as_latents(xs)
|
| 2136 |
|
| 2137 |
# Randomly select a contiguous n_tokens denoising window inside the long
|
|
|
|
| 2145 |
frame_idx_window = frame_idx[start:end]
|
| 2146 |
|
| 2147 |
input_pose_condition, frame_idx_list = self._training_pose_condition(
|
| 2148 |
+
pose_conditions[start:end],
|
| 2149 |
+
c2w_mat[start:end],
|
| 2150 |
+
frame_idx_window,
|
| 2151 |
+
dtype=xs_window.dtype,
|
| 2152 |
+
image_width=image_width,
|
| 2153 |
+
image_height=image_height,
|
| 2154 |
)
|
| 2155 |
|
| 2156 |
noise_levels = self._generate_noise_levels(xs_window)
|
|
|
|
| 2158 |
noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
|
| 2159 |
conditions_window[-self.memory_condition_length:] *= 0
|
| 2160 |
source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool)
|
| 2161 |
+
memory_source_latents, source_is_generated = self._apply_generated_history_proxy(
|
| 2162 |
xs,
|
| 2163 |
source_is_generated,
|
| 2164 |
context_frame_count=self._context_frame_count(),
|
| 2165 |
target_start_frame=start,
|
| 2166 |
)
|
| 2167 |
timesteps = int(getattr(self, "timesteps", 0) or 0)
|
|
|
|
|
|
|
| 2168 |
training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps)
|
| 2169 |
+
memory_kwargs = self.build_memory_kwargs(
|
| 2170 |
memory_source_latents,
|
| 2171 |
frame_idx,
|
| 2172 |
target_frame_indices=frame_idx_window,
|
|
|
|
| 2176 |
target_action=conditions_window,
|
| 2177 |
source_is_generated=source_is_generated,
|
| 2178 |
denoising_fraction=training_denoising_fraction,
|
|
|
|
|
|
|
| 2179 |
)
|
|
|
|
| 2180 |
_, loss = self.diffusion_model(
|
| 2181 |
xs_window,
|
| 2182 |
conditions_window,
|
|
|
|
| 2186 |
frame_idx=frame_idx_list,
|
| 2187 |
**memory_kwargs,
|
| 2188 |
)
|
|
|
|
| 2189 |
if self.memory_condition_length:
|
| 2190 |
loss = loss[:-self.memory_condition_length]
|
| 2191 |
loss_denoise = self.reweight_loss(loss, None)
|
| 2192 |
loss_total = loss_denoise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2193 |
if batch_idx % 20 == 0:
|
| 2194 |
+
revisit_gate = memory_kwargs.get("memory_retrieval_gate")
|
| 2195 |
+
if torch.is_tensor(revisit_gate):
|
| 2196 |
+
revisit_gate_value = revisit_gate.detach().float().mean()
|
| 2197 |
+
else:
|
| 2198 |
+
revisit_gate_value = loss_total.detach().new_tensor(0.0 if revisit_gate is None else float(revisit_gate))
|
| 2199 |
+
self.log("training/loss", loss_total.detach(), prog_bar=True, sync_dist=True)
|
| 2200 |
+
self.log("training/denoise_loss", loss_denoise.detach(), prog_bar=False, sync_dist=True)
|
| 2201 |
+
self.log("training/revisit_gate", revisit_gate_value, prog_bar=False, sync_dist=True)
|
| 2202 |
return {"loss": loss_total}
|
| 2203 |
|
| 2204 |
def validation_step(self, batch, batch_idx, namespace="validation"):
|
|
|
|
| 2222 |
streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}")
|
| 2223 |
cached_until = 0
|
| 2224 |
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
|
|
|
| 2225 |
while curr_frame < n_frames:
|
| 2226 |
if streaming_cache is not None and curr_frame > cached_until:
|
| 2227 |
new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device)
|
|
|
|
| 2288 |
from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length)
|
| 2289 |
denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0)
|
| 2290 |
step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac)
|
| 2291 |
+
memory_kwargs = self._streams_to_kwargs(step_streams)
|
| 2292 |
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 2293 |
xs_pred[start_frame:].to(input_condition.device),
|
| 2294 |
input_condition,
|
|
|
|
| 2318 |
action=conditions[cached_until:curr_frame],
|
| 2319 |
)
|
| 2320 |
cached_until = curr_frame
|
|
|
|
|
|
|
| 2321 |
pbar.update(horizon)
|
| 2322 |
pbar.close()
|
|
|
|
|
|
|
| 2323 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
|
| 2324 |
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
|
| 2325 |
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
|
|
|
|
| 2349 |
# Compatibility aliases for old DeMemWM test and experiment call sites.
|
| 2350 |
dememwm_strict_key_prefixes = strict_key_prefixes
|
| 2351 |
dememwm_strict_key_substrings = strict_key_substrings
|
|
|
|
|
|
|
| 2352 |
_dememwm_cfg = _memory_cfg
|
|
|
|
| 2353 |
_dememwm_eval_ablation_cfg = _eval_ablation_cfg
|
| 2354 |
_dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg
|
| 2355 |
_dememwm_eval_ablation_state = _eval_ablation_state
|
|
|
|
| 2382 |
_dememwm_refresh_stream_gates = _refresh_stream_gates
|
| 2383 |
_dememwm_streams_to_kwargs = _streams_to_kwargs
|
| 2384 |
build_dememwm_memory_kwargs = build_memory_kwargs
|
|
|
|
|
|
|
| 2385 |
_dememwm_training_window_bounds = _training_window_bounds
|
| 2386 |
strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check
|
| 2387 |
|
algorithms/worldmem/dememwm/cache.py
CHANGED
|
@@ -486,23 +486,6 @@ class StreamingCache:
|
|
| 486 |
).to(device=device)
|
| 487 |
return latents, frame_indices, generated, pose
|
| 488 |
|
| 489 |
-
def diagnostics(self, prefix: str = "cache") -> dict[str, Any]:
|
| 490 |
-
return {
|
| 491 |
-
f"{prefix}_enabled": bool(self.enabled),
|
| 492 |
-
f"{prefix}_records": int(self.record_count),
|
| 493 |
-
f"{prefix}_anchor_records": int(self.records_count("anchor")),
|
| 494 |
-
f"{prefix}_revisit_records": int(self.records_count("revisit")),
|
| 495 |
-
f"{prefix}_slots": int(self.slot_count),
|
| 496 |
-
f"{prefix}_raw_frame_slots": int(self.raw_frame_slots),
|
| 497 |
-
f"{prefix}_raw_segments": int(self.raw_segment_count),
|
| 498 |
-
f"{prefix}_evictions": int(self.evictions),
|
| 499 |
-
f"{prefix}_resets": int(self.reset_count),
|
| 500 |
-
f"{prefix}_capacity_exceeded": int(self.capacity_exceeded_count),
|
| 501 |
-
f"{prefix}_device": self.device,
|
| 502 |
-
f"{prefix}_current_video_id": self.current_video_id,
|
| 503 |
-
f"{prefix}_clear_between_videos": bool(self.clear_between_videos),
|
| 504 |
-
f"{prefix}_no_evict": bool(self.no_evict),
|
| 505 |
-
}
|
| 506 |
|
| 507 |
|
| 508 |
DeMemWMStreamingCache = StreamingCache
|
|
|
|
| 486 |
).to(device=device)
|
| 487 |
return latents, frame_indices, generated, pose
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
|
| 491 |
DeMemWMStreamingCache = StreamingCache
|
algorithms/worldmem/dememwm/compression.py
CHANGED
|
@@ -163,7 +163,8 @@ class CausalConv3DDynamicCompressor(nn.Module):
|
|
| 163 |
target_frame_indices: torch.Tensor,
|
| 164 |
source_is_generated: Optional[torch.Tensor] = None,
|
| 165 |
exclude_latest_local_frames: Optional[int] = None,
|
| 166 |
-
) -> tuple[torch.Tensor, torch.Tensor
|
|
|
|
| 167 |
if latents.ndim != 5:
|
| 168 |
raise ValueError("latents must have shape (T_src,B,C,H,W)")
|
| 169 |
exclude_latest_local_frames = (
|
|
@@ -175,86 +176,78 @@ class CausalConv3DDynamicCompressor(nn.Module):
|
|
| 175 |
p = self.patch_size
|
| 176 |
if H % p != 0 or W % p != 0:
|
| 177 |
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
|
|
|
|
|
|
|
| 178 |
if target_frame_indices.ndim == 1:
|
| 179 |
target_frame_indices = target_frame_indices[:, None].expand(-1, B)
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
device = latents.device
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
n_spatial = (H // p) * (W // p)
|
| 184 |
T_out = self._temporal_output_count()
|
| 185 |
num_slots = T_out * n_spatial
|
| 186 |
output_time_idx = self._output_time_indices(device)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
diagnostics = {
|
| 248 |
-
"num_dynamic_slots": num_slots,
|
| 249 |
-
"dynamic_T_out": T_out,
|
| 250 |
-
"dynamic_n_spatial": n_spatial,
|
| 251 |
-
"dynamic_temporal_left_pad": self.causal_pad,
|
| 252 |
-
"dynamic_output_time_indices": output_time_idx,
|
| 253 |
-
"selected_source_count": selected_source_count,
|
| 254 |
-
"max_source_frame": max_source_frame,
|
| 255 |
-
"generated_source_fraction": generated_source_fraction,
|
| 256 |
-
"dynamic_min_gap_to_target_per_target": min_gap,
|
| 257 |
-
"dynamic_max_gap_to_target_per_target": max_gap,
|
| 258 |
-
"dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
|
| 259 |
-
}
|
| 260 |
-
return out_tokens, out_mask, diagnostics
|
|
|
|
| 163 |
target_frame_indices: torch.Tensor,
|
| 164 |
source_is_generated: Optional[torch.Tensor] = None,
|
| 165 |
exclude_latest_local_frames: Optional[int] = None,
|
| 166 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 167 |
+
del pose, source_is_generated
|
| 168 |
if latents.ndim != 5:
|
| 169 |
raise ValueError("latents must have shape (T_src,B,C,H,W)")
|
| 170 |
exclude_latest_local_frames = (
|
|
|
|
| 176 |
p = self.patch_size
|
| 177 |
if H % p != 0 or W % p != 0:
|
| 178 |
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
|
| 179 |
+
if frame_indices.shape != (T_src, B):
|
| 180 |
+
raise ValueError("frame_indices must have shape (T_src,B)")
|
| 181 |
if target_frame_indices.ndim == 1:
|
| 182 |
target_frame_indices = target_frame_indices[:, None].expand(-1, B)
|
| 183 |
+
if target_frame_indices.ndim != 2 or target_frame_indices.shape[1] != B:
|
| 184 |
+
raise ValueError("target_frame_indices must have shape (T_tgt,B)")
|
| 185 |
+
|
| 186 |
device = latents.device
|
| 187 |
+
frame_indices = frame_indices.to(device=device)
|
| 188 |
+
target_frame_indices = target_frame_indices.to(device=device)
|
| 189 |
+
T_tgt = target_frame_indices.shape[0]
|
| 190 |
n_spatial = (H // p) * (W // p)
|
| 191 |
T_out = self._temporal_output_count()
|
| 192 |
num_slots = T_out * n_spatial
|
| 193 |
output_time_idx = self._output_time_indices(device)
|
| 194 |
+
if T_src == 0:
|
| 195 |
+
out_tokens = latents.new_zeros((B, T_tgt, num_slots, self.dit_hidden_size))
|
| 196 |
+
out_mask = torch.zeros((B, T_tgt, num_slots), device=device, dtype=torch.bool)
|
| 197 |
+
return out_tokens, out_mask
|
| 198 |
+
|
| 199 |
+
source_frames = frame_indices.transpose(0, 1).contiguous()
|
| 200 |
+
target_frames = target_frame_indices.transpose(0, 1).contiguous()
|
| 201 |
+
valid = source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames))
|
| 202 |
+
valid_flat = valid.reshape(B * T_tgt, T_src)
|
| 203 |
+
source_frames_flat = source_frames[:, None, :].expand(B, T_tgt, T_src).reshape(B * T_tgt, T_src)
|
| 204 |
+
|
| 205 |
+
topk = min(int(self.max_source_frames), T_src)
|
| 206 |
+
rank = source_frames_flat.to(dtype=torch.float64).masked_fill(~valid_flat, -float("inf"))
|
| 207 |
+
top = torch.topk(rank, k=topk, dim=1, largest=True, sorted=True)
|
| 208 |
+
selected_idx = top.indices.flip(dims=(1,))
|
| 209 |
+
selected_valid = torch.isfinite(top.values).flip(dims=(1,))
|
| 210 |
+
if topk < self.max_source_frames:
|
| 211 |
+
pad_count = self.max_source_frames - topk
|
| 212 |
+
selected_idx = torch.cat([
|
| 213 |
+
torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.long),
|
| 214 |
+
selected_idx,
|
| 215 |
+
], dim=1)
|
| 216 |
+
selected_valid = torch.cat([
|
| 217 |
+
torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.bool),
|
| 218 |
+
selected_valid,
|
| 219 |
+
], dim=1)
|
| 220 |
+
|
| 221 |
+
selected_idx_clamped = selected_idx.to(device=device, dtype=torch.long).clamp(min=0, max=max(0, T_src - 1))
|
| 222 |
+
has_valid = selected_valid.any(dim=1)
|
| 223 |
+
batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt)
|
| 224 |
+
latents_by_batch = latents.permute(1, 0, 2, 3, 4).contiguous()
|
| 225 |
+
latents_per_query = latents_by_batch.index_select(0, batch_ids)
|
| 226 |
+
gather_idx = selected_idx_clamped.reshape(B * T_tgt, self.max_source_frames, 1, 1, 1).expand(
|
| 227 |
+
-1, -1, C, H, W
|
| 228 |
+
)
|
| 229 |
+
chunk = torch.gather(latents_per_query, 1, gather_idx)
|
| 230 |
+
chunk = torch.where(
|
| 231 |
+
selected_valid[:, :, None, None, None],
|
| 232 |
+
chunk,
|
| 233 |
+
torch.zeros((), device=device, dtype=latents.dtype),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
inp = chunk.clone()
|
| 237 |
+
inp[:, 1:] = chunk[:, 1:] - chunk[:, :-1]
|
| 238 |
+
x = inp.permute(0, 2, 1, 3, 4)
|
| 239 |
+
x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0))
|
| 240 |
+
x = self.conv3d(x)
|
| 241 |
+
x = self.out_norm(x.permute(0, 2, 3, 4, 1))
|
| 242 |
+
tokens_flat = x.reshape(B * T_tgt, num_slots, self.dit_hidden_size)
|
| 243 |
+
tokens_flat = torch.where(has_valid[:, None, None], tokens_flat, torch.zeros_like(tokens_flat))
|
| 244 |
+
out_tokens = tokens_flat.reshape(B, T_tgt, num_slots, self.dit_hidden_size)
|
| 245 |
+
|
| 246 |
+
clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
|
| 247 |
+
temporal_mask = (
|
| 248 |
+
(output_time_idx >= 0)
|
| 249 |
+
& (output_time_idx < self.max_source_frames)
|
| 250 |
+
& selected_valid.index_select(1, clamped_time_idx)
|
| 251 |
+
)
|
| 252 |
+
out_mask = temporal_mask[:, :, None].expand(B * T_tgt, T_out, n_spatial).reshape(B, T_tgt, num_slots)
|
| 253 |
+
return out_tokens, out_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
algorithms/worldmem/dememwm/diagnostics.py
DELETED
|
@@ -1,172 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
from typing import Any
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
_REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker"
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def tensor_valid_fraction(mask: torch.Tensor | None) -> float:
|
| 14 |
-
if mask is None or mask.numel() == 0:
|
| 15 |
-
return 0.0
|
| 16 |
-
return float(mask.detach().bool().float().mean().item())
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]:
|
| 20 |
-
if gate is None:
|
| 21 |
-
return {"mean": 0.0, "min": 0.0, "max": 0.0}
|
| 22 |
-
if not torch.is_tensor(gate):
|
| 23 |
-
value = float(gate)
|
| 24 |
-
return {"mean": value, "min": value, "max": value}
|
| 25 |
-
g = gate.detach().float()
|
| 26 |
-
return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())}
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]:
|
| 30 |
-
return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)}
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None:
|
| 34 |
-
max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame)
|
| 35 |
-
if max_src >= int(target_frame):
|
| 36 |
-
raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}")
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]:
|
| 40 |
-
values: list[float] = []
|
| 41 |
-
for diag in result_diagnostics:
|
| 42 |
-
for value in diag.get(key, []) or []:
|
| 43 |
-
values.append(float(value))
|
| 44 |
-
return values
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def _value_stats(values: list[float], prefix: str) -> dict[str, float]:
|
| 48 |
-
if not values:
|
| 49 |
-
return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
|
| 50 |
-
return {
|
| 51 |
-
f"{prefix}_mean": float(sum(values) / len(values)),
|
| 52 |
-
f"{prefix}_min": float(min(values)),
|
| 53 |
-
f"{prefix}_max": float(max(values)),
|
| 54 |
-
}
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]:
|
| 58 |
-
target_count = len(result_diagnostics)
|
| 59 |
-
candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics)
|
| 60 |
-
candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0
|
| 61 |
-
valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics)
|
| 62 |
-
pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics)
|
| 63 |
-
pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics)
|
| 64 |
-
exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
|
| 65 |
-
valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
|
| 66 |
-
valid_count_mean = float(valid_count / target_count) if target_count else 0.0
|
| 67 |
-
selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
|
| 68 |
-
no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
|
| 69 |
-
abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
|
| 70 |
-
selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0]
|
| 71 |
-
diagnostics: dict[str, Any] = {
|
| 72 |
-
"revisit_candidate_frame_count": candidate_count_mean,
|
| 73 |
-
"revisit_candidate_count": candidate_count_mean,
|
| 74 |
-
"valid_candidate_label_count": int(valid_candidate_label_count),
|
| 75 |
-
"revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0,
|
| 76 |
-
"revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0,
|
| 77 |
-
"revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
|
| 78 |
-
"valid_revisit_frame_count": valid_count_mean,
|
| 79 |
-
"valid_revisit_count": valid_count_mean,
|
| 80 |
-
"no_valid_revisit_count": int(no_valid_count),
|
| 81 |
-
"valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
|
| 82 |
-
"revisit_selected_frame_count": int(selected_count),
|
| 83 |
-
"revisit_selected_count": int(selected_count),
|
| 84 |
-
"revisit_abstained_count": int(abstained_count),
|
| 85 |
-
"revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
|
| 86 |
-
"revisit_label_source": _REVISIT_LABEL_SOURCE,
|
| 87 |
-
}
|
| 88 |
-
frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values")
|
| 89 |
-
if not frame_fov_values:
|
| 90 |
-
frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values")
|
| 91 |
-
diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap"))
|
| 92 |
-
diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap"))
|
| 93 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap"))
|
| 94 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap"))
|
| 95 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap"))
|
| 96 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames"))
|
| 97 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap"))
|
| 98 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap"))
|
| 99 |
-
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap"))
|
| 100 |
-
return diagnostics
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def summarize_noise_bucket_diagnostics(
|
| 104 |
-
*,
|
| 105 |
-
noise_bucket: str | None,
|
| 106 |
-
valid_revisit_mask: torch.Tensor | None,
|
| 107 |
-
no_valid_revisit_mask: torch.Tensor | None,
|
| 108 |
-
noise_bucket_ids: torch.Tensor | None = None,
|
| 109 |
-
) -> dict[str, Any]:
|
| 110 |
-
bucket = normalize_noise_bucket(noise_bucket)
|
| 111 |
-
diagnostics: dict[str, Any] = {
|
| 112 |
-
"noise_bucket": bucket,
|
| 113 |
-
"noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]),
|
| 114 |
-
}
|
| 115 |
-
for candidate in NOISE_BUCKETS:
|
| 116 |
-
diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate)
|
| 117 |
-
|
| 118 |
-
valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 119 |
-
no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 120 |
-
target_count = int(valid.numel())
|
| 121 |
-
diagnostics["noise_bucket_target_count"] = target_count
|
| 122 |
-
if noise_bucket_ids is None:
|
| 123 |
-
target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long)
|
| 124 |
-
else:
|
| 125 |
-
target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu()
|
| 126 |
-
if int(target_bucket_ids.numel()) != target_count:
|
| 127 |
-
raise ValueError(
|
| 128 |
-
f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}"
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
for bucket_name in NOISE_BUCKETS:
|
| 132 |
-
bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
|
| 133 |
-
diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item())
|
| 134 |
-
|
| 135 |
-
mask_specs = (
|
| 136 |
-
("valid_revisit", valid),
|
| 137 |
-
("no_valid_revisit", no_valid),
|
| 138 |
-
)
|
| 139 |
-
for mask_name, mask in mask_specs:
|
| 140 |
-
for bucket_name in NOISE_BUCKETS:
|
| 141 |
-
bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
|
| 142 |
-
count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0
|
| 143 |
-
diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count
|
| 144 |
-
return diagnostics
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def summarize_eval_ablation_diagnostics(
|
| 148 |
-
*,
|
| 149 |
-
enabled: bool,
|
| 150 |
-
branch: str | None,
|
| 151 |
-
valid_revisit_mask: torch.Tensor | None,
|
| 152 |
-
no_valid_revisit_mask: torch.Tensor | None,
|
| 153 |
-
eval_corrupted_revisit_mask: torch.Tensor | None,
|
| 154 |
-
) -> dict[str, Any]:
|
| 155 |
-
branch = normalize_eval_ablation_branch(branch)
|
| 156 |
-
valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 157 |
-
no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 158 |
-
corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 159 |
-
true_revisit = valid & (~corrupted)
|
| 160 |
-
diagnostics: dict[str, Any] = {
|
| 161 |
-
"eval_ablation_enabled": bool(enabled),
|
| 162 |
-
"eval_ablation_branch": branch,
|
| 163 |
-
"eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]),
|
| 164 |
-
"eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()),
|
| 165 |
-
"eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()),
|
| 166 |
-
"eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()),
|
| 167 |
-
}
|
| 168 |
-
total = max(int(valid.numel()), 1)
|
| 169 |
-
diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total)
|
| 170 |
-
diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total)
|
| 171 |
-
diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total)
|
| 172 |
-
return diagnostics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
algorithms/worldmem/dememwm/injection.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
from dataclasses import dataclass
|
|
@@ -6,7 +5,6 @@ from typing import Any
|
|
| 6 |
|
| 7 |
import torch
|
| 8 |
|
| 9 |
-
from .diagnostics import summarize_stream
|
| 10 |
from .types import MemoryStreamTensors
|
| 11 |
|
| 12 |
|
|
@@ -31,7 +29,12 @@ class InjectionAdapter:
|
|
| 31 |
return gate.to(device=device, dtype=dtype)
|
| 32 |
return torch.tensor(float(gate), device=device, dtype=dtype)
|
| 33 |
|
| 34 |
-
def __call__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
ref = streams.anchor_tokens
|
| 36 |
device = device or ref.device
|
| 37 |
dtype = dtype or ref.dtype
|
|
@@ -62,22 +65,7 @@ class InjectionAdapter:
|
|
| 62 |
if not revisit_mask.any():
|
| 63 |
kwargs["memory_retrieval_tokens"] = None
|
| 64 |
kwargs["memory_retrieval_mask"] = None
|
| 65 |
-
|
| 66 |
-
diagnostics.update(summarize_stream("anchor", anchor_tokens, anchor_mask, kwargs["memory_anchor_gate"]))
|
| 67 |
-
diagnostics.update(summarize_stream("dynamic", dynamic_tokens, dynamic_mask, kwargs["memory_dynamic_gate"]))
|
| 68 |
-
diagnostics.update(summarize_stream("revisit", revisit_tokens, revisit_mask, kwargs["memory_retrieval_gate"]))
|
| 69 |
-
if streams.revisit_gate_raw is not None:
|
| 70 |
-
raw_gate = streams.revisit_gate_raw.to(device=device, dtype=dtype)
|
| 71 |
-
diagnostics["revisit_gate_raw"] = raw_gate
|
| 72 |
-
diagnostics["revisit_gate_raw_mean"] = float(raw_gate.detach().float().mean().item()) if raw_gate.numel() else 0.0
|
| 73 |
-
diagnostics["revisit_gate_raw_min"] = float(raw_gate.detach().float().min().item()) if raw_gate.numel() else 0.0
|
| 74 |
-
diagnostics["revisit_gate_raw_max"] = float(raw_gate.detach().float().max().item()) if raw_gate.numel() else 0.0
|
| 75 |
-
if streams.no_valid_revisit_mask is not None:
|
| 76 |
-
diagnostics["no_valid_revisit_mask"] = streams.no_valid_revisit_mask.to(device=device, dtype=torch.bool)
|
| 77 |
-
max_sources = [v for k, v in streams.diagnostics.items() if k.endswith("max_source_frame")]
|
| 78 |
-
if max_sources:
|
| 79 |
-
diagnostics["max_source_frame"] = max(int(torch.as_tensor(v).max().item()) for v in max_sources)
|
| 80 |
-
return kwargs, diagnostics
|
| 81 |
|
| 82 |
|
| 83 |
DeMemWMInjectionAdapter = InjectionAdapter
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
|
|
|
|
| 8 |
from .types import MemoryStreamTensors
|
| 9 |
|
| 10 |
|
|
|
|
| 29 |
return gate.to(device=device, dtype=dtype)
|
| 30 |
return torch.tensor(float(gate), device=device, dtype=dtype)
|
| 31 |
|
| 32 |
+
def __call__(
|
| 33 |
+
self,
|
| 34 |
+
streams: MemoryStreamTensors,
|
| 35 |
+
device=None,
|
| 36 |
+
dtype=None,
|
| 37 |
+
) -> dict[str, Any]:
|
| 38 |
ref = streams.anchor_tokens
|
| 39 |
device = device or ref.device
|
| 40 |
dtype = dtype or ref.dtype
|
|
|
|
| 65 |
if not revisit_mask.any():
|
| 66 |
kwargs["memory_retrieval_tokens"] = None
|
| 67 |
kwargs["memory_retrieval_mask"] = None
|
| 68 |
+
return kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
DeMemWMInjectionAdapter = InjectionAdapter
|
algorithms/worldmem/dememwm/retrieval.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import math
|
| 4 |
-
from dataclasses import replace
|
| 5 |
from typing import Any, Optional
|
| 6 |
|
| 7 |
import torch
|
|
@@ -9,6 +9,7 @@ import torch
|
|
| 9 |
from .labels import (
|
| 10 |
LABEL_SOURCE,
|
| 11 |
RevisitCandidateLabel,
|
|
|
|
| 12 |
_inside_fov_3d_hv,
|
| 13 |
_plucker_descriptor,
|
| 14 |
_target_fov_points,
|
|
@@ -16,24 +17,19 @@ from .labels import (
|
|
| 16 |
from .types import MemoryRecord, RevisitRetrievalResult
|
| 17 |
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
-
def _overlap_stats(values: list[float], prefix: str) -> dict[str, float]:
|
| 29 |
-
if not values:
|
| 30 |
-
return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
|
| 31 |
-
return {
|
| 32 |
-
f"{prefix}_mean": float(sum(values) / len(values)),
|
| 33 |
-
f"{prefix}_min": float(min(values)),
|
| 34 |
-
f"{prefix}_max": float(max(values)),
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
|
| 38 |
def _pose_rows(pose) -> torch.Tensor | None:
|
| 39 |
if pose is None:
|
|
@@ -58,6 +54,348 @@ def _pose_forward(poses: torch.Tensor) -> torch.Tensor:
|
|
| 58 |
)
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None:
|
| 62 |
if int(record.frame_indices.numel()) != 1:
|
| 63 |
return None
|
|
@@ -83,19 +421,9 @@ def _vectorized_frame_candidate_labels(
|
|
| 83 |
plucker_grid_w: int,
|
| 84 |
plucker_focal_length: float,
|
| 85 |
pose_preselect_topk: Optional[int],
|
| 86 |
-
) ->
|
| 87 |
-
diagnostics: dict[str, float | int] = {
|
| 88 |
-
"revisit_pose_preselect_input_count": len(records),
|
| 89 |
-
"revisit_pose_preselect_scored_count": len(records),
|
| 90 |
-
"revisit_pose_preselect_unscored_count": 0,
|
| 91 |
-
"revisit_pose_preselect_selected_count": len(records),
|
| 92 |
-
"revisit_pose_preselect_min_distance": 0.0,
|
| 93 |
-
"revisit_pose_preselect_max_distance": 0.0,
|
| 94 |
-
"revisit_exact_fov_candidate_count": len(records),
|
| 95 |
-
"revisit_vectorized_frame_scorer_used": 1,
|
| 96 |
-
}
|
| 97 |
if not records:
|
| 98 |
-
return []
|
| 99 |
|
| 100 |
target_poses = _pose_rows(target_pose)
|
| 101 |
if target_poses is None:
|
|
@@ -137,9 +465,6 @@ def _vectorized_frame_candidate_labels(
|
|
| 137 |
]
|
| 138 |
ranked.sort()
|
| 139 |
selected_indices = [idx for *_, idx in ranked[:topk]]
|
| 140 |
-
diagnostics["revisit_pose_preselect_selected_count"] = len(selected_indices)
|
| 141 |
-
diagnostics["revisit_pose_preselect_min_distance"] = float(min(distance_values))
|
| 142 |
-
diagnostics["revisit_pose_preselect_max_distance"] = float(max(distance_values))
|
| 143 |
|
| 144 |
selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long)
|
| 145 |
selected_records = [records[idx] for idx in selected_indices]
|
|
@@ -175,7 +500,6 @@ def _vectorized_frame_candidate_labels(
|
|
| 175 |
if fov_overlap_threshold is not None:
|
| 176 |
valid_mask = fov_values >= float(fov_overlap_threshold)
|
| 177 |
|
| 178 |
-
diagnostics["revisit_exact_fov_candidate_count"] = len(selected_records)
|
| 179 |
fov_list = [float(value) for value in fov_values.detach().cpu().tolist()]
|
| 180 |
plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()]
|
| 181 |
valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()]
|
|
@@ -200,7 +524,7 @@ def _vectorized_frame_candidate_labels(
|
|
| 200 |
best_frame_fov_overlap=fov_overlap,
|
| 201 |
)
|
| 202 |
)
|
| 203 |
-
return labels
|
| 204 |
|
| 205 |
|
| 206 |
def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float:
|
|
@@ -297,23 +621,6 @@ def _best_selected_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidat
|
|
| 297 |
)
|
| 298 |
|
| 299 |
|
| 300 |
-
def _best_selected_frame_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None:
|
| 301 |
-
frame_labels = [label for label in labels if label.best_frame_fov_overlap is not None]
|
| 302 |
-
if not frame_labels:
|
| 303 |
-
return None
|
| 304 |
-
return max(
|
| 305 |
-
frame_labels,
|
| 306 |
-
key=lambda label: (
|
| 307 |
-
float(label.best_frame_fov_overlap),
|
| 308 |
-
0.0 if label.fov_overlap is None else float(label.fov_overlap),
|
| 309 |
-
0.0 if label.plucker_overlap is None else float(label.plucker_overlap),
|
| 310 |
-
-int(label.gap_to_target),
|
| 311 |
-
-int(label.record.source_start),
|
| 312 |
-
str(label.record.chunk_id or ""),
|
| 313 |
-
),
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
|
| 317 |
def _record_with_selected_frame_metadata(
|
| 318 |
label: RevisitCandidateLabel,
|
| 319 |
*,
|
|
@@ -334,6 +641,11 @@ def _record_with_selected_frame_metadata(
|
|
| 334 |
return replace(label.record, metadata=metadata)
|
| 335 |
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
def deterministic_revisit_retrieval(
|
| 338 |
records: list[MemoryRecord],
|
| 339 |
target_frame: int,
|
|
@@ -367,7 +679,7 @@ def deterministic_revisit_retrieval(
|
|
| 367 |
for record in causal_records
|
| 368 |
if int(record.source_end) <= target_frame - exclude_local_context_frames
|
| 369 |
]
|
| 370 |
-
labels
|
| 371 |
score_records,
|
| 372 |
target_frame=target_frame,
|
| 373 |
target_pose=target_pose,
|
|
@@ -383,21 +695,16 @@ def deterministic_revisit_retrieval(
|
|
| 383 |
plucker_focal_length=plucker_focal_length,
|
| 384 |
pose_preselect_topk=pose_preselect_topk,
|
| 385 |
)
|
| 386 |
-
exact_fov_candidate_count = int(pose_preselect_diagnostics["revisit_exact_fov_candidate_count"])
|
| 387 |
valid_labels = [label for label in labels if label.valid]
|
| 388 |
-
selected_labels, selected_scores,
|
| 389 |
valid_labels,
|
| 390 |
topk=topk,
|
| 391 |
plucker_weight=float(plucker_weight),
|
| 392 |
)
|
| 393 |
best_selected = _best_selected_label(selected_labels)
|
| 394 |
-
best_selected_frame = _best_selected_frame_label(selected_labels)
|
| 395 |
best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap)
|
| 396 |
best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap)
|
| 397 |
best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target)
|
| 398 |
-
best_selected_frame_fov = 0.0 if best_selected_frame is None else float(best_selected_frame.best_frame_fov_overlap)
|
| 399 |
-
best_selected_frame_index = -1 if best_selected_frame is None or best_selected_frame.best_frame_index is None else int(best_selected_frame.best_frame_index)
|
| 400 |
-
high_quality_selected = int(best_selected_frame is not None and best_selected_frame_fov >= float(high_quality_fov_threshold))
|
| 401 |
selected_records = [
|
| 402 |
_record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold))
|
| 403 |
for label in selected_labels
|
|
@@ -405,71 +712,11 @@ def deterministic_revisit_retrieval(
|
|
| 405 |
score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu")
|
| 406 |
scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device)
|
| 407 |
|
| 408 |
-
fov_values = _overlap_values(valid_labels, "fov_overlap")
|
| 409 |
-
plucker_values = _overlap_values(valid_labels, "plucker_overlap")
|
| 410 |
-
selected_gaps = [label.gap_to_target for label in selected_labels]
|
| 411 |
-
selected_frame_fov_values = [
|
| 412 |
-
float(label.best_frame_fov_overlap)
|
| 413 |
-
for label in selected_labels
|
| 414 |
-
if label.best_frame_fov_overlap is not None
|
| 415 |
-
]
|
| 416 |
-
diagnostics = {
|
| 417 |
-
"target_frame": int(target_frame),
|
| 418 |
-
"candidate_count": len(causal_records),
|
| 419 |
-
"candidate_frame_count": len(causal_records),
|
| 420 |
-
"valid_candidate_count": len(valid_labels),
|
| 421 |
-
"revisit_exact_fov_candidate_count": exact_fov_candidate_count,
|
| 422 |
-
"valid_candidate_frame_count": len(valid_labels),
|
| 423 |
-
"valid_candidate_label_count": len(valid_labels),
|
| 424 |
-
"selected_count": len(selected_records),
|
| 425 |
-
"selected_frame_count": len(selected_records),
|
| 426 |
-
"revisit_candidate_frame_count": len(causal_records),
|
| 427 |
-
"revisit_candidate_count": len(causal_records),
|
| 428 |
-
"valid_revisit_frame_count": len(valid_labels),
|
| 429 |
-
"valid_revisit_count": len(valid_labels),
|
| 430 |
-
"no_valid_revisit_count": int(len(valid_labels) == 0),
|
| 431 |
-
"valid_revisit_mask": int(len(valid_labels) > 0),
|
| 432 |
-
"revisit_abstained_count": int(len(selected_records) == 0),
|
| 433 |
-
"abstained": bool(len(selected_records) == 0),
|
| 434 |
-
"revisit_selected_frame_count": len(selected_records),
|
| 435 |
-
"revisit_selected_count": len(selected_records),
|
| 436 |
-
"revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
|
| 437 |
-
"best_selected_fov_overlap": best_selected_fov,
|
| 438 |
-
"best_selected_plucker_overlap": best_selected_plucker,
|
| 439 |
-
"best_selected_gap_frames": best_selected_gap,
|
| 440 |
-
"best_selected_frame_index": best_selected_frame_index,
|
| 441 |
-
"best_selected_frame_fov_overlap": best_selected_frame_fov,
|
| 442 |
-
"best_selected_frame_passes_high_quality": high_quality_selected,
|
| 443 |
-
"high_quality_selected_revisit": high_quality_selected,
|
| 444 |
-
"high_quality_fov_threshold": float(high_quality_fov_threshold),
|
| 445 |
-
"revisit_label_source": LABEL_SOURCE,
|
| 446 |
-
"selected_frame_ids": [int(record.max_source_frame) for record in selected_records],
|
| 447 |
-
"selected_frame_record_ids": [record.chunk_id for record in selected_records],
|
| 448 |
-
"selected_ranges": [(record.source_start, record.source_end) for record in selected_records],
|
| 449 |
-
"frame_fov_overlap_values": fov_values,
|
| 450 |
-
"fov_overlap_values": fov_values,
|
| 451 |
-
"plucker_overlap_values": plucker_values,
|
| 452 |
-
"best_selected_fov_overlap_values": [] if best_selected is None else [best_selected_fov],
|
| 453 |
-
"best_selected_plucker_overlap_values": [] if best_selected is None else [best_selected_plucker],
|
| 454 |
-
"best_selected_gap_frame_values": [] if best_selected is None else [best_selected_gap],
|
| 455 |
-
"best_selected_frame_fov_overlap_values": [] if best_selected_frame is None else [best_selected_frame_fov],
|
| 456 |
-
"selected_frame_fov_overlap_values": selected_frame_fov_values,
|
| 457 |
-
"selected_incremental_fov_overlap_values": selected_gains,
|
| 458 |
-
"selected_revisit_scores": selected_scores,
|
| 459 |
-
**pose_preselect_diagnostics,
|
| 460 |
-
}
|
| 461 |
-
diagnostics.update(_overlap_stats(fov_values, "revisit_frame_fov_overlap"))
|
| 462 |
-
diagnostics.update(_overlap_stats(fov_values, "revisit_fov_overlap"))
|
| 463 |
-
diagnostics.update(_overlap_stats(plucker_values, "revisit_plucker_overlap"))
|
| 464 |
-
diagnostics.update(_overlap_stats(diagnostics["best_selected_fov_overlap_values"], "revisit_best_selected_fov_overlap"))
|
| 465 |
-
diagnostics.update(_overlap_stats(diagnostics["best_selected_plucker_overlap_values"], "revisit_best_selected_plucker_overlap"))
|
| 466 |
-
diagnostics.update(_overlap_stats(diagnostics["best_selected_gap_frame_values"], "revisit_best_selected_gap_frames"))
|
| 467 |
-
diagnostics.update(_overlap_stats(diagnostics["best_selected_frame_fov_overlap_values"], "revisit_best_selected_frame_fov_overlap"))
|
| 468 |
-
diagnostics.update(_overlap_stats(selected_frame_fov_values, "revisit_selected_frame_fov_overlap"))
|
| 469 |
-
diagnostics.update(_overlap_stats(selected_gains, "revisit_incremental_fov_overlap"))
|
| 470 |
return RevisitRetrievalResult(
|
| 471 |
records=selected_records,
|
| 472 |
scores=scores,
|
| 473 |
selected_frame_ids=[int(record.max_source_frame) for record in selected_records],
|
| 474 |
-
|
|
|
|
|
|
|
| 475 |
)
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import math
|
| 4 |
+
from dataclasses import dataclass, replace
|
| 5 |
from typing import Any, Optional
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 9 |
from .labels import (
|
| 10 |
LABEL_SOURCE,
|
| 11 |
RevisitCandidateLabel,
|
| 12 |
+
_angle_diff_degrees,
|
| 13 |
_inside_fov_3d_hv,
|
| 14 |
_plucker_descriptor,
|
| 15 |
_target_fov_points,
|
|
|
|
| 17 |
from .types import MemoryRecord, RevisitRetrievalResult
|
| 18 |
|
| 19 |
|
| 20 |
+
@dataclass
|
| 21 |
+
class BatchedRevisitSelectionResult:
|
| 22 |
+
selected_positions: torch.Tensor
|
| 23 |
+
selected_mask: torch.Tensor
|
| 24 |
+
selected_scores: torch.Tensor
|
| 25 |
+
selected_fov_overlap: torch.Tensor
|
| 26 |
+
selected_plucker_overlap: torch.Tensor
|
| 27 |
+
selected_gap_frames: torch.Tensor
|
| 28 |
+
best_selected_fov_overlap: torch.Tensor
|
| 29 |
+
best_selected_plucker_overlap: torch.Tensor
|
| 30 |
+
best_selected_gap_frames: torch.Tensor
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def _pose_rows(pose) -> torch.Tensor | None:
|
| 35 |
if pose is None:
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
|
| 57 |
+
|
| 58 |
+
def _time_batch_pose_rows(
|
| 59 |
+
pose: torch.Tensor,
|
| 60 |
+
*,
|
| 61 |
+
time: int,
|
| 62 |
+
batch: int,
|
| 63 |
+
name: str,
|
| 64 |
+
) -> torch.Tensor:
|
| 65 |
+
if pose is None:
|
| 66 |
+
raise ValueError(f"{name} is required for batched DeMemWM revisit retrieval")
|
| 67 |
+
pose_tensor = pose if torch.is_tensor(pose) else torch.as_tensor(pose, dtype=torch.float32)
|
| 68 |
+
if pose_tensor.ndim < 3 or pose_tensor.shape[-1] < 5:
|
| 69 |
+
raise ValueError(f"{name} must have shape (T,B,D) or (B,T,D) with D >= 5")
|
| 70 |
+
if pose_tensor.shape[0] == time and pose_tensor.shape[1] == batch:
|
| 71 |
+
pose_tb = pose_tensor
|
| 72 |
+
elif pose_tensor.shape[0] == batch and pose_tensor.shape[1] == time:
|
| 73 |
+
pose_tb = pose_tensor.transpose(0, 1)
|
| 74 |
+
else:
|
| 75 |
+
raise ValueError(f"{name} must match time/batch dimensions ({time},{batch})")
|
| 76 |
+
return pose_tb[..., :5].detach().to(dtype=torch.float32)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _time_batch_mask(
|
| 80 |
+
mask: torch.Tensor | None,
|
| 81 |
+
*,
|
| 82 |
+
time: int,
|
| 83 |
+
batch: int,
|
| 84 |
+
device: torch.device,
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
if mask is None:
|
| 87 |
+
return torch.ones((time, batch), device=device, dtype=torch.bool)
|
| 88 |
+
mask_tensor = mask if torch.is_tensor(mask) else torch.as_tensor(mask)
|
| 89 |
+
if mask_tensor.ndim != 2:
|
| 90 |
+
raise ValueError("source_candidate_mask must have shape (T,B) or (B,T)")
|
| 91 |
+
if mask_tensor.shape == (time, batch):
|
| 92 |
+
mask_tb = mask_tensor
|
| 93 |
+
elif mask_tensor.shape == (batch, time):
|
| 94 |
+
mask_tb = mask_tensor.transpose(0, 1)
|
| 95 |
+
else:
|
| 96 |
+
raise ValueError(f"source_candidate_mask must match time/batch dimensions ({time},{batch})")
|
| 97 |
+
return mask_tb.to(device=device, dtype=torch.bool)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _target_fov_points_batched(
|
| 101 |
+
target_poses: torch.Tensor,
|
| 102 |
+
*,
|
| 103 |
+
fov_half_h: float,
|
| 104 |
+
fov_half_v: float,
|
| 105 |
+
yaw_samples: int,
|
| 106 |
+
pitch_samples: int,
|
| 107 |
+
depth_samples: int,
|
| 108 |
+
radius: float,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
yaw_samples = max(1, int(yaw_samples))
|
| 111 |
+
pitch_samples = max(1, int(pitch_samples))
|
| 112 |
+
depth_samples = max(1, int(depth_samples))
|
| 113 |
+
device = target_poses.device
|
| 114 |
+
dtype = target_poses.dtype
|
| 115 |
+
if yaw_samples == 1:
|
| 116 |
+
yaw_offsets = torch.zeros((1,), device=device, dtype=dtype)
|
| 117 |
+
else:
|
| 118 |
+
yaw_offsets = torch.linspace(-float(fov_half_h), float(fov_half_h), yaw_samples + 2, device=device, dtype=dtype)[1:-1]
|
| 119 |
+
if pitch_samples == 1:
|
| 120 |
+
pitch_offsets = torch.zeros((1,), device=device, dtype=dtype)
|
| 121 |
+
else:
|
| 122 |
+
pitch_offsets = torch.linspace(-float(fov_half_v), float(fov_half_v), pitch_samples + 2, device=device, dtype=dtype)[1:-1]
|
| 123 |
+
if depth_samples == 1:
|
| 124 |
+
depths = torch.full((1,), float(radius), device=device, dtype=dtype)
|
| 125 |
+
else:
|
| 126 |
+
depths = torch.linspace(float(radius) / float(depth_samples), float(radius), depth_samples, device=device, dtype=dtype)
|
| 127 |
+
depth_grid, pitch_grid, yaw_grid = torch.meshgrid(depths, pitch_offsets, yaw_offsets, indexing="ij")
|
| 128 |
+
pitch_offsets_flat = pitch_grid.reshape(1, -1)
|
| 129 |
+
yaw_offsets_flat = yaw_grid.reshape(1, -1)
|
| 130 |
+
depth = depth_grid.reshape(1, -1)
|
| 131 |
+
pitch = torch.deg2rad(target_poses[:, 3:4] + pitch_offsets_flat)
|
| 132 |
+
yaw = torch.deg2rad(target_poses[:, 4:5] + yaw_offsets_flat)
|
| 133 |
+
cos_pitch = torch.cos(pitch)
|
| 134 |
+
vectors = torch.stack(
|
| 135 |
+
[
|
| 136 |
+
depth * cos_pitch * torch.sin(yaw),
|
| 137 |
+
depth * torch.sin(pitch),
|
| 138 |
+
depth * cos_pitch * torch.cos(yaw),
|
| 139 |
+
],
|
| 140 |
+
dim=-1,
|
| 141 |
+
)
|
| 142 |
+
return target_poses[:, None, :3] + vectors
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _inside_fov_3d_hv_batched(
|
| 146 |
+
points: torch.Tensor,
|
| 147 |
+
poses: torch.Tensor,
|
| 148 |
+
*,
|
| 149 |
+
fov_half_h: float,
|
| 150 |
+
fov_half_v: float,
|
| 151 |
+
) -> torch.Tensor:
|
| 152 |
+
vectors = points[:, None, :, :] - poses[:, :, None, :3]
|
| 153 |
+
x = vectors[..., 0]
|
| 154 |
+
y = vectors[..., 1]
|
| 155 |
+
z = vectors[..., 2]
|
| 156 |
+
azimuth = torch.atan2(x, z) * (180.0 / math.pi)
|
| 157 |
+
elevation = torch.atan2(y, torch.sqrt(x.square() + z.square()).clamp_min(1e-8)) * (180.0 / math.pi)
|
| 158 |
+
diff_azimuth = _angle_diff_degrees(azimuth - poses[:, :, None, 4])
|
| 159 |
+
diff_elevation = _angle_diff_degrees(elevation - poses[:, :, None, 3])
|
| 160 |
+
return (diff_azimuth < float(fov_half_h)) & (diff_elevation < float(fov_half_v))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _batched_tie_mask(mask: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
| 164 |
+
neg_inf = torch.full_like(values, -float("inf"))
|
| 165 |
+
best = torch.where(mask, values, neg_inf).max(dim=1).values
|
| 166 |
+
return mask & torch.isclose(values, best[:, None], rtol=0.0, atol=1e-12)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def batched_revisit_select_positions(
|
| 170 |
+
source_frame_indices: torch.Tensor,
|
| 171 |
+
source_pose: torch.Tensor,
|
| 172 |
+
target_frame_indices: torch.Tensor,
|
| 173 |
+
target_pose: torch.Tensor,
|
| 174 |
+
*,
|
| 175 |
+
source_candidate_mask: torch.Tensor | None = None,
|
| 176 |
+
topk: int = 2,
|
| 177 |
+
exclude_local_context_frames: int = 0,
|
| 178 |
+
fov_overlap_threshold: Optional[float] = 0.30,
|
| 179 |
+
plucker_weight: float = 0.1,
|
| 180 |
+
fov_half_h: float = 105.0 / 2.0,
|
| 181 |
+
fov_half_v: float = 75.0 / 2.0,
|
| 182 |
+
fov_yaw_samples: int = 25,
|
| 183 |
+
fov_pitch_samples: int = 20,
|
| 184 |
+
fov_depth_samples: int = 20,
|
| 185 |
+
fov_radius: float = 30.0,
|
| 186 |
+
plucker_grid_h: int = 4,
|
| 187 |
+
plucker_grid_w: int = 4,
|
| 188 |
+
plucker_focal_length: float = 0.35,
|
| 189 |
+
pose_preselect_topk: Optional[int] = 64,
|
| 190 |
+
query_chunk_size: int = 16,
|
| 191 |
+
) -> BatchedRevisitSelectionResult:
|
| 192 |
+
if source_frame_indices.ndim != 2:
|
| 193 |
+
raise ValueError("source_frame_indices must have shape (T_src,B)")
|
| 194 |
+
if target_frame_indices.ndim == 1:
|
| 195 |
+
target_frame_indices = target_frame_indices[:, None]
|
| 196 |
+
if target_frame_indices.ndim != 2:
|
| 197 |
+
raise ValueError("target_frame_indices must have shape (T_tgt,B)")
|
| 198 |
+
|
| 199 |
+
T_src, B = source_frame_indices.shape
|
| 200 |
+
T_tgt, B_tgt = target_frame_indices.shape
|
| 201 |
+
if B_tgt != B:
|
| 202 |
+
raise ValueError("source_frame_indices and target_frame_indices must share batch dimension")
|
| 203 |
+
|
| 204 |
+
source_pose_tensor = source_pose if torch.is_tensor(source_pose) else torch.as_tensor(source_pose, dtype=torch.float32)
|
| 205 |
+
target_pose_tensor = target_pose if torch.is_tensor(target_pose) else torch.as_tensor(target_pose, dtype=torch.float32)
|
| 206 |
+
device = source_pose_tensor.device
|
| 207 |
+
if target_pose_tensor.is_cuda:
|
| 208 |
+
device = target_pose_tensor.device
|
| 209 |
+
elif source_pose_tensor.is_cuda:
|
| 210 |
+
device = source_pose_tensor.device
|
| 211 |
+
source_frames_tb = source_frame_indices.to(device=device)
|
| 212 |
+
target_frames_tb = target_frame_indices.to(device=device)
|
| 213 |
+
source_pose_tb = _time_batch_pose_rows(source_pose_tensor.to(device=device), time=T_src, batch=B, name="source_pose")
|
| 214 |
+
target_pose_tb = _time_batch_pose_rows(target_pose_tensor.to(device=device), time=T_tgt, batch=B, name="target_pose")
|
| 215 |
+
candidate_mask_tb = _time_batch_mask(source_candidate_mask, time=T_src, batch=B, device=device)
|
| 216 |
+
|
| 217 |
+
topk = max(0, int(topk))
|
| 218 |
+
selected_positions = torch.full((B, T_tgt, topk), -1, device=device, dtype=torch.long)
|
| 219 |
+
selected_mask = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.bool)
|
| 220 |
+
selected_scores = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.float32)
|
| 221 |
+
selected_fov_overlap = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.float32)
|
| 222 |
+
selected_plucker_overlap = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.float32)
|
| 223 |
+
selected_gap_frames = torch.full((B, T_tgt, topk), -1.0, device=device, dtype=torch.float32)
|
| 224 |
+
best_fov = torch.zeros((B, T_tgt), device=device, dtype=torch.float32)
|
| 225 |
+
best_plucker = torch.zeros((B, T_tgt), device=device, dtype=torch.float32)
|
| 226 |
+
best_gap = torch.full((B, T_tgt), -1.0, device=device, dtype=torch.float32)
|
| 227 |
+
if topk == 0 or T_src == 0 or T_tgt == 0 or B == 0:
|
| 228 |
+
return BatchedRevisitSelectionResult(
|
| 229 |
+
selected_positions=selected_positions,
|
| 230 |
+
selected_mask=selected_mask,
|
| 231 |
+
selected_scores=selected_scores,
|
| 232 |
+
selected_fov_overlap=selected_fov_overlap,
|
| 233 |
+
selected_plucker_overlap=selected_plucker_overlap,
|
| 234 |
+
selected_gap_frames=selected_gap_frames,
|
| 235 |
+
best_selected_fov_overlap=best_fov,
|
| 236 |
+
best_selected_plucker_overlap=best_plucker,
|
| 237 |
+
best_selected_gap_frames=best_gap,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
source_pose_flat = source_pose_tb.reshape(-1, source_pose_tb.shape[-1])
|
| 241 |
+
source_forward_tb = _pose_forward(source_pose_flat).reshape(T_src, B, 3)
|
| 242 |
+
source_desc_tb = _plucker_descriptor(
|
| 243 |
+
source_pose_flat,
|
| 244 |
+
grid_h=plucker_grid_h,
|
| 245 |
+
grid_w=plucker_grid_w,
|
| 246 |
+
focal_length=plucker_focal_length,
|
| 247 |
+
).reshape(T_src, B, -1)
|
| 248 |
+
target_frames_flat = target_frames_tb.transpose(0, 1).contiguous().reshape(-1)
|
| 249 |
+
target_pose_flat = target_pose_tb.transpose(0, 1).contiguous().reshape(-1, target_pose_tb.shape[-1])
|
| 250 |
+
target_forward_flat = _pose_forward(target_pose_flat)
|
| 251 |
+
target_desc_flat = _plucker_descriptor(
|
| 252 |
+
target_pose_flat,
|
| 253 |
+
grid_h=plucker_grid_h,
|
| 254 |
+
grid_w=plucker_grid_w,
|
| 255 |
+
focal_length=plucker_focal_length,
|
| 256 |
+
)
|
| 257 |
+
batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt)
|
| 258 |
+
Q = int(target_frames_flat.numel())
|
| 259 |
+
chunk_size = max(1, int(query_chunk_size))
|
| 260 |
+
source_positions = torch.arange(T_src, device=device, dtype=torch.long)
|
| 261 |
+
pose_topk = None if pose_preselect_topk is None else int(pose_preselect_topk)
|
| 262 |
+
|
| 263 |
+
selected_positions_flat_view = selected_positions.reshape(-1, topk)
|
| 264 |
+
selected_mask_flat_view = selected_mask.reshape(-1, topk)
|
| 265 |
+
selected_scores_flat_view = selected_scores.reshape(-1, topk)
|
| 266 |
+
selected_fov_flat_view = selected_fov_overlap.reshape(-1, topk)
|
| 267 |
+
selected_plucker_flat_view = selected_plucker_overlap.reshape(-1, topk)
|
| 268 |
+
selected_gap_flat_view = selected_gap_frames.reshape(-1, topk)
|
| 269 |
+
best_fov_flat_view = best_fov.reshape(-1)
|
| 270 |
+
best_plucker_flat_view = best_plucker.reshape(-1)
|
| 271 |
+
best_gap_flat_view = best_gap.reshape(-1)
|
| 272 |
+
|
| 273 |
+
for start in range(0, Q, chunk_size):
|
| 274 |
+
end = min(Q, start + chunk_size)
|
| 275 |
+
b_idx = batch_ids[start:end]
|
| 276 |
+
target_frames = target_frames_flat[start:end]
|
| 277 |
+
target_poses = target_pose_flat[start:end]
|
| 278 |
+
q = int(end - start)
|
| 279 |
+
|
| 280 |
+
source_frames = source_frames_tb.index_select(1, b_idx).transpose(0, 1).contiguous()
|
| 281 |
+
source_candidates = candidate_mask_tb.index_select(1, b_idx).transpose(0, 1).contiguous()
|
| 282 |
+
score_valid = source_candidates & (source_frames < (target_frames[:, None] - int(exclude_local_context_frames)))
|
| 283 |
+
|
| 284 |
+
if pose_topk is None or pose_topk <= 0 or T_src <= pose_topk:
|
| 285 |
+
preselect_idx = source_positions.reshape(1, -1).expand(q, -1)
|
| 286 |
+
preselected_valid = score_valid
|
| 287 |
+
else:
|
| 288 |
+
source_poses_q = source_pose_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
|
| 289 |
+
source_forward_q = source_forward_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
|
| 290 |
+
translation_norm = torch.linalg.vector_norm(source_poses_q[:, :, :3] - target_poses[:, None, :3], dim=-1) / max(float(fov_radius), 1e-6)
|
| 291 |
+
dot = (
|
| 292 |
+
source_forward_q * target_forward_flat[start:end].reshape(q, 1, 3)
|
| 293 |
+
).sum(dim=-1).clamp(-1.0, 1.0)
|
| 294 |
+
pose_distance = translation_norm + (torch.acos(dot) / math.pi)
|
| 295 |
+
rank = (
|
| 296 |
+
pose_distance.to(dtype=torch.float64)
|
| 297 |
+
- source_frames.to(dtype=torch.float64) * 1e-12
|
| 298 |
+
+ source_frames.to(dtype=torch.float64) * 1e-15
|
| 299 |
+
)
|
| 300 |
+
rank = rank.masked_fill(~score_valid, float("inf"))
|
| 301 |
+
k_pre = min(max(1, pose_topk), T_src)
|
| 302 |
+
top = torch.topk(rank, k=k_pre, largest=False, sorted=True)
|
| 303 |
+
preselect_idx = top.indices
|
| 304 |
+
preselected_valid = torch.isfinite(top.values) & torch.gather(score_valid, 1, preselect_idx)
|
| 305 |
+
|
| 306 |
+
K = int(preselect_idx.shape[1])
|
| 307 |
+
if K == 0:
|
| 308 |
+
continue
|
| 309 |
+
|
| 310 |
+
gather_pose_idx = preselect_idx.unsqueeze(-1).expand(-1, -1, source_pose_tb.shape[-1])
|
| 311 |
+
source_poses_q = source_pose_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
|
| 312 |
+
selected_poses = torch.gather(source_poses_q, 1, gather_pose_idx)
|
| 313 |
+
selected_frames = torch.gather(source_frames, 1, preselect_idx)
|
| 314 |
+
|
| 315 |
+
points = _target_fov_points_batched(
|
| 316 |
+
target_poses,
|
| 317 |
+
fov_half_h=fov_half_h,
|
| 318 |
+
fov_half_v=fov_half_v,
|
| 319 |
+
yaw_samples=fov_yaw_samples,
|
| 320 |
+
pitch_samples=fov_pitch_samples,
|
| 321 |
+
depth_samples=fov_depth_samples,
|
| 322 |
+
radius=fov_radius,
|
| 323 |
+
)
|
| 324 |
+
inside = _inside_fov_3d_hv_batched(points, selected_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v)
|
| 325 |
+
fov_values = inside.float().mean(dim=2)
|
| 326 |
+
|
| 327 |
+
source_desc_q = source_desc_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
|
| 328 |
+
gather_desc_idx = preselect_idx.unsqueeze(-1).expand(-1, -1, source_desc_tb.shape[-1])
|
| 329 |
+
selected_desc = torch.gather(source_desc_q, 1, gather_desc_idx)
|
| 330 |
+
diff = selected_desc - target_desc_flat[start:end, None, :]
|
| 331 |
+
plucker_distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
|
| 332 |
+
plucker_values = 1.0 / (1.0 + plucker_distance.clamp_min(0.0))
|
| 333 |
+
|
| 334 |
+
valid_mask = preselected_valid
|
| 335 |
+
if fov_overlap_threshold is not None:
|
| 336 |
+
valid_mask = valid_mask & (fov_values >= float(fov_overlap_threshold))
|
| 337 |
+
|
| 338 |
+
remaining = valid_mask.clone()
|
| 339 |
+
covered = torch.zeros((q, inside.shape[2]), device=device, dtype=torch.bool)
|
| 340 |
+
chosen_positions = torch.full((q, topk), -1, device=device, dtype=torch.long)
|
| 341 |
+
chosen_scores = torch.zeros((q, topk), device=device, dtype=torch.float32)
|
| 342 |
+
chosen_fov = torch.zeros((q, topk), device=device, dtype=torch.float32)
|
| 343 |
+
chosen_plucker = torch.zeros((q, topk), device=device, dtype=torch.float32)
|
| 344 |
+
chosen_gap = torch.full((q, topk), -1.0, device=device, dtype=torch.float32)
|
| 345 |
+
row_idx = torch.arange(q, device=device, dtype=torch.long)
|
| 346 |
+
gap_values = target_frames[:, None] - selected_frames
|
| 347 |
+
for slot in range(topk):
|
| 348 |
+
active = remaining.any(dim=1)
|
| 349 |
+
gains = (inside & ~covered[:, None, :]).float().mean(dim=2)
|
| 350 |
+
tied = _batched_tie_mask(remaining, gains)
|
| 351 |
+
tied = _batched_tie_mask(tied, fov_values)
|
| 352 |
+
tied = _batched_tie_mask(tied, plucker_values * float(plucker_weight))
|
| 353 |
+
tied = _batched_tie_mask(tied, -gap_values.to(dtype=torch.float32))
|
| 354 |
+
tied = _batched_tie_mask(tied, -selected_frames.to(dtype=torch.float32))
|
| 355 |
+
best_idx = tied.to(dtype=torch.long).argmax(dim=1)
|
| 356 |
+
chosen_positions[active, slot] = preselect_idx[row_idx[active], best_idx[active]]
|
| 357 |
+
chosen_scores[active, slot] = gains[row_idx[active], best_idx[active]]
|
| 358 |
+
chosen_fov[active, slot] = fov_values[row_idx[active], best_idx[active]]
|
| 359 |
+
chosen_plucker[active, slot] = plucker_values[row_idx[active], best_idx[active]]
|
| 360 |
+
chosen_gap[active, slot] = gap_values[row_idx[active], best_idx[active]].to(dtype=torch.float32)
|
| 361 |
+
covered[active] = covered[active] | inside[row_idx[active], best_idx[active]]
|
| 362 |
+
remaining[row_idx[active], best_idx[active]] = False
|
| 363 |
+
|
| 364 |
+
chosen_mask = chosen_positions >= 0
|
| 365 |
+
chosen_rank = (
|
| 366 |
+
chosen_fov.to(dtype=torch.float64)
|
| 367 |
+
+ chosen_plucker.to(dtype=torch.float64) * 1e-9
|
| 368 |
+
- chosen_gap.to(dtype=torch.float64) * 1e-12
|
| 369 |
+
).masked_fill(~chosen_mask, -float("inf"))
|
| 370 |
+
has_choice = chosen_mask.any(dim=1)
|
| 371 |
+
best_slot = chosen_rank.argmax(dim=1)
|
| 372 |
+
best_fov_flat = torch.where(has_choice, chosen_fov[row_idx, best_slot], torch.zeros((q,), device=device, dtype=torch.float32))
|
| 373 |
+
best_plucker_flat = torch.where(has_choice, chosen_plucker[row_idx, best_slot], torch.zeros((q,), device=device, dtype=torch.float32))
|
| 374 |
+
best_gap_flat = torch.where(has_choice, chosen_gap[row_idx, best_slot], torch.full((q,), -1.0, device=device, dtype=torch.float32))
|
| 375 |
+
|
| 376 |
+
selected_positions_flat_view[start:end] = chosen_positions
|
| 377 |
+
selected_mask_flat_view[start:end] = chosen_mask
|
| 378 |
+
selected_scores_flat_view[start:end] = chosen_scores
|
| 379 |
+
selected_fov_flat_view[start:end] = chosen_fov
|
| 380 |
+
selected_plucker_flat_view[start:end] = chosen_plucker
|
| 381 |
+
selected_gap_flat_view[start:end] = chosen_gap
|
| 382 |
+
best_fov_flat_view[start:end] = best_fov_flat
|
| 383 |
+
best_plucker_flat_view[start:end] = best_plucker_flat
|
| 384 |
+
best_gap_flat_view[start:end] = best_gap_flat
|
| 385 |
+
|
| 386 |
+
return BatchedRevisitSelectionResult(
|
| 387 |
+
selected_positions=selected_positions,
|
| 388 |
+
selected_mask=selected_mask,
|
| 389 |
+
selected_scores=selected_scores,
|
| 390 |
+
selected_fov_overlap=selected_fov_overlap,
|
| 391 |
+
selected_plucker_overlap=selected_plucker_overlap,
|
| 392 |
+
selected_gap_frames=selected_gap_frames,
|
| 393 |
+
best_selected_fov_overlap=best_fov,
|
| 394 |
+
best_selected_plucker_overlap=best_plucker,
|
| 395 |
+
best_selected_gap_frames=best_gap,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None:
|
| 400 |
if int(record.frame_indices.numel()) != 1:
|
| 401 |
return None
|
|
|
|
| 421 |
plucker_grid_w: int,
|
| 422 |
plucker_focal_length: float,
|
| 423 |
pose_preselect_topk: Optional[int],
|
| 424 |
+
) -> list[RevisitCandidateLabel]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
if not records:
|
| 426 |
+
return []
|
| 427 |
|
| 428 |
target_poses = _pose_rows(target_pose)
|
| 429 |
if target_poses is None:
|
|
|
|
| 465 |
]
|
| 466 |
ranked.sort()
|
| 467 |
selected_indices = [idx for *_, idx in ranked[:topk]]
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long)
|
| 470 |
selected_records = [records[idx] for idx in selected_indices]
|
|
|
|
| 500 |
if fov_overlap_threshold is not None:
|
| 501 |
valid_mask = fov_values >= float(fov_overlap_threshold)
|
| 502 |
|
|
|
|
| 503 |
fov_list = [float(value) for value in fov_values.detach().cpu().tolist()]
|
| 504 |
plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()]
|
| 505 |
valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()]
|
|
|
|
| 524 |
best_frame_fov_overlap=fov_overlap,
|
| 525 |
)
|
| 526 |
)
|
| 527 |
+
return labels
|
| 528 |
|
| 529 |
|
| 530 |
def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float:
|
|
|
|
| 621 |
)
|
| 622 |
|
| 623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 624 |
def _record_with_selected_frame_metadata(
|
| 625 |
label: RevisitCandidateLabel,
|
| 626 |
*,
|
|
|
|
| 641 |
return replace(label.record, metadata=metadata)
|
| 642 |
|
| 643 |
|
| 644 |
+
|
| 645 |
+
def _record_frame_id(record: MemoryRecord) -> int:
|
| 646 |
+
return int(record.source_end) - 1
|
| 647 |
+
|
| 648 |
+
|
| 649 |
def deterministic_revisit_retrieval(
|
| 650 |
records: list[MemoryRecord],
|
| 651 |
target_frame: int,
|
|
|
|
| 679 |
for record in causal_records
|
| 680 |
if int(record.source_end) <= target_frame - exclude_local_context_frames
|
| 681 |
]
|
| 682 |
+
labels = _vectorized_frame_candidate_labels(
|
| 683 |
score_records,
|
| 684 |
target_frame=target_frame,
|
| 685 |
target_pose=target_pose,
|
|
|
|
| 695 |
plucker_focal_length=plucker_focal_length,
|
| 696 |
pose_preselect_topk=pose_preselect_topk,
|
| 697 |
)
|
|
|
|
| 698 |
valid_labels = [label for label in labels if label.valid]
|
| 699 |
+
selected_labels, selected_scores, _ = _select_greedy_coverage(
|
| 700 |
valid_labels,
|
| 701 |
topk=topk,
|
| 702 |
plucker_weight=float(plucker_weight),
|
| 703 |
)
|
| 704 |
best_selected = _best_selected_label(selected_labels)
|
|
|
|
| 705 |
best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap)
|
| 706 |
best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap)
|
| 707 |
best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target)
|
|
|
|
|
|
|
|
|
|
| 708 |
selected_records = [
|
| 709 |
_record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold))
|
| 710 |
for label in selected_labels
|
|
|
|
| 712 |
score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu")
|
| 713 |
scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device)
|
| 714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
return RevisitRetrievalResult(
|
| 716 |
records=selected_records,
|
| 717 |
scores=scores,
|
| 718 |
selected_frame_ids=[int(record.max_source_frame) for record in selected_records],
|
| 719 |
+
best_selected_fov_overlap=torch.as_tensor(best_selected_fov, dtype=torch.float32, device=score_device),
|
| 720 |
+
best_selected_plucker_overlap=torch.as_tensor(best_selected_plucker, dtype=torch.float32, device=score_device),
|
| 721 |
+
best_selected_gap_frames=torch.as_tensor(float(best_selected_gap), dtype=torch.float32, device=score_device),
|
| 722 |
)
|
algorithms/worldmem/dememwm/schedules.py
CHANGED
|
@@ -8,8 +8,6 @@ import torch
|
|
| 8 |
|
| 9 |
from .types import StreamGateState
|
| 10 |
|
| 11 |
-
NOISE_BUCKETS = ("high", "mid", "low")
|
| 12 |
-
NOISE_BUCKET_TO_ID = {name: idx for idx, name in enumerate(NOISE_BUCKETS)}
|
| 13 |
EVAL_ABLATION_BRANCHES = (
|
| 14 |
"memory_off",
|
| 15 |
"A_only",
|
|
@@ -41,46 +39,6 @@ def _clamp01(value: float) -> float:
|
|
| 41 |
return max(0.0, min(1.0, float(value)))
|
| 42 |
|
| 43 |
|
| 44 |
-
def noise_bucket_from_denoising_fraction(denoising_fraction: float | None) -> str:
|
| 45 |
-
if denoising_fraction is None:
|
| 46 |
-
return "mid"
|
| 47 |
-
frac = _clamp01(float(denoising_fraction))
|
| 48 |
-
if frac < (1.0 / 3.0):
|
| 49 |
-
return "high"
|
| 50 |
-
if frac < (2.0 / 3.0):
|
| 51 |
-
return "mid"
|
| 52 |
-
return "low"
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
def noise_bucket_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> str:
|
| 56 |
-
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 57 |
-
return "mid"
|
| 58 |
-
noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1))
|
| 59 |
-
if noise_fraction >= (2.0 / 3.0):
|
| 60 |
-
return "high"
|
| 61 |
-
if noise_fraction >= (1.0 / 3.0):
|
| 62 |
-
return "mid"
|
| 63 |
-
return "low"
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def noise_bucket_ids_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> torch.Tensor | None:
|
| 67 |
-
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 68 |
-
return None
|
| 69 |
-
noise_fraction = noise_levels.detach().float() / float(int(timesteps) - 1)
|
| 70 |
-
bucket_ids = torch.full_like(noise_levels, NOISE_BUCKET_TO_ID["mid"], dtype=torch.long)
|
| 71 |
-
bucket_ids = torch.where(
|
| 72 |
-
noise_fraction >= (2.0 / 3.0),
|
| 73 |
-
torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["high"]),
|
| 74 |
-
bucket_ids,
|
| 75 |
-
)
|
| 76 |
-
bucket_ids = torch.where(
|
| 77 |
-
noise_fraction < (1.0 / 3.0),
|
| 78 |
-
torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["low"]),
|
| 79 |
-
bucket_ids,
|
| 80 |
-
)
|
| 81 |
-
return bucket_ids
|
| 82 |
-
|
| 83 |
-
|
| 84 |
def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None:
|
| 85 |
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 86 |
return None
|
|
@@ -97,12 +55,6 @@ def normalize_eval_ablation_branch(branch: str | None) -> str:
|
|
| 97 |
return branch
|
| 98 |
|
| 99 |
|
| 100 |
-
def normalize_noise_bucket(noise_bucket: str | None) -> str:
|
| 101 |
-
if noise_bucket in NOISE_BUCKET_TO_ID:
|
| 102 |
-
return str(noise_bucket)
|
| 103 |
-
return "mid"
|
| 104 |
-
|
| 105 |
-
|
| 106 |
_STAGE_ENABLES = {
|
| 107 |
'stage_1': (True, True, True),
|
| 108 |
'stage_2': (True, True, True),
|
|
@@ -131,22 +83,6 @@ class CurriculumState:
|
|
| 131 |
def dit_full_trainable(self) -> bool:
|
| 132 |
return self.dit_train_state == "full"
|
| 133 |
|
| 134 |
-
def diagnostics(self) -> dict[str, Any]:
|
| 135 |
-
return {
|
| 136 |
-
"dememwm_global_step": self.global_step,
|
| 137 |
-
"dememwm_curriculum_enabled": self.enabled,
|
| 138 |
-
"dememwm_stage": self.stage,
|
| 139 |
-
"curriculum_anchor_enabled": self.anchor_enabled,
|
| 140 |
-
"curriculum_dynamic_enabled": self.dynamic_enabled,
|
| 141 |
-
"curriculum_revisit_enabled": self.revisit_enabled,
|
| 142 |
-
"dit_train_state": self.dit_train_state,
|
| 143 |
-
"dit_full_trainable": self.dit_full_trainable,
|
| 144 |
-
"freeze_vae": self.freeze_vae,
|
| 145 |
-
"lr_dememwm_modules": self.dememwm_lr,
|
| 146 |
-
"lr_memory_adapters": self.memory_adapter_lr,
|
| 147 |
-
"lr_full_dit": self.full_dit_lr,
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
|
| 151 |
def _cfg_get(obj: Any, name: str, default: Any) -> Any:
|
| 152 |
return getattr(obj, name, default) if obj is not None else default
|
|
@@ -209,9 +145,7 @@ DeMemWMCurriculumState = CurriculumState
|
|
| 209 |
resolve_dememwm_curriculum = resolve_curriculum
|
| 210 |
|
| 211 |
|
| 212 |
-
def compute_stream_gates(stage: str, denoising_fraction: float | None = None,
|
| 213 |
-
if debug_force_all_streams:
|
| 214 |
-
return StreamGateState(True, True, True, float(anchor_gate), float(dynamic_gate), float(revisit_gate), stage, "debug_force_all_streams")
|
| 215 |
if stage not in _STAGE_ENABLES:
|
| 216 |
raise ValueError(f"unknown DeMemWM stage: {stage}")
|
| 217 |
a_on, d_on, r_on = _STAGE_ENABLES[stage]
|
|
|
|
| 8 |
|
| 9 |
from .types import StreamGateState
|
| 10 |
|
|
|
|
|
|
|
| 11 |
EVAL_ABLATION_BRANCHES = (
|
| 12 |
"memory_off",
|
| 13 |
"A_only",
|
|
|
|
| 39 |
return max(0.0, min(1.0, float(value)))
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None:
|
| 43 |
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 44 |
return None
|
|
|
|
| 55 |
return branch
|
| 56 |
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
_STAGE_ENABLES = {
|
| 59 |
'stage_1': (True, True, True),
|
| 60 |
'stage_2': (True, True, True),
|
|
|
|
| 83 |
def dit_full_trainable(self) -> bool:
|
| 84 |
return self.dit_train_state == "full"
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
def _cfg_get(obj: Any, name: str, default: Any) -> Any:
|
| 88 |
return getattr(obj, name, default) if obj is not None else default
|
|
|
|
| 145 |
resolve_dememwm_curriculum = resolve_curriculum
|
| 146 |
|
| 147 |
|
| 148 |
+
def compute_stream_gates(stage: str, denoising_fraction: float | None = None, anchor_gate: float = 1.0, dynamic_gate: float = 1.0, revisit_gate: float = 1.0) -> StreamGateState:
|
|
|
|
|
|
|
| 149 |
if stage not in _STAGE_ENABLES:
|
| 150 |
raise ValueError(f"unknown DeMemWM stage: {stage}")
|
| 151 |
a_on, d_on, r_on = _STAGE_ENABLES[stage]
|
algorithms/worldmem/dememwm/types.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
from dataclasses import dataclass, field
|
|
@@ -74,8 +73,9 @@ class MemoryStreamTensors:
|
|
| 74 |
revisit_gate: torch.Tensor | float
|
| 75 |
revisit_gate_raw: torch.Tensor | None = None
|
| 76 |
valid_revisit_mask: torch.Tensor | None = None
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
@dataclass(frozen=True)
|
|
@@ -95,4 +95,6 @@ class RevisitRetrievalResult:
|
|
| 95 |
records: list[MemoryRecord]
|
| 96 |
scores: torch.Tensor
|
| 97 |
selected_frame_ids: list[int]
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from dataclasses import dataclass, field
|
|
|
|
| 73 |
revisit_gate: torch.Tensor | float
|
| 74 |
revisit_gate_raw: torch.Tensor | None = None
|
| 75 |
valid_revisit_mask: torch.Tensor | None = None
|
| 76 |
+
revisit_best_selected_fov_overlap: torch.Tensor | None = None
|
| 77 |
+
revisit_best_selected_plucker_overlap: torch.Tensor | None = None
|
| 78 |
+
revisit_selected_gap_frames: torch.Tensor | None = None
|
| 79 |
|
| 80 |
|
| 81 |
@dataclass(frozen=True)
|
|
|
|
| 95 |
records: list[MemoryRecord]
|
| 96 |
scores: torch.Tensor
|
| 97 |
selected_frame_ids: list[int]
|
| 98 |
+
best_selected_fov_overlap: torch.Tensor
|
| 99 |
+
best_selected_plucker_overlap: torch.Tensor
|
| 100 |
+
best_selected_gap_frames: torch.Tensor
|
algorithms/worldmem/models/dit.py
CHANGED
|
@@ -166,12 +166,6 @@ class MemoryTokenCrossAttention(nn.Module):
|
|
| 166 |
self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
|
| 167 |
self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
|
| 168 |
self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
|
| 169 |
-
self.last_gate_mean = None
|
| 170 |
-
self.last_delta_ratio = None
|
| 171 |
-
self.last_valid_fraction = None
|
| 172 |
-
self.last_type_gate_mean = None
|
| 173 |
-
for type_name in MEMORY_TYPE_NAMES[:num_memory_types]:
|
| 174 |
-
setattr(self, f"last_type_gate_{type_name}_mean", None)
|
| 175 |
nn.init.normal_(self.memory_type_embed.weight, std=0.02)
|
| 176 |
self.reset_identity_init()
|
| 177 |
|
|
@@ -236,19 +230,11 @@ class MemoryTokenCrossAttention(nn.Module):
|
|
| 236 |
type_scale = type_scale.unsqueeze(0)
|
| 237 |
return memory_tokens * type_scale + type_embed
|
| 238 |
|
| 239 |
-
def _store_type_gate_diagnostics(self, stage_gate):
|
| 240 |
-
with torch.no_grad():
|
| 241 |
-
detached = stage_gate.detach().float()
|
| 242 |
-
self.last_type_gate_mean = detached.mean()
|
| 243 |
-
for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]):
|
| 244 |
-
setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean())
|
| 245 |
-
|
| 246 |
def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
|
| 247 |
if memory_type_ids is None:
|
| 248 |
return None
|
| 249 |
memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
|
| 250 |
stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
|
| 251 |
-
self._store_type_gate_diagnostics(stage_gate)
|
| 252 |
if memory_tokens.dim() == 4:
|
| 253 |
batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
|
| 254 |
if memory_type_ids.dim() == 1:
|
|
@@ -329,33 +315,6 @@ class MemoryTokenCrossAttention(nn.Module):
|
|
| 329 |
gate_tensor = gate_tensor.unsqueeze(-1)
|
| 330 |
return gate_tensor
|
| 331 |
|
| 332 |
-
def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows):
|
| 333 |
-
with torch.no_grad():
|
| 334 |
-
batch_size, num_frames = base.shape[:2]
|
| 335 |
-
gate_values = torch.cat(
|
| 336 |
-
[gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()],
|
| 337 |
-
dim=-1,
|
| 338 |
-
)
|
| 339 |
-
gate_mask = self._gate_valid_mask(
|
| 340 |
-
valid_rows,
|
| 341 |
-
batch_size,
|
| 342 |
-
num_frames,
|
| 343 |
-
dtype=gate_values.dtype,
|
| 344 |
-
device=gate_values.device,
|
| 345 |
-
)
|
| 346 |
-
if gate_mask is not None:
|
| 347 |
-
gate_values = gate_values * gate_mask
|
| 348 |
-
self.last_valid_fraction = valid_rows.detach().float().mean()
|
| 349 |
-
valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0)
|
| 350 |
-
self.last_gate_mean = gate_values.sum() / valid_count
|
| 351 |
-
else:
|
| 352 |
-
self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32)
|
| 353 |
-
self.last_gate_mean = gate_values.mean()
|
| 354 |
-
|
| 355 |
-
delta_norm = (output.detach().float() - base.detach().float()).norm()
|
| 356 |
-
base_norm = base.detach().float().norm()
|
| 357 |
-
self.last_delta_ratio = delta_norm / (base_norm + 1e-6)
|
| 358 |
-
|
| 359 |
def forward(
|
| 360 |
self,
|
| 361 |
x,
|
|
@@ -437,7 +396,6 @@ class MemoryTokenCrossAttention(nn.Module):
|
|
| 437 |
if residual_gate_tensor is not None:
|
| 438 |
mlp_delta = mlp_delta * residual_gate_tensor
|
| 439 |
output = output + mlp_delta
|
| 440 |
-
self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows)
|
| 441 |
if return_delta:
|
| 442 |
return attn_delta + mlp_delta
|
| 443 |
return output
|
|
@@ -767,38 +725,6 @@ class DiT(nn.Module):
|
|
| 767 |
if memory_adapter is not None:
|
| 768 |
memory_adapter.reset_identity_init()
|
| 769 |
|
| 770 |
-
def memory_adapter_delta_diagnostics(self):
|
| 771 |
-
diagnostics = {}
|
| 772 |
-
ratios = []
|
| 773 |
-
type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES}
|
| 774 |
-
shared_type_gate_values = []
|
| 775 |
-
for block in self.blocks:
|
| 776 |
-
adapter = getattr(block, "memory_token_cross_attn", None)
|
| 777 |
-
if adapter is None:
|
| 778 |
-
continue
|
| 779 |
-
ratio = getattr(adapter, "last_delta_ratio", None)
|
| 780 |
-
if ratio is not None:
|
| 781 |
-
ratios.append(torch.as_tensor(ratio).detach().float())
|
| 782 |
-
type_gate = getattr(adapter, "last_type_gate_mean", None)
|
| 783 |
-
if type_gate is not None:
|
| 784 |
-
shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float())
|
| 785 |
-
for type_name in MEMORY_TYPE_NAMES:
|
| 786 |
-
value = getattr(adapter, f"last_type_gate_{type_name}_mean", None)
|
| 787 |
-
if value is not None:
|
| 788 |
-
type_gate_values[type_name].append(torch.as_tensor(value).detach().float())
|
| 789 |
-
if ratios:
|
| 790 |
-
values = torch.stack(ratios)
|
| 791 |
-
diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item())
|
| 792 |
-
diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item())
|
| 793 |
-
if shared_type_gate_values:
|
| 794 |
-
values = torch.stack(shared_type_gate_values)
|
| 795 |
-
diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item())
|
| 796 |
-
for type_name, values_list in type_gate_values.items():
|
| 797 |
-
if values_list:
|
| 798 |
-
values = torch.stack(values_list)
|
| 799 |
-
diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item())
|
| 800 |
-
return diagnostics
|
| 801 |
-
|
| 802 |
def unpatchify(self, x):
|
| 803 |
"""
|
| 804 |
x: (N, H, W, patch_size**2 * C)
|
|
|
|
| 166 |
self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
|
| 167 |
self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
|
| 168 |
self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
nn.init.normal_(self.memory_type_embed.weight, std=0.02)
|
| 170 |
self.reset_identity_init()
|
| 171 |
|
|
|
|
| 230 |
type_scale = type_scale.unsqueeze(0)
|
| 231 |
return memory_tokens * type_scale + type_embed
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
|
| 234 |
if memory_type_ids is None:
|
| 235 |
return None
|
| 236 |
memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
|
| 237 |
stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
|
|
|
|
| 238 |
if memory_tokens.dim() == 4:
|
| 239 |
batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
|
| 240 |
if memory_type_ids.dim() == 1:
|
|
|
|
| 315 |
gate_tensor = gate_tensor.unsqueeze(-1)
|
| 316 |
return gate_tensor
|
| 317 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
def forward(
|
| 319 |
self,
|
| 320 |
x,
|
|
|
|
| 396 |
if residual_gate_tensor is not None:
|
| 397 |
mlp_delta = mlp_delta * residual_gate_tensor
|
| 398 |
output = output + mlp_delta
|
|
|
|
| 399 |
if return_delta:
|
| 400 |
return attn_delta + mlp_delta
|
| 401 |
return output
|
|
|
|
| 725 |
if memory_adapter is not None:
|
| 726 |
memory_adapter.reset_identity_init()
|
| 727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
def unpatchify(self, x):
|
| 729 |
"""
|
| 730 |
x: (N, H, W, patch_size**2 * C)
|
configurations/algorithm/dememwm_memory_dit.yaml
CHANGED
|
@@ -15,7 +15,6 @@ log_video: false
|
|
| 15 |
dememwm:
|
| 16 |
enabled: true
|
| 17 |
training_stage: stage_1 # fallback only when curriculum.enabled=false
|
| 18 |
-
debug_force_all_streams: false
|
| 19 |
curriculum:
|
| 20 |
enabled: true
|
| 21 |
full_stage_start_step: 60000
|
|
@@ -66,8 +65,6 @@ dememwm:
|
|
| 66 |
plucker_focal_length: 0.35
|
| 67 |
compress:
|
| 68 |
downsample_ratio: 4
|
| 69 |
-
stage_policy:
|
| 70 |
-
noise_bucket_logging: true
|
| 71 |
eval_ablation:
|
| 72 |
enabled: false
|
| 73 |
branch: A_plus_D_plus_R_normal
|
|
|
|
| 15 |
dememwm:
|
| 16 |
enabled: true
|
| 17 |
training_stage: stage_1 # fallback only when curriculum.enabled=false
|
|
|
|
| 18 |
curriculum:
|
| 19 |
enabled: true
|
| 20 |
full_stage_start_step: 60000
|
|
|
|
| 65 |
plucker_focal_length: 0.35
|
| 66 |
compress:
|
| 67 |
downsample_ratio: 4
|
|
|
|
|
|
|
| 68 |
eval_ablation:
|
| 69 |
enabled: false
|
| 70 |
branch: A_plus_D_plus_R_normal
|
scripts/dememwm_full_eval.slurm
CHANGED
|
@@ -118,7 +118,6 @@ EVAL_ARGS=(
|
|
| 118 |
"++algorithm.context_frames=${CONTEXT_FRAMES}"
|
| 119 |
"++algorithm.log_video=${LOG_VIDEO}"
|
| 120 |
"++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}"
|
| 121 |
-
"++algorithm.dememwm.debug_force_all_streams=false"
|
| 122 |
"++algorithm.dememwm.training_stage=stage_2"
|
| 123 |
"++algorithm.dememwm.anchor.enabled=true"
|
| 124 |
"++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]"
|
|
@@ -139,7 +138,6 @@ EVAL_ARGS=(
|
|
| 139 |
"++algorithm.dememwm.revisit.plucker_weight=0.10"
|
| 140 |
"++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}"
|
| 141 |
"++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}"
|
| 142 |
-
"++algorithm.dememwm.stage_policy.noise_bucket_logging=true"
|
| 143 |
"++algorithm.dememwm.eval_ablation.enabled=true"
|
| 144 |
"++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}"
|
| 145 |
"++algorithm.dememwm.cache.enabled=true"
|
|
|
|
| 118 |
"++algorithm.context_frames=${CONTEXT_FRAMES}"
|
| 119 |
"++algorithm.log_video=${LOG_VIDEO}"
|
| 120 |
"++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}"
|
|
|
|
| 121 |
"++algorithm.dememwm.training_stage=stage_2"
|
| 122 |
"++algorithm.dememwm.anchor.enabled=true"
|
| 123 |
"++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]"
|
|
|
|
| 138 |
"++algorithm.dememwm.revisit.plucker_weight=0.10"
|
| 139 |
"++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}"
|
| 140 |
"++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}"
|
|
|
|
| 141 |
"++algorithm.dememwm.eval_ablation.enabled=true"
|
| 142 |
"++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}"
|
| 143 |
"++algorithm.dememwm.cache.enabled=true"
|
scripts/dememwm_full_train.slurm
CHANGED
|
@@ -52,7 +52,6 @@ srun python -m main \
|
|
| 52 |
++algorithm.context_frames=100 \
|
| 53 |
++algorithm.log_video=true \
|
| 54 |
++algorithm.diffusion.sampling_timesteps=20 \
|
| 55 |
-
++algorithm.dememwm.debug_force_all_streams=false \
|
| 56 |
++algorithm.dememwm.generated_history_proxy.enabled=true \
|
| 57 |
++algorithm.dememwm.generated_history_proxy.start_step=40000 \
|
| 58 |
++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
|
|
@@ -77,7 +76,6 @@ srun python -m main \
|
|
| 77 |
++algorithm.dememwm.revisit.plucker_weight=0.10 \
|
| 78 |
++algorithm.dememwm.revisit.max_frames=2 \
|
| 79 |
++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
|
| 80 |
-
++algorithm.dememwm.stage_policy.noise_bucket_logging=true \
|
| 81 |
++algorithm.dememwm.cache.enabled=true \
|
| 82 |
++algorithm.dememwm.cache.device=cpu \
|
| 83 |
++algorithm.dememwm.cache.keep_raw_latents=all \
|
|
|
|
| 52 |
++algorithm.context_frames=100 \
|
| 53 |
++algorithm.log_video=true \
|
| 54 |
++algorithm.diffusion.sampling_timesteps=20 \
|
|
|
|
| 55 |
++algorithm.dememwm.generated_history_proxy.enabled=true \
|
| 56 |
++algorithm.dememwm.generated_history_proxy.start_step=40000 \
|
| 57 |
++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
|
|
|
|
| 76 |
++algorithm.dememwm.revisit.plucker_weight=0.10 \
|
| 77 |
++algorithm.dememwm.revisit.max_frames=2 \
|
| 78 |
++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
|
|
|
|
| 79 |
++algorithm.dememwm.cache.enabled=true \
|
| 80 |
++algorithm.dememwm.cache.device=cpu \
|
| 81 |
++algorithm.dememwm.cache.keep_raw_latents=all \
|
tests/test_dememwm_compression.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from dememwm_import_helper import install_dememwm_namespace
|
| 3 |
|
| 4 |
install_dememwm_namespace()
|
|
@@ -18,16 +19,77 @@ def small_compressor(**kwargs):
|
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def test_dynamic_compressor_shapes_and_budget():
|
| 22 |
comp = small_compressor(exclude_latest_local_frames=0)
|
| 23 |
latents = torch.randn(4, 2, 3, 2, 2)
|
| 24 |
frame_indices = torch.arange(4)[:, None].repeat(1, 2)
|
| 25 |
target = torch.tensor([[1, 2], [4, 4]])
|
| 26 |
-
tokens, mask
|
| 27 |
assert tokens.shape == (2, 2, 2, 8)
|
| 28 |
assert mask.shape == (2, 2, 2)
|
| 29 |
assert mask[0, 0].any()
|
| 30 |
-
assert
|
| 31 |
|
| 32 |
|
| 33 |
def test_dynamic_compressor_abstains_without_old_enough_sources():
|
|
@@ -35,23 +97,23 @@ def test_dynamic_compressor_abstains_without_old_enough_sources():
|
|
| 35 |
latents = torch.randn(2, 1, 3, 2, 2)
|
| 36 |
frame_indices = torch.tensor([[5], [6]])
|
| 37 |
target = torch.tensor([[8]])
|
| 38 |
-
tokens, mask
|
| 39 |
assert tokens.shape == (1, 1, 2, 8)
|
|
|
|
| 40 |
assert not mask.any()
|
| 41 |
-
assert diag["max_source_frame"].item() == -1
|
| 42 |
-
assert diag["dynamic_min_gap_to_target_per_target"].item() == -1
|
| 43 |
|
| 44 |
|
| 45 |
-
def
|
| 46 |
comp = small_compressor(exclude_latest_local_frames=0)
|
| 47 |
latents = torch.randn(3, 1, 3, 2, 2)
|
| 48 |
frame_indices = torch.tensor([[0], [2], [5]])
|
| 49 |
generated = torch.tensor([[False], [True], [True]])
|
| 50 |
target = torch.tensor([[3]])
|
| 51 |
-
|
|
|
|
| 52 |
assert mask.any()
|
| 53 |
-
assert
|
| 54 |
-
assert
|
| 55 |
|
| 56 |
|
| 57 |
def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
|
|
@@ -59,13 +121,12 @@ def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
|
|
| 59 |
latents = torch.randn(5, 1, 3, 2, 2)
|
| 60 |
frame_indices = torch.tensor([[0], [1], [2], [3], [4]])
|
| 61 |
target = torch.tensor([[5]])
|
| 62 |
-
tokens, mask
|
|
|
|
| 63 |
assert tokens.shape == (1, 1, 2, 8)
|
| 64 |
assert mask.any()
|
| 65 |
-
assert
|
| 66 |
-
assert
|
| 67 |
-
assert diag["dynamic_max_gap_to_target_per_target"].item() == 5
|
| 68 |
-
assert diag["dynamic_exclude_latest_local_frames"] == 2
|
| 69 |
|
| 70 |
|
| 71 |
def test_cache_materialize_raw_latents_excludes_c_short_overlap():
|
|
@@ -91,7 +152,7 @@ def test_dynamic_compressor_preserves_grad_to_trainable_parts():
|
|
| 91 |
latents = torch.randn(4, 1, 3, 2, 2)
|
| 92 |
frame_indices = torch.arange(4)[:, None]
|
| 93 |
target = torch.tensor([[4]])
|
| 94 |
-
tokens, mask
|
| 95 |
assert mask.any()
|
| 96 |
tokens.square().sum().backward()
|
| 97 |
grads = [
|
|
@@ -107,6 +168,54 @@ def test_dynamic_compressor_selects_only_recent_valid_sources():
|
|
| 107 |
latents = torch.randn(20, 1, 3, 2, 2)
|
| 108 |
frame_indices = torch.arange(20)[:, None]
|
| 109 |
target = torch.tensor([[10]])
|
| 110 |
-
|
|
|
|
| 111 |
assert mask.any()
|
| 112 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
from dememwm_import_helper import install_dememwm_namespace
|
| 4 |
|
| 5 |
install_dememwm_namespace()
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
| 22 |
+
def legacy_dynamic_forward(comp, latents, frame_indices, target_frame_indices, source_is_generated=None, exclude_latest_local_frames=None):
|
| 23 |
+
del source_is_generated
|
| 24 |
+
exclude_latest_local_frames = (
|
| 25 |
+
comp.exclude_latest_local_frames
|
| 26 |
+
if exclude_latest_local_frames is None
|
| 27 |
+
else int(exclude_latest_local_frames)
|
| 28 |
+
)
|
| 29 |
+
T_src, B, C, H, W = latents.shape
|
| 30 |
+
if target_frame_indices.ndim == 1:
|
| 31 |
+
target_frame_indices = target_frame_indices[:, None].expand(-1, B)
|
| 32 |
+
T_tgt = target_frame_indices.shape[0]
|
| 33 |
+
device = latents.device
|
| 34 |
+
n_spatial = (H // comp.patch_size) * (W // comp.patch_size)
|
| 35 |
+
T_out = comp._temporal_output_count()
|
| 36 |
+
num_slots = T_out * n_spatial
|
| 37 |
+
output_time_idx = comp._output_time_indices(device)
|
| 38 |
+
output_rows, mask_rows = [], []
|
| 39 |
+
for b in range(B):
|
| 40 |
+
src_frames_b = frame_indices[:, b]
|
| 41 |
+
tgt_outputs, tgt_masks = [], []
|
| 42 |
+
for j in range(T_tgt):
|
| 43 |
+
target = int(target_frame_indices[j, b].item())
|
| 44 |
+
valid_idx = (src_frames_b < target - exclude_latest_local_frames).nonzero(as_tuple=False).flatten()
|
| 45 |
+
if valid_idx.numel() == 0:
|
| 46 |
+
tgt_outputs.append(latents.new_zeros(num_slots, comp.dit_hidden_size))
|
| 47 |
+
tgt_masks.append(torch.zeros(num_slots, device=device, dtype=torch.bool))
|
| 48 |
+
continue
|
| 49 |
+
selected_frames = src_frames_b.index_select(0, valid_idx)
|
| 50 |
+
order = torch.argsort(selected_frames)
|
| 51 |
+
valid_idx = valid_idx.index_select(0, order)[-comp.max_source_frames:]
|
| 52 |
+
chunk = latents[valid_idx, b]
|
| 53 |
+
real_mask = torch.ones((chunk.shape[0],), device=device, dtype=torch.bool)
|
| 54 |
+
if chunk.shape[0] < comp.max_source_frames:
|
| 55 |
+
pad = chunk.new_zeros(comp.max_source_frames - chunk.shape[0], C, H, W)
|
| 56 |
+
chunk = torch.cat([pad, chunk], dim=0)
|
| 57 |
+
real_mask = torch.cat([
|
| 58 |
+
torch.zeros((pad.shape[0],), device=device, dtype=torch.bool),
|
| 59 |
+
real_mask,
|
| 60 |
+
])
|
| 61 |
+
inp = chunk.clone()
|
| 62 |
+
inp[1:] = chunk[1:] - chunk[:-1]
|
| 63 |
+
x = inp.permute(1, 0, 2, 3).unsqueeze(0)
|
| 64 |
+
x = F.pad(x, (0, 0, 0, 0, comp.causal_pad, 0))
|
| 65 |
+
x = comp.conv3d(x)
|
| 66 |
+
x = x.squeeze(0).permute(1, 2, 3, 0)
|
| 67 |
+
x = comp.out_norm(x)
|
| 68 |
+
tokens = x.reshape(num_slots, comp.dit_hidden_size)
|
| 69 |
+
clamped_time_idx = output_time_idx.clamp(min=0, max=comp.max_source_frames - 1)
|
| 70 |
+
temporal_mask = (
|
| 71 |
+
(output_time_idx >= 0)
|
| 72 |
+
& (output_time_idx < comp.max_source_frames)
|
| 73 |
+
& real_mask.index_select(0, clamped_time_idx)
|
| 74 |
+
)
|
| 75 |
+
mask = temporal_mask[:, None].expand(T_out, n_spatial).reshape(num_slots)
|
| 76 |
+
tgt_outputs.append(tokens)
|
| 77 |
+
tgt_masks.append(mask)
|
| 78 |
+
output_rows.append(torch.stack(tgt_outputs))
|
| 79 |
+
mask_rows.append(torch.stack(tgt_masks))
|
| 80 |
+
return torch.stack(output_rows), torch.stack(mask_rows)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
def test_dynamic_compressor_shapes_and_budget():
|
| 84 |
comp = small_compressor(exclude_latest_local_frames=0)
|
| 85 |
latents = torch.randn(4, 2, 3, 2, 2)
|
| 86 |
frame_indices = torch.arange(4)[:, None].repeat(1, 2)
|
| 87 |
target = torch.tensor([[1, 2], [4, 4]])
|
| 88 |
+
tokens, mask = comp(latents, frame_indices, None, target)
|
| 89 |
assert tokens.shape == (2, 2, 2, 8)
|
| 90 |
assert mask.shape == (2, 2, 2)
|
| 91 |
assert mask[0, 0].any()
|
| 92 |
+
assert mask.sum(dim=-1).max().item() <= tokens.shape[2]
|
| 93 |
|
| 94 |
|
| 95 |
def test_dynamic_compressor_abstains_without_old_enough_sources():
|
|
|
|
| 97 |
latents = torch.randn(2, 1, 3, 2, 2)
|
| 98 |
frame_indices = torch.tensor([[5], [6]])
|
| 99 |
target = torch.tensor([[8]])
|
| 100 |
+
tokens, mask = comp(latents, frame_indices, None, target)
|
| 101 |
assert tokens.shape == (1, 1, 2, 8)
|
| 102 |
+
assert not tokens.any()
|
| 103 |
assert not mask.any()
|
|
|
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
+
def test_dynamic_compressor_ignores_generated_flags_and_excludes_future_sources():
|
| 107 |
comp = small_compressor(exclude_latest_local_frames=0)
|
| 108 |
latents = torch.randn(3, 1, 3, 2, 2)
|
| 109 |
frame_indices = torch.tensor([[0], [2], [5]])
|
| 110 |
generated = torch.tensor([[False], [True], [True]])
|
| 111 |
target = torch.tensor([[3]])
|
| 112 |
+
tokens, mask = comp(latents, frame_indices, None, target, generated)
|
| 113 |
+
expected_tokens, expected_mask = legacy_dynamic_forward(comp, latents, frame_indices, target, generated)
|
| 114 |
assert mask.any()
|
| 115 |
+
assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
|
| 116 |
+
assert torch.equal(mask, expected_mask)
|
| 117 |
|
| 118 |
|
| 119 |
def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
|
|
|
|
| 121 |
latents = torch.randn(5, 1, 3, 2, 2)
|
| 122 |
frame_indices = torch.tensor([[0], [1], [2], [3], [4]])
|
| 123 |
target = torch.tensor([[5]])
|
| 124 |
+
tokens, mask = comp(latents, frame_indices, None, target)
|
| 125 |
+
expected_tokens, expected_mask = legacy_dynamic_forward(comp, latents, frame_indices, target)
|
| 126 |
assert tokens.shape == (1, 1, 2, 8)
|
| 127 |
assert mask.any()
|
| 128 |
+
assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
|
| 129 |
+
assert torch.equal(mask, expected_mask)
|
|
|
|
|
|
|
| 130 |
|
| 131 |
|
| 132 |
def test_cache_materialize_raw_latents_excludes_c_short_overlap():
|
|
|
|
| 152 |
latents = torch.randn(4, 1, 3, 2, 2)
|
| 153 |
frame_indices = torch.arange(4)[:, None]
|
| 154 |
target = torch.tensor([[4]])
|
| 155 |
+
tokens, mask = comp(latents, frame_indices, None, target)
|
| 156 |
assert mask.any()
|
| 157 |
tokens.square().sum().backward()
|
| 158 |
grads = [
|
|
|
|
| 168 |
latents = torch.randn(20, 1, 3, 2, 2)
|
| 169 |
frame_indices = torch.arange(20)[:, None]
|
| 170 |
target = torch.tensor([[10]])
|
| 171 |
+
tokens, mask = comp(latents, frame_indices, None, target)
|
| 172 |
+
expected_tokens, expected_mask = legacy_dynamic_forward(comp, latents, frame_indices, target)
|
| 173 |
assert mask.any()
|
| 174 |
+
assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
|
| 175 |
+
assert torch.equal(mask, expected_mask)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def test_dynamic_compressor_batched_matches_legacy_loop():
|
| 179 |
+
torch.manual_seed(11)
|
| 180 |
+
comp = small_compressor(exclude_latest_local_frames=2)
|
| 181 |
+
latents = torch.randn(6, 3, 3, 2, 2)
|
| 182 |
+
frame_indices = torch.tensor([
|
| 183 |
+
[0, 4, 1],
|
| 184 |
+
[3, 1, 5],
|
| 185 |
+
[7, 8, 9],
|
| 186 |
+
[2, 6, 11],
|
| 187 |
+
[12, 3, 13],
|
| 188 |
+
[15, 10, 4],
|
| 189 |
+
])
|
| 190 |
+
generated = torch.tensor([
|
| 191 |
+
[False, True, False],
|
| 192 |
+
[True, False, False],
|
| 193 |
+
[False, True, True],
|
| 194 |
+
[True, True, False],
|
| 195 |
+
[False, False, True],
|
| 196 |
+
[True, False, True],
|
| 197 |
+
])
|
| 198 |
+
targets = torch.tensor([
|
| 199 |
+
[5, 5, 6],
|
| 200 |
+
[9, 9, 12],
|
| 201 |
+
[20, 11, 3],
|
| 202 |
+
])
|
| 203 |
+
|
| 204 |
+
expected_tokens, expected_mask = legacy_dynamic_forward(
|
| 205 |
+
comp, latents, frame_indices, targets, generated
|
| 206 |
+
)
|
| 207 |
+
tokens, mask = comp(latents, frame_indices, None, targets, generated)
|
| 208 |
+
|
| 209 |
+
assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
|
| 210 |
+
assert torch.equal(mask, expected_mask)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def test_dynamic_compressor_handles_empty_source_tensor():
|
| 214 |
+
comp = small_compressor(exclude_latest_local_frames=2)
|
| 215 |
+
latents = torch.randn(0, 2, 3, 2, 2)
|
| 216 |
+
frame_indices = torch.empty(0, 2, dtype=torch.long)
|
| 217 |
+
target = torch.tensor([[5, 6], [7, 8]])
|
| 218 |
+
tokens, mask = comp(latents, frame_indices, None, target)
|
| 219 |
+
assert tokens.shape == (2, 2, 2, 8)
|
| 220 |
+
assert not tokens.any()
|
| 221 |
+
assert not mask.any()
|
tests/test_dememwm_config_static.py
CHANGED
|
@@ -7,7 +7,8 @@ def test_config_is_distinct_standalone_memory_dit_path():
|
|
| 7 |
assert "base_video_dit" in text
|
| 8 |
assert "memory_token_cross_attention: true" in text
|
| 9 |
assert "dememwm:" in text
|
| 10 |
-
assert "
|
|
|
|
| 11 |
assert "ssm_memory" not in text
|
| 12 |
assert "ssm_memory_ckpt_path" not in text
|
| 13 |
|
|
@@ -42,7 +43,6 @@ def test_current_config_contract_is_explicit_and_has_no_stale_sections():
|
|
| 42 |
"plucker_grid_h: 4",
|
| 43 |
"plucker_grid_w: 4",
|
| 44 |
"plucker_focal_length: 0.35",
|
| 45 |
-
"noise_bucket_logging: true",
|
| 46 |
"eval_ablation:",
|
| 47 |
"branch: A_plus_D_plus_R_normal",
|
| 48 |
"generated_history_proxy:",
|
|
@@ -64,6 +64,8 @@ def test_current_config_contract_is_explicit_and_has_no_stale_sections():
|
|
| 64 |
"min_gap_frames",
|
| 65 |
"max_chunks",
|
| 66 |
"chunk_frames",
|
|
|
|
|
|
|
| 67 |
):
|
| 68 |
assert forbidden not in text
|
| 69 |
|
|
@@ -77,7 +79,6 @@ def test_full_scripts_use_consumed_contract_overrides():
|
|
| 77 |
"algorithm.dememwm.revisit.fov_pitch_samples=20",
|
| 78 |
"algorithm.dememwm.revisit.fov_depth_samples=20",
|
| 79 |
"algorithm.dememwm.revisit.plucker_weight=0.10",
|
| 80 |
-
"algorithm.dememwm.stage_policy.noise_bucket_logging=true",
|
| 81 |
"algorithm.dememwm.cache.keep_compressed_records=true",
|
| 82 |
]
|
| 83 |
stale = [
|
|
@@ -95,6 +96,8 @@ def test_full_scripts_use_consumed_contract_overrides():
|
|
| 95 |
"algorithm.dememwm.revisit.min_score",
|
| 96 |
"algorithm.dememwm.revisit.generated_penalty",
|
| 97 |
"algorithm.dememwm.rollout.",
|
|
|
|
|
|
|
| 98 |
]
|
| 99 |
expected_by_script = {
|
| 100 |
"scripts/dememwm_full_train.slurm": [
|
|
@@ -120,10 +123,9 @@ def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args():
|
|
| 120 |
"_validate_config_contract",
|
| 121 |
"deterministic_pose_retrieval",
|
| 122 |
"exclude_latest_local_frames",
|
| 123 |
-
"noise_bucket_logging",
|
| 124 |
"anchor_effective_enabled",
|
| 125 |
"dynamic_effective_enabled",
|
| 126 |
-
"
|
| 127 |
"stale DeMemWM config fields",
|
| 128 |
"revisit_retrieval_kwargs",
|
| 129 |
"fov_half_h",
|
|
@@ -135,56 +137,58 @@ def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args():
|
|
| 135 |
assert token in text
|
| 136 |
assert '_cfg_get(revisit_cfg, "topk"' not in text
|
| 137 |
assert "lambda_abstain" not in text
|
|
|
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
|
| 141 |
retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text()
|
| 142 |
labels = Path("algorithms/worldmem/dememwm/labels.py").read_text()
|
| 143 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 144 |
-
diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
|
| 145 |
for token in [
|
| 146 |
"exclude_local_context_frames",
|
| 147 |
"fov_overlap_threshold",
|
| 148 |
"plucker_weight",
|
| 149 |
"high_quality_fov_threshold",
|
| 150 |
-
"
|
|
|
|
|
|
|
|
|
|
| 151 |
"deterministic_fov_coverage_plucker",
|
| 152 |
"valid_revisit_mask",
|
| 153 |
-
"
|
| 154 |
-
"valid_candidate_label_count",
|
| 155 |
-
"valid_revisit_frame_count",
|
| 156 |
-
"no_valid_revisit_count",
|
| 157 |
-
"revisit_selected_frame_count",
|
| 158 |
-
"revisit_frame_fov_overlap",
|
| 159 |
-
"revisit_abstained_count",
|
| 160 |
]:
|
| 161 |
-
assert token in retrieval + labels + algorithm
|
| 162 |
assert "same_video" not in retrieval + labels
|
| 163 |
assert "wrong_video" not in retrieval + labels
|
| 164 |
for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]:
|
| 165 |
assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm
|
| 166 |
|
| 167 |
|
| 168 |
-
def
|
| 169 |
compression = Path("algorithms/worldmem/dememwm/compression.py").read_text()
|
| 170 |
cache = Path("algorithms/worldmem/dememwm/cache.py").read_text()
|
| 171 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 172 |
for token in [
|
| 173 |
"exclude_latest_local_frames",
|
| 174 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
"dynamic_min_gap_to_target_per_target",
|
| 176 |
"dynamic_max_gap_to_target_per_target",
|
| 177 |
"dynamic_exclude_latest_local_frames",
|
| 178 |
-
"
|
|
|
|
| 179 |
]:
|
| 180 |
-
assert
|
| 181 |
assert "src_frames_b < target, as_tuple=False" not in compression
|
| 182 |
assert "src < int(target), as_tuple=False" not in cache
|
| 183 |
|
| 184 |
|
| 185 |
-
def
|
| 186 |
schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
|
| 187 |
-
diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
|
| 188 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 189 |
for branch in [
|
| 190 |
"memory_off",
|
|
@@ -202,15 +206,17 @@ def test_eval_ablation_and_noise_bucket_logging_contracts():
|
|
| 202 |
"local_context_overlap_fake_revisit",
|
| 203 |
]:
|
| 204 |
assert branch in schedules
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
"noise_bucket_from_noise_levels",
|
| 208 |
"summarize_noise_bucket_diagnostics",
|
| 209 |
-
"noise_bucket_id",
|
| 210 |
"summarize_eval_ablation_diagnostics",
|
| 211 |
"eval_bucket_true_revisit_count",
|
| 212 |
"eval_bucket_no_valid_revisit_count",
|
| 213 |
"eval_bucket_corrupted_memory_count",
|
| 214 |
-
"
|
| 215 |
]:
|
| 216 |
-
assert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
assert "base_video_dit" in text
|
| 8 |
assert "memory_token_cross_attention: true" in text
|
| 9 |
assert "dememwm:" in text
|
| 10 |
+
assert "diagnostics:" not in text
|
| 11 |
+
assert "noise_bucket_logging" not in text
|
| 12 |
assert "ssm_memory" not in text
|
| 13 |
assert "ssm_memory_ckpt_path" not in text
|
| 14 |
|
|
|
|
| 43 |
"plucker_grid_h: 4",
|
| 44 |
"plucker_grid_w: 4",
|
| 45 |
"plucker_focal_length: 0.35",
|
|
|
|
| 46 |
"eval_ablation:",
|
| 47 |
"branch: A_plus_D_plus_R_normal",
|
| 48 |
"generated_history_proxy:",
|
|
|
|
| 64 |
"min_gap_frames",
|
| 65 |
"max_chunks",
|
| 66 |
"chunk_frames",
|
| 67 |
+
"diagnostics:",
|
| 68 |
+
"noise_bucket_logging",
|
| 69 |
):
|
| 70 |
assert forbidden not in text
|
| 71 |
|
|
|
|
| 79 |
"algorithm.dememwm.revisit.fov_pitch_samples=20",
|
| 80 |
"algorithm.dememwm.revisit.fov_depth_samples=20",
|
| 81 |
"algorithm.dememwm.revisit.plucker_weight=0.10",
|
|
|
|
| 82 |
"algorithm.dememwm.cache.keep_compressed_records=true",
|
| 83 |
]
|
| 84 |
stale = [
|
|
|
|
| 96 |
"algorithm.dememwm.revisit.min_score",
|
| 97 |
"algorithm.dememwm.revisit.generated_penalty",
|
| 98 |
"algorithm.dememwm.rollout.",
|
| 99 |
+
"algorithm.dememwm.diagnostics",
|
| 100 |
+
"algorithm.dememwm.stage_policy.noise_bucket_logging",
|
| 101 |
]
|
| 102 |
expected_by_script = {
|
| 103 |
"scripts/dememwm_full_train.slurm": [
|
|
|
|
| 123 |
"_validate_config_contract",
|
| 124 |
"deterministic_pose_retrieval",
|
| 125 |
"exclude_latest_local_frames",
|
|
|
|
| 126 |
"anchor_effective_enabled",
|
| 127 |
"dynamic_effective_enabled",
|
| 128 |
+
"revisit_stage_config_enabled",
|
| 129 |
"stale DeMemWM config fields",
|
| 130 |
"revisit_retrieval_kwargs",
|
| 131 |
"fov_half_h",
|
|
|
|
| 137 |
assert token in text
|
| 138 |
assert '_cfg_get(revisit_cfg, "topk"' not in text
|
| 139 |
assert "lambda_abstain" not in text
|
| 140 |
+
assert "noise_bucket" not in text
|
| 141 |
+
assert "diagnostics" not in text
|
| 142 |
|
| 143 |
|
| 144 |
def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
|
| 145 |
retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text()
|
| 146 |
labels = Path("algorithms/worldmem/dememwm/labels.py").read_text()
|
| 147 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
|
|
|
| 148 |
for token in [
|
| 149 |
"exclude_local_context_frames",
|
| 150 |
"fov_overlap_threshold",
|
| 151 |
"plucker_weight",
|
| 152 |
"high_quality_fov_threshold",
|
| 153 |
+
"dememwm_selected_frame_fov_overlap",
|
| 154 |
+
"best_selected_fov_overlap",
|
| 155 |
+
"best_selected_plucker_overlap",
|
| 156 |
+
"best_selected_gap_frames",
|
| 157 |
"deterministic_fov_coverage_plucker",
|
| 158 |
"valid_revisit_mask",
|
| 159 |
+
"batched_revisit_select_positions",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
]:
|
| 161 |
+
assert token in retrieval + labels + algorithm
|
| 162 |
assert "same_video" not in retrieval + labels
|
| 163 |
assert "wrong_video" not in retrieval + labels
|
| 164 |
for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]:
|
| 165 |
assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm
|
| 166 |
|
| 167 |
|
| 168 |
+
def test_dynamic_compressor_excludes_c_short_contract_without_diagnostics():
|
| 169 |
compression = Path("algorithms/worldmem/dememwm/compression.py").read_text()
|
| 170 |
cache = Path("algorithms/worldmem/dememwm/cache.py").read_text()
|
| 171 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 172 |
for token in [
|
| 173 |
"exclude_latest_local_frames",
|
| 174 |
+
"source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames))",
|
| 175 |
+
"_local_context_exclusion_frames",
|
| 176 |
+
]:
|
| 177 |
+
assert token in compression + cache + algorithm
|
| 178 |
+
for removed in [
|
| 179 |
"dynamic_min_gap_to_target_per_target",
|
| 180 |
"dynamic_max_gap_to_target_per_target",
|
| 181 |
"dynamic_exclude_latest_local_frames",
|
| 182 |
+
"selected_source_count",
|
| 183 |
+
"generated_source_fraction",
|
| 184 |
]:
|
| 185 |
+
assert removed not in compression + algorithm
|
| 186 |
assert "src_frames_b < target, as_tuple=False" not in compression
|
| 187 |
assert "src < int(target), as_tuple=False" not in cache
|
| 188 |
|
| 189 |
|
| 190 |
+
def test_eval_ablation_contracts_without_diagnostic_summaries():
|
| 191 |
schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
|
|
|
|
| 192 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 193 |
for branch in [
|
| 194 |
"memory_off",
|
|
|
|
| 206 |
"local_context_overlap_fake_revisit",
|
| 207 |
]:
|
| 208 |
assert branch in schedules
|
| 209 |
+
assert "apply_revisit_eval_corruption" in algorithm
|
| 210 |
+
for removed in [
|
|
|
|
| 211 |
"summarize_noise_bucket_diagnostics",
|
|
|
|
| 212 |
"summarize_eval_ablation_diagnostics",
|
| 213 |
"eval_bucket_true_revisit_count",
|
| 214 |
"eval_bucket_no_valid_revisit_count",
|
| 215 |
"eval_bucket_corrupted_memory_count",
|
| 216 |
+
"noise_bucket_id",
|
| 217 |
]:
|
| 218 |
+
assert removed not in schedules + algorithm
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def test_dememwm_diagnostics_module_removed():
|
| 222 |
+
assert not Path("algorithms/worldmem/dememwm/diagnostics.py").exists()
|
tests/test_dememwm_dit_extension_static.py
CHANGED
|
@@ -73,17 +73,15 @@ def test_shared_memory_attention_zero_revisit_gate_matches_anchor_only():
|
|
| 73 |
assert torch.allclose(out_anchor, out_packed, atol=1e-5)
|
| 74 |
|
| 75 |
|
| 76 |
-
def
|
| 77 |
attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
|
| 78 |
x = torch.randn(1, 2, 1, 1, 8)
|
| 79 |
c = torch.randn(1, 2, 8)
|
| 80 |
mem = torch.randn(1, 2, 3, 8)
|
| 81 |
mask = torch.ones(1, 2, 3, dtype=torch.bool)
|
| 82 |
type_ids = torch.tensor([0, 1, 2])
|
| 83 |
-
|
| 84 |
-
assert
|
| 85 |
-
assert torch.is_tensor(attn.last_type_gate_dynamic_mean)
|
| 86 |
-
assert torch.is_tensor(attn.last_type_gate_revisit_mean)
|
| 87 |
|
| 88 |
|
| 89 |
def test_dit_accepts_dynamic_rank4_tokens_and_all_false_masks_without_nan():
|
|
@@ -120,7 +118,7 @@ def test_diffusion_methods_accept_option_c_kwargs_by_signature():
|
|
| 120 |
|
| 121 |
|
| 122 |
|
| 123 |
-
def
|
| 124 |
attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
|
| 125 |
x = torch.randn(1, 2, 1, 1, 8)
|
| 126 |
c = torch.randn(1, 2, 8)
|
|
@@ -128,10 +126,9 @@ def test_fresh_memory_cross_attention_identity_init_delta_ratio_is_zero():
|
|
| 128 |
mask = torch.ones(1, 2, 3, dtype=torch.bool)
|
| 129 |
delta = attn(x, c, mem, mask, return_delta=True, residual_gate=torch.ones(1, 2, 1))
|
| 130 |
assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6)
|
| 131 |
-
assert float(attn.last_delta_ratio.item()) <= 1e-7
|
| 132 |
|
| 133 |
|
| 134 |
-
def
|
| 135 |
model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True)
|
| 136 |
x = torch.randn(1, 2, 2, 4, 4)
|
| 137 |
t = torch.zeros(1, 2, dtype=torch.long)
|
|
@@ -152,6 +149,4 @@ def test_fresh_dit_memory_on_matches_memory_off_and_reports_delta_ratio():
|
|
| 152 |
memory_dynamic_gate=torch.ones(1, 2, 1),
|
| 153 |
memory_retrieval_gate=torch.ones(1, 2, 1),
|
| 154 |
)
|
| 155 |
-
diagnostics = model.memory_adapter_delta_diagnostics()
|
| 156 |
assert torch.allclose(out_on, out_off, atol=1e-6)
|
| 157 |
-
assert diagnostics["memory_adapter_delta_ratio_max"] <= 1e-7
|
|
|
|
| 73 |
assert torch.allclose(out_anchor, out_packed, atol=1e-5)
|
| 74 |
|
| 75 |
|
| 76 |
+
def test_memory_cross_attention_type_ids_apply_stage_gates():
|
| 77 |
attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
|
| 78 |
x = torch.randn(1, 2, 1, 1, 8)
|
| 79 |
c = torch.randn(1, 2, 8)
|
| 80 |
mem = torch.randn(1, 2, 3, 8)
|
| 81 |
mask = torch.ones(1, 2, 3, dtype=torch.bool)
|
| 82 |
type_ids = torch.tensor([0, 1, 2])
|
| 83 |
+
out = attn(x, c, mem, mask, memory_type_ids=type_ids, memory_token_gate=torch.ones(1, 2, 3))
|
| 84 |
+
assert out.shape == x.shape
|
|
|
|
|
|
|
| 85 |
|
| 86 |
|
| 87 |
def test_dit_accepts_dynamic_rank4_tokens_and_all_false_masks_without_nan():
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
|
| 121 |
+
def test_fresh_memory_cross_attention_identity_init_delta_is_zero():
|
| 122 |
attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
|
| 123 |
x = torch.randn(1, 2, 1, 1, 8)
|
| 124 |
c = torch.randn(1, 2, 8)
|
|
|
|
| 126 |
mask = torch.ones(1, 2, 3, dtype=torch.bool)
|
| 127 |
delta = attn(x, c, mem, mask, return_delta=True, residual_gate=torch.ones(1, 2, 1))
|
| 128 |
assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6)
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
+
def test_fresh_dit_memory_on_matches_memory_off():
|
| 132 |
model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True)
|
| 133 |
x = torch.randn(1, 2, 2, 4, 4)
|
| 134 |
t = torch.zeros(1, 2, dtype=torch.long)
|
|
|
|
| 149 |
memory_dynamic_gate=torch.ones(1, 2, 1),
|
| 150 |
memory_retrieval_gate=torch.ones(1, 2, 1),
|
| 151 |
)
|
|
|
|
| 152 |
assert torch.allclose(out_on, out_off, atol=1e-6)
|
|
|
tests/test_dememwm_eval_ablation.py
CHANGED
|
@@ -7,7 +7,6 @@ from dememwm_import_helper import install_dememwm_namespace
|
|
| 7 |
install_dememwm_namespace()
|
| 8 |
from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
|
| 9 |
from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor
|
| 10 |
-
from algorithms.worldmem.dememwm.diagnostics import summarize_eval_ablation_diagnostics
|
| 11 |
from algorithms.worldmem.dememwm.schedules import (
|
| 12 |
EVAL_ABLATION_BRANCHES,
|
| 13 |
EVAL_ABLATION_BRANCH_TO_ID,
|
|
@@ -48,25 +47,6 @@ def test_wave9_branch_registry_is_exact_and_validated():
|
|
| 48 |
normalize_eval_ablation_branch("ratio_sweep")
|
| 49 |
|
| 50 |
|
| 51 |
-
def test_eval_ablation_diagnostics_bucket_counts():
|
| 52 |
-
diag = summarize_eval_ablation_diagnostics(
|
| 53 |
-
enabled=True,
|
| 54 |
-
branch="wrong_pose",
|
| 55 |
-
valid_revisit_mask=torch.tensor([[True, True, True, False]]),
|
| 56 |
-
no_valid_revisit_mask=torch.tensor([[False, False, False, True]]),
|
| 57 |
-
eval_corrupted_revisit_mask=torch.tensor([[False, True, True, False]]),
|
| 58 |
-
)
|
| 59 |
-
assert diag["eval_ablation_enabled"] is True
|
| 60 |
-
assert diag["eval_ablation_branch"] == "wrong_pose"
|
| 61 |
-
assert diag["eval_ablation_branch_id"] == EVAL_ABLATION_BRANCH_TO_ID["wrong_pose"]
|
| 62 |
-
assert diag["eval_bucket_true_revisit_count"] == 1
|
| 63 |
-
assert diag["eval_bucket_no_valid_revisit_count"] == 1
|
| 64 |
-
assert diag["eval_bucket_corrupted_memory_count"] == 2
|
| 65 |
-
assert diag["eval_bucket_true_revisit_fraction"] == pytest.approx(0.25)
|
| 66 |
-
assert diag["eval_bucket_no_valid_revisit_fraction"] == pytest.approx(0.25)
|
| 67 |
-
assert diag["eval_bucket_corrupted_memory_fraction"] == pytest.approx(0.5)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
class ConstantGate(torch.nn.Module):
|
| 71 |
def __init__(self, value: float):
|
| 72 |
super().__init__()
|
|
@@ -83,7 +63,6 @@ class DummyDeMemWM(MemoryDiTMixin):
|
|
| 83 |
dememwm=types.SimpleNamespace(
|
| 84 |
enabled=True,
|
| 85 |
training_stage="stage_2",
|
| 86 |
-
debug_force_all_streams=False,
|
| 87 |
token_patch_size=2,
|
| 88 |
curriculum=types.SimpleNamespace(enabled=False),
|
| 89 |
anchor=types.SimpleNamespace(
|
|
@@ -108,7 +87,6 @@ class DummyDeMemWM(MemoryDiTMixin):
|
|
| 108 |
max_frames=2,
|
| 109 |
compress=types.SimpleNamespace(pool_h=1, pool_w=1),
|
| 110 |
),
|
| 111 |
-
stage_policy=types.SimpleNamespace(noise_bucket_logging=True),
|
| 112 |
eval_ablation=types.SimpleNamespace(enabled=True, branch=branch),
|
| 113 |
generated_history_proxy=types.SimpleNamespace(enabled=False),
|
| 114 |
injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0),
|
|
@@ -190,12 +168,11 @@ def test_eval_ablation_forced_revisit_controls_are_isolated_to_eval_branch():
|
|
| 190 |
assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25))
|
| 191 |
assert torch.count_nonzero(forced_off.revisit_gate).item() == 0
|
| 192 |
assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype))
|
| 193 |
-
assert forced_on.diagnostics["eval_ablation_branch"] == "R_forced_on"
|
| 194 |
|
| 195 |
|
| 196 |
def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate():
|
| 197 |
wrong_pose = _streams("wrong_pose")
|
| 198 |
assert wrong_pose.valid_revisit_mask.all()
|
|
|
|
| 199 |
assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25))
|
| 200 |
-
assert wrong_pose.
|
| 201 |
-
assert wrong_pose.diagnostics["eval_bucket_true_revisit_count"] == 0
|
|
|
|
| 7 |
install_dememwm_namespace()
|
| 8 |
from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
|
| 9 |
from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor
|
|
|
|
| 10 |
from algorithms.worldmem.dememwm.schedules import (
|
| 11 |
EVAL_ABLATION_BRANCHES,
|
| 12 |
EVAL_ABLATION_BRANCH_TO_ID,
|
|
|
|
| 47 |
normalize_eval_ablation_branch("ratio_sweep")
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
class ConstantGate(torch.nn.Module):
|
| 51 |
def __init__(self, value: float):
|
| 52 |
super().__init__()
|
|
|
|
| 63 |
dememwm=types.SimpleNamespace(
|
| 64 |
enabled=True,
|
| 65 |
training_stage="stage_2",
|
|
|
|
| 66 |
token_patch_size=2,
|
| 67 |
curriculum=types.SimpleNamespace(enabled=False),
|
| 68 |
anchor=types.SimpleNamespace(
|
|
|
|
| 87 |
max_frames=2,
|
| 88 |
compress=types.SimpleNamespace(pool_h=1, pool_w=1),
|
| 89 |
),
|
|
|
|
| 90 |
eval_ablation=types.SimpleNamespace(enabled=True, branch=branch),
|
| 91 |
generated_history_proxy=types.SimpleNamespace(enabled=False),
|
| 92 |
injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0),
|
|
|
|
| 168 |
assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25))
|
| 169 |
assert torch.count_nonzero(forced_off.revisit_gate).item() == 0
|
| 170 |
assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype))
|
|
|
|
| 171 |
|
| 172 |
|
| 173 |
def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate():
|
| 174 |
wrong_pose = _streams("wrong_pose")
|
| 175 |
assert wrong_pose.valid_revisit_mask.all()
|
| 176 |
+
assert wrong_pose.revisit_mask.any()
|
| 177 |
assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25))
|
| 178 |
+
assert wrong_pose.revisit_best_selected_fov_overlap.shape == wrong_pose.valid_revisit_mask.shape
|
|
|
tests/test_dememwm_freeze_policy.py
CHANGED
|
@@ -62,8 +62,6 @@ def test_dit_freeze_keeps_requires_grad_stable_and_zeroes_optimizer_lr():
|
|
| 62 |
assert frozen_state.dit_train_state == "frozen"
|
| 63 |
assert full_dit_params
|
| 64 |
assert all(param.requires_grad for param in full_dit_params)
|
| 65 |
-
assert model._last_dememwm_freeze_diagnostics["trainable_tensors_full_dit"] == 0
|
| 66 |
-
assert model._last_dememwm_freeze_diagnostics["requires_grad_tensors_full_dit"] == len(full_dit_params)
|
| 67 |
assert all(not param.requires_grad for param in model.vae.parameters())
|
| 68 |
|
| 69 |
for param in full_dit_params:
|
|
@@ -85,7 +83,6 @@ def test_dit_freeze_keeps_requires_grad_stable_and_zeroes_optimizer_lr():
|
|
| 85 |
|
| 86 |
assert full_state.dit_train_state == "full"
|
| 87 |
assert all(param.requires_grad for param in full_dit_params)
|
| 88 |
-
assert model._last_dememwm_freeze_diagnostics["trainable_tensors_full_dit"] == len(full_dit_params)
|
| 89 |
assert lr_by_name["full_dit"] == 1.0e-5
|
| 90 |
assert all(not param.requires_grad for param in model.vae.parameters())
|
| 91 |
|
|
|
|
| 62 |
assert frozen_state.dit_train_state == "frozen"
|
| 63 |
assert full_dit_params
|
| 64 |
assert all(param.requires_grad for param in full_dit_params)
|
|
|
|
|
|
|
| 65 |
assert all(not param.requires_grad for param in model.vae.parameters())
|
| 66 |
|
| 67 |
for param in full_dit_params:
|
|
|
|
| 83 |
|
| 84 |
assert full_state.dit_train_state == "full"
|
| 85 |
assert all(param.requires_grad for param in full_dit_params)
|
|
|
|
| 86 |
assert lr_by_name["full_dit"] == 1.0e-5
|
| 87 |
assert all(not param.requires_grad for param in model.vae.parameters())
|
| 88 |
|
tests/test_dememwm_generated_history_proxy.py
CHANGED
|
@@ -68,7 +68,7 @@ def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_
|
|
| 68 |
source_is_generated = torch.zeros(4, 1, dtype=torch.bool)
|
| 69 |
|
| 70 |
torch.manual_seed(123)
|
| 71 |
-
corrupted, generated
|
| 72 |
source_latents,
|
| 73 |
source_is_generated,
|
| 74 |
)
|
|
@@ -77,8 +77,7 @@ def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_
|
|
| 77 |
assert not torch.equal(corrupted, source_latents)
|
| 78 |
assert generated.all()
|
| 79 |
assert not source_is_generated.any()
|
| 80 |
-
assert
|
| 81 |
-
assert diagnostics["generated_history_proxy_frame_fraction"] == 1.0
|
| 82 |
|
| 83 |
|
| 84 |
|
|
@@ -88,7 +87,7 @@ def test_generated_history_proxy_respects_context_prefix_and_target_window_bound
|
|
| 88 |
source_is_generated = torch.zeros(8, 1, dtype=torch.bool)
|
| 89 |
|
| 90 |
torch.manual_seed(123)
|
| 91 |
-
corrupted, generated
|
| 92 |
source_latents,
|
| 93 |
source_is_generated,
|
| 94 |
context_frame_count=3,
|
|
@@ -104,8 +103,7 @@ def test_generated_history_proxy_respects_context_prefix_and_target_window_bound
|
|
| 104 |
assert torch.equal(corrupted[:3], source_latents[:3])
|
| 105 |
assert not torch.equal(corrupted[3:6], source_latents[3:6])
|
| 106 |
assert torch.equal(corrupted[6:], source_latents[6:])
|
| 107 |
-
assert
|
| 108 |
-
assert diagnostics["generated_history_proxy_frame_fraction"] == 3 / 8
|
| 109 |
|
| 110 |
|
| 111 |
def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources():
|
|
|
|
| 68 |
source_is_generated = torch.zeros(4, 1, dtype=torch.bool)
|
| 69 |
|
| 70 |
torch.manual_seed(123)
|
| 71 |
+
corrupted, generated = model._apply_generated_history_proxy(
|
| 72 |
source_latents,
|
| 73 |
source_is_generated,
|
| 74 |
)
|
|
|
|
| 77 |
assert not torch.equal(corrupted, source_latents)
|
| 78 |
assert generated.all()
|
| 79 |
assert not source_is_generated.any()
|
| 80 |
+
assert generated.sum().item() == 4
|
|
|
|
| 81 |
|
| 82 |
|
| 83 |
|
|
|
|
| 87 |
source_is_generated = torch.zeros(8, 1, dtype=torch.bool)
|
| 88 |
|
| 89 |
torch.manual_seed(123)
|
| 90 |
+
corrupted, generated = model._apply_generated_history_proxy(
|
| 91 |
source_latents,
|
| 92 |
source_is_generated,
|
| 93 |
context_frame_count=3,
|
|
|
|
| 103 |
assert torch.equal(corrupted[:3], source_latents[:3])
|
| 104 |
assert not torch.equal(corrupted[3:6], source_latents[3:6])
|
| 105 |
assert torch.equal(corrupted[6:], source_latents[6:])
|
| 106 |
+
assert generated.sum().item() == 3
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources():
|
tests/test_dememwm_injection_static.py
CHANGED
|
@@ -18,23 +18,19 @@ def _streams(dtype=torch.float32):
|
|
| 18 |
anchor_gate=1.0,
|
| 19 |
dynamic_gate=torch.ones(2, 3, 1) * 0.5,
|
| 20 |
revisit_gate=0.0,
|
| 21 |
-
diagnostics={"selected_revisit_frame_record_ids": ["c1"], "dynamic_max_source_frame": torch.tensor(2)},
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
-
def
|
| 26 |
-
kwargs
|
| 27 |
assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
|
| 28 |
assert kwargs["memory_tokens"].dtype == torch.float64
|
| 29 |
assert kwargs["memory_dynamic_mask"].dtype == torch.bool
|
| 30 |
-
assert
|
| 31 |
-
assert diag["dynamic_valid_fraction"] > 0.0
|
| 32 |
-
assert diag["selected_revisit_frame_record_ids"] == ["c1"]
|
| 33 |
-
assert diag["max_source_frame"] == 2
|
| 34 |
|
| 35 |
|
| 36 |
def test_injection_omit_disabled_streams():
|
| 37 |
-
kwargs
|
| 38 |
assert kwargs["memory_retrieval_tokens"] is None
|
| 39 |
assert kwargs["memory_retrieval_mask"] is None
|
| 40 |
assert kwargs["memory_dynamic_tokens"] is not None
|
|
|
|
| 18 |
anchor_gate=1.0,
|
| 19 |
dynamic_gate=torch.ones(2, 3, 1) * 0.5,
|
| 20 |
revisit_gate=0.0,
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
|
| 24 |
+
def test_injection_kwarg_names_masks_and_dtype():
|
| 25 |
+
kwargs = InjectionAdapter()(_streams(), dtype=torch.float64)
|
| 26 |
assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
|
| 27 |
assert kwargs["memory_tokens"].dtype == torch.float64
|
| 28 |
assert kwargs["memory_dynamic_mask"].dtype == torch.bool
|
| 29 |
+
assert kwargs["memory_retrieval_tokens"].dtype == torch.float64
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def test_injection_omit_disabled_streams():
|
| 33 |
+
kwargs = InjectionAdapter(omit_disabled=True)(_streams())
|
| 34 |
assert kwargs["memory_retrieval_tokens"] is None
|
| 35 |
assert kwargs["memory_retrieval_mask"] is None
|
| 36 |
assert kwargs["memory_dynamic_tokens"] is not None
|
tests/test_dememwm_noise_bucket.py
CHANGED
|
@@ -1,102 +1,19 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
assert diag["revisit_candidate_frame_count"] == 5.0
|
| 19 |
-
assert diag["revisit_candidate_count"] == 5.0
|
| 20 |
-
assert diag["valid_revisit_frame_count"] == 3.0
|
| 21 |
-
assert diag["valid_revisit_count"] == 3.0
|
| 22 |
-
assert diag["revisit_selected_frame_count"] == 3
|
| 23 |
-
assert diag["no_valid_revisit_count"] == 1
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def test_noise_bucket_diagnostics_include_valid_and_no_valid_counts():
|
| 27 |
-
diag = summarize_noise_bucket_diagnostics(
|
| 28 |
-
noise_bucket="high",
|
| 29 |
-
valid_revisit_mask=torch.tensor([[True, True, False]]),
|
| 30 |
-
no_valid_revisit_mask=torch.tensor([[False, False, True]]),
|
| 31 |
-
)
|
| 32 |
-
assert diag["noise_bucket"] == "high"
|
| 33 |
-
assert diag["noise_bucket_id"] == 0
|
| 34 |
-
assert diag["noise_bucket_is_high"] == 1
|
| 35 |
-
assert diag["noise_bucket_is_mid"] == 0
|
| 36 |
-
assert diag["noise_bucket_high_target_count"] == 3
|
| 37 |
-
assert diag["noise_bucket_mid_target_count"] == 0
|
| 38 |
-
assert diag["valid_revisit_noise_bucket_high_count"] == 2
|
| 39 |
-
assert diag["no_valid_revisit_noise_bucket_high_count"] == 1
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def test_noise_bucket_diagnostics_count_per_target_bucket_ids():
|
| 43 |
-
diag = summarize_noise_bucket_diagnostics(
|
| 44 |
-
noise_bucket="mid",
|
| 45 |
-
noise_bucket_ids=torch.tensor([[0, 1, 2]]),
|
| 46 |
-
valid_revisit_mask=torch.tensor([[True, True, False]]),
|
| 47 |
-
no_valid_revisit_mask=torch.tensor([[False, False, True]]),
|
| 48 |
-
)
|
| 49 |
-
assert diag["noise_bucket"] == "mid"
|
| 50 |
-
assert diag["noise_bucket_id"] == 1
|
| 51 |
-
assert diag["noise_bucket_is_mid"] == 1
|
| 52 |
-
assert diag["noise_bucket_high_target_count"] == 1
|
| 53 |
-
assert diag["noise_bucket_mid_target_count"] == 1
|
| 54 |
-
assert diag["noise_bucket_low_target_count"] == 1
|
| 55 |
-
assert diag["valid_revisit_noise_bucket_high_count"] == 1
|
| 56 |
-
assert diag["valid_revisit_noise_bucket_mid_count"] == 1
|
| 57 |
-
assert diag["valid_revisit_noise_bucket_low_count"] == 0
|
| 58 |
-
assert diag["no_valid_revisit_noise_bucket_low_count"] == 1
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def test_noise_bucket_log_allowlist_keeps_target_counts_only():
|
| 62 |
-
keys = MemoryDiTMixin._TRAIN_DIAGNOSTIC_LOG_KEYS
|
| 63 |
-
for key in (
|
| 64 |
-
"anchor_valid_fraction",
|
| 65 |
-
"dynamic_valid_fraction",
|
| 66 |
-
"revisit_valid_fraction",
|
| 67 |
-
"valid_revisit_mask_fraction",
|
| 68 |
-
"revisit_candidate_count",
|
| 69 |
-
"valid_revisit_count",
|
| 70 |
-
"revisit_selected_count",
|
| 71 |
-
"revisit_fov_overlap_mean",
|
| 72 |
-
"revisit_incremental_fov_overlap_mean",
|
| 73 |
-
"revisit_plucker_overlap_mean",
|
| 74 |
-
"causal_violation_count",
|
| 75 |
-
"noise_bucket_id",
|
| 76 |
-
"noise_bucket_is_high",
|
| 77 |
-
"noise_bucket_is_mid",
|
| 78 |
-
"noise_bucket_is_low",
|
| 79 |
-
"revisit_raw_gate_mean",
|
| 80 |
-
"valid_revisit_noise_bucket_high_count",
|
| 81 |
-
"valid_revisit_noise_bucket_mid_count",
|
| 82 |
-
"valid_revisit_noise_bucket_low_count",
|
| 83 |
-
"no_valid_revisit_noise_bucket_high_count",
|
| 84 |
-
"no_valid_revisit_noise_bucket_mid_count",
|
| 85 |
-
"no_valid_revisit_noise_bucket_low_count",
|
| 86 |
-
):
|
| 87 |
-
assert key not in keys
|
| 88 |
-
for key in [
|
| 89 |
-
"noise_bucket_target_count",
|
| 90 |
-
"noise_bucket_high_target_count",
|
| 91 |
-
"noise_bucket_mid_target_count",
|
| 92 |
-
"noise_bucket_low_target_count",
|
| 93 |
-
"revisit_candidate_frame_count",
|
| 94 |
-
"valid_revisit_frame_count",
|
| 95 |
-
"revisit_selected_frame_count",
|
| 96 |
-
"revisit_frame_fov_overlap_mean",
|
| 97 |
-
"revisit_best_selected_frame_fov_overlap_mean",
|
| 98 |
-
"revisit_best_selected_plucker_overlap_mean",
|
| 99 |
-
"revisit_best_selected_gap_frames_mean",
|
| 100 |
-
"revisit_learned_gate_mean",
|
| 101 |
]:
|
| 102 |
-
assert
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_training_logging_keeps_only_core_scalars():
|
| 5 |
+
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 6 |
+
expected_logs = [
|
| 7 |
+
"training/loss",
|
| 8 |
+
"training/denoise_loss",
|
| 9 |
+
"training/revisit_gate",
|
| 10 |
+
]
|
| 11 |
+
for key in expected_logs:
|
| 12 |
+
assert key in algorithm
|
| 13 |
+
for removed_key in [
|
| 14 |
+
"training/dynamic_max_source_frame",
|
| 15 |
+
"training/revisit_valid_count",
|
| 16 |
+
"training/noise_bucket_id",
|
| 17 |
+
"training/eval_bucket_true_revisit_count",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
]:
|
| 19 |
+
assert removed_key not in algorithm
|
tests/test_dememwm_preselection.py
CHANGED
|
@@ -107,7 +107,7 @@ def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
|
|
| 107 |
frame_indices = torch.arange(8)[:, None]
|
| 108 |
poses = torch.zeros((8, 1, 5), dtype=torch.float32)
|
| 109 |
target_pose = torch.zeros((1, 1, 5), dtype=torch.float32)
|
| 110 |
-
anchor_banks, _, _,
|
| 111 |
committed_latents=latents,
|
| 112 |
source_frame_indices=frame_indices,
|
| 113 |
source_is_generated=None,
|
|
@@ -133,7 +133,7 @@ def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
|
|
| 133 |
)
|
| 134 |
|
| 135 |
assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
|
| 136 |
-
assert
|
| 137 |
|
| 138 |
|
| 139 |
def test_streaming_diverse_anchor_selection_uses_context_frames():
|
|
@@ -168,7 +168,7 @@ def test_preselected_memory_banks_project_only_selected_frames():
|
|
| 168 |
target_frame_indices = torch.tensor([[10], [11]])
|
| 169 |
poses = torch.zeros((20, 1, 5), dtype=torch.float32)
|
| 170 |
target_pose = torch.zeros((2, 1, 5), dtype=torch.float32)
|
| 171 |
-
anchor_banks, revisit_banks, tokens_per_frame,
|
| 172 |
committed_latents=latents,
|
| 173 |
source_frame_indices=frame_indices,
|
| 174 |
source_is_generated=None,
|
|
@@ -195,10 +195,10 @@ def test_preselected_memory_banks_project_only_selected_frames():
|
|
| 195 |
assert tokens_per_frame == 1
|
| 196 |
assert len(anchor_banks[0].records) == 4
|
| 197 |
assert len(revisit_banks[0].records) == 3
|
| 198 |
-
assert
|
| 199 |
-
assert
|
| 200 |
-
assert
|
| 201 |
-
assert harness.project_call_lengths == [4,
|
| 202 |
assert 20 not in harness.project_call_lengths
|
| 203 |
|
| 204 |
|
|
@@ -221,7 +221,7 @@ def test_preselected_revisit_projects_best_fov_frame_not_latest():
|
|
| 221 |
)
|
| 222 |
poses = pose_rows[:, None, :]
|
| 223 |
|
| 224 |
-
_, revisit_banks, _, _ = harness._build_preselected_causal_memory_banks(
|
| 225 |
committed_latents=latents,
|
| 226 |
source_frame_indices=frame_indices,
|
| 227 |
source_is_generated=None,
|
|
@@ -246,6 +246,7 @@ def test_preselected_revisit_projects_best_fov_frame_not_latest():
|
|
| 246 |
token_patch_size=2,
|
| 247 |
)
|
| 248 |
|
|
|
|
| 249 |
assert len(revisit_banks[0].records) == 1
|
| 250 |
assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1
|
| 251 |
assert harness.project_call_values == [[1.0]]
|
|
|
|
| 107 |
frame_indices = torch.arange(8)[:, None]
|
| 108 |
poses = torch.zeros((8, 1, 5), dtype=torch.float32)
|
| 109 |
target_pose = torch.zeros((1, 1, 5), dtype=torch.float32)
|
| 110 |
+
anchor_banks, _, _, _, _ = harness._build_preselected_causal_memory_banks(
|
| 111 |
committed_latents=latents,
|
| 112 |
source_frame_indices=frame_indices,
|
| 113 |
source_is_generated=None,
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
|
| 136 |
+
assert harness.project_call_lengths == [2]
|
| 137 |
|
| 138 |
|
| 139 |
def test_streaming_diverse_anchor_selection_uses_context_frames():
|
|
|
|
| 168 |
target_frame_indices = torch.tensor([[10], [11]])
|
| 169 |
poses = torch.zeros((20, 1, 5), dtype=torch.float32)
|
| 170 |
target_pose = torch.zeros((2, 1, 5), dtype=torch.float32)
|
| 171 |
+
anchor_banks, revisit_banks, tokens_per_frame, selected_by_target, stats = harness._build_preselected_causal_memory_banks(
|
| 172 |
committed_latents=latents,
|
| 173 |
source_frame_indices=frame_indices,
|
| 174 |
source_is_generated=None,
|
|
|
|
| 195 |
assert tokens_per_frame == 1
|
| 196 |
assert len(anchor_banks[0].records) == 4
|
| 197 |
assert len(revisit_banks[0].records) == 3
|
| 198 |
+
assert selected_by_target is not None
|
| 199 |
+
assert stats is not None
|
| 200 |
+
assert [len(records) for records in selected_by_target[0]] == [2, 2]
|
| 201 |
+
assert harness.project_call_lengths == [4, 3]
|
| 202 |
assert 20 not in harness.project_call_lengths
|
| 203 |
|
| 204 |
|
|
|
|
| 221 |
)
|
| 222 |
poses = pose_rows[:, None, :]
|
| 223 |
|
| 224 |
+
_, revisit_banks, _, selected_by_target, _ = harness._build_preselected_causal_memory_banks(
|
| 225 |
committed_latents=latents,
|
| 226 |
source_frame_indices=frame_indices,
|
| 227 |
source_is_generated=None,
|
|
|
|
| 246 |
token_patch_size=2,
|
| 247 |
)
|
| 248 |
|
| 249 |
+
assert selected_by_target is not None
|
| 250 |
assert len(revisit_banks[0].records) == 1
|
| 251 |
assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1
|
| 252 |
assert harness.project_call_values == [[1.0]]
|
tests/test_dememwm_retrieval.py
CHANGED
|
@@ -91,21 +91,19 @@ def test_revisit_candidates_require_causal_c_short_gap():
|
|
| 91 |
exclude_local_context_frames=4,
|
| 92 |
)
|
| 93 |
assert [r.max_source_frame for r in result.records] == [1]
|
| 94 |
-
assert result.
|
| 95 |
-
assert result.
|
| 96 |
-
assert result.
|
| 97 |
-
assert result.diagnostics["valid_revisit_count"] == 1
|
| 98 |
-
assert result.diagnostics["valid_candidate_label_count"] == 1
|
| 99 |
-
assert result.diagnostics["revisit_min_gap_to_target"] == 5
|
| 100 |
-
assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
|
| 101 |
|
| 102 |
|
| 103 |
def test_revisit_abstains_when_no_valid_candidate():
|
| 104 |
result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4)
|
| 105 |
assert result.records == []
|
| 106 |
-
assert result.
|
| 107 |
-
assert result.
|
| 108 |
-
assert result.
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
def test_revisit_retrieval_rejects_non_vectorized_inputs():
|
|
@@ -150,14 +148,11 @@ def test_fov_threshold_filters_candidates_without_action():
|
|
| 150 |
exclude_local_context_frames=4,
|
| 151 |
topk=4,
|
| 152 |
)
|
| 153 |
-
assert result.
|
| 154 |
-
assert result.
|
| 155 |
-
assert result.
|
| 156 |
-
assert result.
|
| 157 |
-
assert result.
|
| 158 |
-
assert result.diagnostics["best_selected_gap_frames"] == 10
|
| 159 |
-
assert result.diagnostics["revisit_fov_overlap_max"] == 1.0
|
| 160 |
-
assert result.diagnostics["revisit_plucker_overlap_max"] > 0.0
|
| 161 |
|
| 162 |
|
| 163 |
def test_pose_preselect_uses_local_position_and_view_direction_before_fov():
|
|
@@ -177,13 +172,8 @@ def test_pose_preselect_uses_local_position_and_view_direction_before_fov():
|
|
| 177 |
pose_preselect_topk=1,
|
| 178 |
)
|
| 179 |
|
| 180 |
-
assert result.
|
| 181 |
-
assert result.
|
| 182 |
-
assert result.diagnostics["revisit_pose_preselect_scored_count"] == 3
|
| 183 |
-
assert result.diagnostics["revisit_pose_preselect_selected_count"] == 1
|
| 184 |
-
assert result.diagnostics["revisit_exact_fov_candidate_count"] == 1
|
| 185 |
-
assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
|
| 186 |
-
assert abs(result.diagnostics["revisit_pose_preselect_min_distance"] - (1.0 / 30.0)) < 1e-6
|
| 187 |
|
| 188 |
|
| 189 |
def test_selected_frame_carries_frame_metadata_for_projection():
|
|
@@ -197,15 +187,14 @@ def test_selected_frame_carries_frame_metadata_for_projection():
|
|
| 197 |
topk=1,
|
| 198 |
)
|
| 199 |
|
| 200 |
-
assert result.
|
| 201 |
assert result.selected_frame_ids == [1]
|
| 202 |
assert result.records[0].metadata["dememwm_selected_frame_index"] == 1
|
|
|
|
| 203 |
assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
|
| 204 |
-
assert result.diagnostics["best_selected_frame_index"] == 1
|
| 205 |
-
assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0
|
| 206 |
|
| 207 |
|
| 208 |
-
def
|
| 209 |
result = deterministic_revisit_retrieval(
|
| 210 |
[rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])],
|
| 211 |
target_frame=10,
|
|
@@ -215,9 +204,10 @@ def test_high_quality_threshold_is_selected_target_diagnostic_only():
|
|
| 215 |
exclude_local_context_frames=4,
|
| 216 |
topk=1,
|
| 217 |
)
|
| 218 |
-
assert result.
|
| 219 |
-
assert result.
|
| 220 |
-
assert 0.30 <= result.
|
|
|
|
| 221 |
|
| 222 |
|
| 223 |
def test_video_metadata_does_not_filter_revisit_candidates():
|
|
@@ -233,8 +223,8 @@ def test_video_metadata_does_not_filter_revisit_candidates():
|
|
| 233 |
exclude_local_context_frames=4,
|
| 234 |
topk=4,
|
| 235 |
)
|
| 236 |
-
assert result.
|
| 237 |
-
assert result.
|
| 238 |
|
| 239 |
|
| 240 |
def test_tie_breaking_is_overlap_then_age_then_source_then_record_id():
|
|
@@ -244,4 +234,4 @@ def test_tie_breaking_is_overlap_then_age_then_source_then_record_id():
|
|
| 244 |
rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"),
|
| 245 |
]
|
| 246 |
result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3)
|
| 247 |
-
assert result.
|
|
|
|
| 91 |
exclude_local_context_frames=4,
|
| 92 |
)
|
| 93 |
assert [r.max_source_frame for r in result.records] == [1]
|
| 94 |
+
assert result.selected_frame_ids == [1]
|
| 95 |
+
assert result.scores.numel() == 1
|
| 96 |
+
assert result.best_selected_gap_frames.item() == pytest.approx(5.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
def test_revisit_abstains_when_no_valid_candidate():
|
| 100 |
result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4)
|
| 101 |
assert result.records == []
|
| 102 |
+
assert result.selected_frame_ids == []
|
| 103 |
+
assert result.scores.numel() == 0
|
| 104 |
+
assert result.best_selected_fov_overlap.item() == pytest.approx(0.0)
|
| 105 |
+
assert result.best_selected_plucker_overlap.item() == pytest.approx(0.0)
|
| 106 |
+
assert result.best_selected_gap_frames.item() == pytest.approx(-1.0)
|
| 107 |
|
| 108 |
|
| 109 |
def test_revisit_retrieval_rejects_non_vectorized_inputs():
|
|
|
|
| 148 |
exclude_local_context_frames=4,
|
| 149 |
topk=4,
|
| 150 |
)
|
| 151 |
+
assert [record.chunk_id for record in result.records] == ["c0"]
|
| 152 |
+
assert result.selected_frame_ids == [0]
|
| 153 |
+
assert result.best_selected_fov_overlap.item() == pytest.approx(1.0)
|
| 154 |
+
assert result.best_selected_plucker_overlap.item() > 0.0
|
| 155 |
+
assert result.best_selected_gap_frames.item() == pytest.approx(10.0)
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
|
| 158 |
def test_pose_preselect_uses_local_position_and_view_direction_before_fov():
|
|
|
|
| 172 |
pose_preselect_topk=1,
|
| 173 |
)
|
| 174 |
|
| 175 |
+
assert [record.chunk_id for record in result.records] == ["near_same_direction"]
|
| 176 |
+
assert result.selected_frame_ids == [2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
def test_selected_frame_carries_frame_metadata_for_projection():
|
|
|
|
| 187 |
topk=1,
|
| 188 |
)
|
| 189 |
|
| 190 |
+
assert [record.chunk_id for record in result.records] == ["frame_1"]
|
| 191 |
assert result.selected_frame_ids == [1]
|
| 192 |
assert result.records[0].metadata["dememwm_selected_frame_index"] == 1
|
| 193 |
+
assert result.records[0].metadata["dememwm_selected_frame_fov_overlap"] == pytest.approx(1.0)
|
| 194 |
assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
+
def test_high_quality_threshold_marks_selected_frame_metadata():
|
| 198 |
result = deterministic_revisit_retrieval(
|
| 199 |
[rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])],
|
| 200 |
target_frame=10,
|
|
|
|
| 204 |
exclude_local_context_frames=4,
|
| 205 |
topk=1,
|
| 206 |
)
|
| 207 |
+
assert [record.chunk_id for record in result.records] == ["c0"]
|
| 208 |
+
assert len(result.records) == 1
|
| 209 |
+
assert 0.30 <= result.best_selected_fov_overlap.item() < 0.70
|
| 210 |
+
assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is False
|
| 211 |
|
| 212 |
|
| 213 |
def test_video_metadata_does_not_filter_revisit_candidates():
|
|
|
|
| 223 |
exclude_local_context_frames=4,
|
| 224 |
topk=4,
|
| 225 |
)
|
| 226 |
+
assert [record.chunk_id for record in result.records] == ["c1", "c0"]
|
| 227 |
+
assert result.selected_frame_ids == [1, 0]
|
| 228 |
|
| 229 |
|
| 230 |
def test_tie_breaking_is_overlap_then_age_then_source_then_record_id():
|
|
|
|
| 234 |
rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"),
|
| 235 |
]
|
| 236 |
result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3)
|
| 237 |
+
assert [record.chunk_id for record in result.records] == ["c", "a", "b"]
|
tests/test_dememwm_schedules.py
CHANGED
|
@@ -4,9 +4,7 @@ from types import SimpleNamespace
|
|
| 4 |
|
| 5 |
from algorithms.worldmem.dememwm.schedules import (
|
| 6 |
compute_stream_gates,
|
| 7 |
-
|
| 8 |
-
noise_bucket_from_noise_levels,
|
| 9 |
-
noise_bucket_ids_from_noise_levels,
|
| 10 |
resolve_curriculum,
|
| 11 |
)
|
| 12 |
|
|
@@ -43,37 +41,16 @@ def test_two_stage_curriculum_switches_at_full_stage_start():
|
|
| 43 |
assert stage_1.anchor_enabled and stage_1.dynamic_enabled and stage_1.revisit_enabled
|
| 44 |
assert stage_1.dit_train_state == "frozen"
|
| 45 |
assert not hasattr(stage_1, "dit_late_blocks_trainable")
|
| 46 |
-
assert all("late" not in key for key in stage_1.diagnostics())
|
| 47 |
assert stage_2.stage == "stage_2"
|
| 48 |
assert stage_2.anchor_enabled and stage_2.dynamic_enabled and stage_2.revisit_enabled
|
| 49 |
assert stage_2.dit_train_state == "full"
|
| 50 |
|
| 51 |
|
| 52 |
-
def test_debug_force_all_streams_overrides_stage():
|
| 53 |
-
gates = compute_stream_gates("stage_1", debug_force_all_streams=True)
|
| 54 |
-
assert gates.anchor_enabled and gates.dynamic_enabled and gates.revisit_enabled
|
| 55 |
-
assert gates.reason == "debug_force_all_streams"
|
| 56 |
-
|
| 57 |
-
|
| 58 |
def test_unknown_stage_fails():
|
| 59 |
with pytest.raises(ValueError):
|
| 60 |
compute_stream_gates("unknown")
|
| 61 |
|
| 62 |
|
| 63 |
-
def
|
| 64 |
-
assert noise_bucket_from_denoising_fraction(0.0) == "high"
|
| 65 |
-
assert noise_bucket_from_denoising_fraction(0.5) == "mid"
|
| 66 |
-
assert noise_bucket_from_denoising_fraction(1.0) == "low"
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def test_noise_bucket_from_training_noise_levels():
|
| 70 |
-
import torch
|
| 71 |
-
assert noise_bucket_from_noise_levels(torch.tensor([9, 8]), 10) == "high"
|
| 72 |
-
assert noise_bucket_from_noise_levels(torch.tensor([5, 4]), 10) == "mid"
|
| 73 |
-
assert noise_bucket_from_noise_levels(torch.tensor([1, 0]), 10) == "low"
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def test_noise_bucket_ids_from_training_noise_levels():
|
| 77 |
import torch
|
| 78 |
-
|
| 79 |
-
assert bucket_ids.tolist() == [[0, 1, 2]]
|
|
|
|
| 4 |
|
| 5 |
from algorithms.worldmem.dememwm.schedules import (
|
| 6 |
compute_stream_gates,
|
| 7 |
+
denoising_fraction_from_noise_levels,
|
|
|
|
|
|
|
| 8 |
resolve_curriculum,
|
| 9 |
)
|
| 10 |
|
|
|
|
| 41 |
assert stage_1.anchor_enabled and stage_1.dynamic_enabled and stage_1.revisit_enabled
|
| 42 |
assert stage_1.dit_train_state == "frozen"
|
| 43 |
assert not hasattr(stage_1, "dit_late_blocks_trainable")
|
|
|
|
| 44 |
assert stage_2.stage == "stage_2"
|
| 45 |
assert stage_2.anchor_enabled and stage_2.dynamic_enabled and stage_2.revisit_enabled
|
| 46 |
assert stage_2.dit_train_state == "full"
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def test_unknown_stage_fails():
|
| 50 |
with pytest.raises(ValueError):
|
| 51 |
compute_stream_gates("unknown")
|
| 52 |
|
| 53 |
|
| 54 |
+
def test_denoising_fraction_from_training_noise_levels():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
import torch
|
| 56 |
+
assert denoising_fraction_from_noise_levels(torch.tensor([9, 0]), 10) == pytest.approx(0.5)
|
|
|
train_dememwm_full_berzelius.sh
CHANGED
|
@@ -51,7 +51,6 @@ srun python -m main \
|
|
| 51 |
++algorithm.context_frames=100 \
|
| 52 |
++algorithm.log_video=true \
|
| 53 |
++algorithm.diffusion.sampling_timesteps=20 \
|
| 54 |
-
++algorithm.dememwm.debug_force_all_streams=false \
|
| 55 |
++algorithm.dememwm.generated_history_proxy.enabled=true \
|
| 56 |
++algorithm.dememwm.generated_history_proxy.start_step=40000 \
|
| 57 |
++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
|
|
@@ -76,7 +75,6 @@ srun python -m main \
|
|
| 76 |
++algorithm.dememwm.revisit.plucker_weight=0.10 \
|
| 77 |
++algorithm.dememwm.revisit.max_frames=2 \
|
| 78 |
++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
|
| 79 |
-
++algorithm.dememwm.stage_policy.noise_bucket_logging=true \
|
| 80 |
++algorithm.dememwm.cache.enabled=true \
|
| 81 |
++algorithm.dememwm.cache.device=cpu \
|
| 82 |
++algorithm.dememwm.cache.keep_raw_latents=all \
|
|
|
|
| 51 |
++algorithm.context_frames=100 \
|
| 52 |
++algorithm.log_video=true \
|
| 53 |
++algorithm.diffusion.sampling_timesteps=20 \
|
|
|
|
| 54 |
++algorithm.dememwm.generated_history_proxy.enabled=true \
|
| 55 |
++algorithm.dememwm.generated_history_proxy.start_step=40000 \
|
| 56 |
++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
|
|
|
|
| 75 |
++algorithm.dememwm.revisit.plucker_weight=0.10 \
|
| 76 |
++algorithm.dememwm.revisit.max_frames=2 \
|
| 77 |
++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
|
|
|
|
| 78 |
++algorithm.dememwm.cache.enabled=true \
|
| 79 |
++algorithm.dememwm.cache.device=cpu \
|
| 80 |
++algorithm.dememwm.cache.keep_raw_latents=all \
|