BonanDing commited on
Commit
1dae740
·
1 Parent(s): 5dc5f97

Optimize DeMemWM memory retrieval and remove diagnostics

Browse files
algorithms/worldmem/dememwm/algorithm.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from __future__ import annotations
3
 
4
  import math
@@ -10,13 +9,13 @@ from einops import rearrange
10
 
11
  from .cache import StreamingCache
12
  from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens
13
- from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics
14
  from .injection import InjectionAdapter
15
  from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
16
  from .negatives import apply_revisit_eval_corruption
17
- from .retrieval import deterministic_revisit_retrieval
18
- from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, noise_bucket_from_denoising_fraction, noise_bucket_from_noise_levels, noise_bucket_ids_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum
19
  from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors
 
20
 
21
 
22
  class MemoryDiTMixin:
@@ -36,35 +35,9 @@ class MemoryDiTMixin:
36
  strict_key_substrings = (
37
  ".memory_token_cross_attn.",
38
  )
39
- _TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({
40
- "revisit_candidate_frame_count",
41
- "revisit_pose_preselect_input_count",
42
- "revisit_pose_preselect_selected_count",
43
- "revisit_exact_fov_candidate_count",
44
- "valid_revisit_frame_count",
45
- "no_valid_revisit_count",
46
- "revisit_selected_frame_count",
47
- "revisit_frame_fov_overlap_mean",
48
- "revisit_best_selected_frame_fov_overlap_mean",
49
- "revisit_best_selected_plucker_overlap_mean",
50
- "revisit_best_selected_gap_frames_mean",
51
- "revisit_gate_raw",
52
- "revisit_gate_eff",
53
- "revisit_learned_gate_mean",
54
- "revisit_effective_gate_mean",
55
- "generated_history_proxy_prob",
56
- "noise_bucket_target_count",
57
- "noise_bucket_high_target_count",
58
- "noise_bucket_mid_target_count",
59
- "noise_bucket_low_target_count",
60
- })
61
- _VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({
62
- "cache_records",
63
- "cache_slots",
64
- })
65
 
66
  def _memory_cfg(self):
67
- return getattr(self.cfg, "dememwm", None)
68
 
69
  def _cfg_get(self, obj, name, default):
70
  if obj is None:
@@ -84,8 +57,6 @@ class MemoryDiTMixin:
84
  except Exception:
85
  return False
86
 
87
- def _stage_policy_cfg(self):
88
- return self._cfg_get(self._memory_cfg(), "stage_policy", None)
89
 
90
  def _eval_ablation_cfg(self):
91
  return self._cfg_get(self._memory_cfg(), "eval_ablation", None)
@@ -99,7 +70,7 @@ class MemoryDiTMixin:
99
  branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal"))
100
  return enabled, branch
101
 
102
- def _effective_gate_state(self, denoising_fraction: float | None = None, noise_bucket: str | None = None) -> dict:
103
  memory_cfg = self._memory_cfg()
104
  anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
105
  dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
@@ -110,12 +81,9 @@ class MemoryDiTMixin:
110
  revisit_config_enabled = self._stream_enabled(revisit_cfg)
111
  curriculum_state = self._curriculum_state()
112
  eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state()
113
- debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False))
114
- resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction)
115
  gates = compute_stream_gates(
116
  curriculum_state.stage,
117
  denoising_fraction=denoising_fraction,
118
- debug_force_all_streams=debug_force,
119
  anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)),
120
  dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)),
121
  revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)),
@@ -139,7 +107,6 @@ class MemoryDiTMixin:
139
  return {
140
  "curriculum_state": curriculum_state,
141
  "gates": gates,
142
- "resolved_noise_bucket": resolved_noise_bucket,
143
  "anchor_config_enabled": anchor_config_enabled,
144
  "dynamic_config_enabled": dynamic_config_enabled,
145
  "revisit_config_enabled": revisit_config_enabled,
@@ -152,14 +119,13 @@ class MemoryDiTMixin:
152
  "force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"),
153
  }
154
 
155
- def _validate_config_contract(self) -> dict:
156
  if bool(getattr(self, "_dememwm_contract_validated", False)):
157
- return getattr(self, "_last_dememwm_config_diagnostics", {})
158
  memory_cfg = self._memory_cfg()
159
  if memory_cfg is None:
160
  self._dememwm_contract_validated = True
161
- self._last_dememwm_config_diagnostics = {}
162
- return {}
163
 
164
  stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)]
165
  if stale_sections:
@@ -193,10 +159,8 @@ class MemoryDiTMixin:
193
  if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)):
194
  raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval")
195
  fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
196
- if fov_overlap_threshold is not None:
197
- fov_overlap_threshold = float(fov_overlap_threshold)
198
- if fov_overlap_threshold < 0.0:
199
- raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative")
200
  high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
201
  if high_quality_fov_threshold < 0.0:
202
  raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative")
@@ -222,9 +186,6 @@ class MemoryDiTMixin:
222
  value = int(self._cfg_get(revisit_cfg, field_name, default))
223
  if value <= 0:
224
  raise ValueError(f"dememwm.revisit.{field_name} must be positive")
225
- stage_policy_cfg = self._stage_policy_cfg()
226
- if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)):
227
- raise ValueError("final DeMemWM keeps noise_bucket logging enabled")
228
  proxy_cfg = self._generated_history_proxy_cfg()
229
  proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0))
230
  proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0))
@@ -240,18 +201,7 @@ class MemoryDiTMixin:
240
  raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative")
241
  eval_ablation_cfg = self._eval_ablation_cfg()
242
  normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal"))
243
-
244
- diagnostics = {
245
- "dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
246
- "revisit_deterministic_fov_plucker_retrieval": True,
247
- "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
248
- "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
249
- "revisit_plucker_weight": plucker_weight,
250
- "stage_policy_noise_bucket_logging": True,
251
- }
252
  self._dememwm_contract_validated = True
253
- self._last_dememwm_config_diagnostics = diagnostics
254
- return diagnostics
255
 
256
  def _stream_enabled(self, stream_cfg) -> bool:
257
  return bool(self._cfg_get(stream_cfg, "enabled", True))
@@ -294,25 +244,17 @@ class MemoryDiTMixin:
294
  source_is_generated: torch.Tensor | None,
295
  context_frame_count: int | None = None,
296
  target_start_frame: int | None = None,
297
- ) -> tuple[torch.Tensor, torch.Tensor, dict]:
298
  cfg = self._generated_history_proxy_cfg()
299
  prob = self._generated_history_proxy_prob()
300
  noise_std = float(self._cfg_get(cfg, "noise_std", 0.0))
301
  dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0))
302
- diagnostics = {
303
- "generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)),
304
- "generated_history_proxy_prob": float(prob),
305
- "generated_history_proxy_noise_std": float(noise_std),
306
- "generated_history_proxy_dropout_prob": float(dropout_prob),
307
- "generated_history_proxy_frame_count": 0,
308
- "generated_history_proxy_frame_fraction": 0.0,
309
- }
310
  if source_is_generated is None:
311
  source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
312
  else:
313
  source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool)
314
  if prob <= 0.0 or source_latents.numel() == 0:
315
- return source_latents, source_is_generated, diagnostics
316
 
317
  eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
318
  if context_frame_count is not None or target_start_frame is not None:
@@ -322,12 +264,8 @@ class MemoryDiTMixin:
322
  if target_start_frame is not None:
323
  eligible_mask &= frame_positions < max(0, int(target_start_frame))
324
  proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask
325
- proxy_count = int(proxy_mask.detach().long().sum().item())
326
- total_count = max(1, int(proxy_mask.numel()))
327
- diagnostics["generated_history_proxy_frame_count"] = proxy_count
328
- diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count)
329
- if proxy_count == 0:
330
- return source_latents, source_is_generated, diagnostics
331
 
332
  corrupt_latents = source_latents.clone()
333
  frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype)
@@ -342,7 +280,7 @@ class MemoryDiTMixin:
342
  corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents)
343
  source_is_generated = source_is_generated.clone()
344
  source_is_generated |= proxy_mask
345
- return corrupt_latents, source_is_generated, diagnostics
346
 
347
  def _checkpoint_cfg(self):
348
  return self._cfg_get(self._memory_cfg(), "checkpoint", None)
@@ -397,62 +335,21 @@ class MemoryDiTMixin:
397
 
398
  def _apply_freeze_policy(self, optimizer=None, step: int | None = None):
399
  state = self._curriculum_state(step)
400
-
401
- # Keep DDP's trainable graph stable: DiT params stay requires_grad=True
402
- # from step 0 and are frozen by optimizer LR=0 until the full stage.
403
- # Re-walk only when curriculum diagnostics can change.
404
  freeze_key = (state.stage, state.dit_train_state, state.freeze_vae)
405
- last_key = getattr(self, "_last_freeze_key", None)
406
- if last_key != freeze_key:
407
- trainable_tensors = {
408
- "dememwm_modules": 0,
409
- "memory_adapters": 0,
410
- "full_dit": 0,
411
- "excluded_frozen": 0,
412
- }
413
- trainable_scalars = {key: 0 for key in trainable_tensors}
414
- requires_grad_tensors = {key: 0 for key in trainable_tensors}
415
- requires_grad_scalars = {key: 0 for key in trainable_tensors}
416
  for name, param in self.named_parameters():
417
  group_name = self._param_group_name(name, state)
418
- should_train = self._group_trainable(group_name, state)
419
  if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae):
420
- should_train = False
421
- should_require_grad = False
422
  else:
423
- should_require_grad = True
424
- param.requires_grad_(should_require_grad)
425
- if should_train:
426
- trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1
427
- trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel())
428
- if should_require_grad:
429
- requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1
430
- requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel())
431
  self._last_freeze_key = freeze_key
432
- self._last_trainable_tensors = trainable_tensors
433
- self._last_trainable_scalars = trainable_scalars
434
- self._last_requires_grad_tensors = requires_grad_tensors
435
- self._last_requires_grad_scalars = requires_grad_scalars
436
- else:
437
- trainable_tensors = getattr(self, "_last_trainable_tensors", {})
438
- trainable_scalars = getattr(self, "_last_trainable_scalars", {})
439
- requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {})
440
- requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {})
441
 
442
  if optimizer is not None:
443
  for param_group in optimizer.param_groups:
444
  group_name = param_group.get("name", "")
445
  trainable = self._group_trainable(group_name, state)
446
  param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0
447
-
448
- diagnostics = state.diagnostics()
449
- for group_name in ("dememwm_modules", "memory_adapters", "full_dit"):
450
- diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0)
451
- diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0)
452
- diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0)
453
- diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0)
454
- diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0
455
- self._last_dememwm_freeze_diagnostics = diagnostics
456
  return state
457
 
458
  def configure_optimizers(self):
@@ -1101,7 +998,7 @@ class MemoryDiTMixin:
1101
  plucker_weight: float,
1102
  revisit_retrieval_kwargs: dict | None,
1103
  token_patch_size: int,
1104
- ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]:
1105
  if committed_latents.ndim != 5:
1106
  raise ValueError("committed_latents must have shape (T_src,B,C,H,W)")
1107
  T_src, B, _, H, W = committed_latents.shape
@@ -1126,12 +1023,6 @@ class MemoryDiTMixin:
1126
  revisit_banks: list[CausalMemoryBank] = []
1127
  dummy_tokens = committed_latents.new_zeros((1, hidden_size))
1128
  dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool)
1129
- preselection_candidate_count = 0
1130
- preselection_valid_candidate_label_count = 0
1131
- preselection_selected_count = 0
1132
- projected_anchor_frames = 0
1133
- projected_revisit_frames = 0
1134
- projected_revisit_records = 0
1135
  retrieval_kwargs = dict(revisit_retrieval_kwargs or {})
1136
 
1137
  # Pre-convert pose tensors to stream_device once so that the
@@ -1188,6 +1079,175 @@ class MemoryDiTMixin:
1188
  def _pose_subset(positions: torch.Tensor, batch_idx: int):
1189
  return _tensor_subset(pose, positions, batch_idx)
1190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1191
  def _candidate_record(
1192
  *,
1193
  batch_idx: int,
@@ -1238,7 +1298,6 @@ class MemoryDiTMixin:
1238
  if selected_anchor_positions:
1239
  anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
1240
  if anchor_positions.numel() > 0:
1241
- projected_anchor_frames += int(anchor_positions.numel())
1242
  anchor_projected = self._project_latent_patch_tokens(
1243
  committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
1244
  self.dememwm_anchor_proj,
@@ -1258,9 +1317,7 @@ class MemoryDiTMixin:
1258
 
1259
  candidate_records: list[MemoryRecord] = []
1260
  candidate_positions: dict[str, torch.Tensor] = {}
1261
- src_frames_cpu = src_frames.detach().cpu()
1262
- target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long)
1263
- latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames)
1264
  for prefix, positions, source_type, is_generated in (
1265
  ("prefix", source_positions, MemorySourceType.PREFIX_GT, False),
1266
  (
@@ -1270,14 +1327,15 @@ class MemoryDiTMixin:
1270
  True,
1271
  ),
1272
  ):
1273
- if positions.numel() == 0 or latest_valid_source_frame_exclusive <= 0:
1274
  continue
1275
- positions_cpu = positions.detach().cpu().to(dtype=torch.long)
1276
- for frame_position_cpu in positions_cpu:
1277
- frame = int(src_frames_cpu[int(frame_position_cpu.item())].item())
1278
- if frame >= latest_valid_source_frame_exclusive:
1279
- continue
1280
- frame_position = frame_position_cpu.reshape(1).to(device=stream_device, dtype=torch.long)
 
1281
  record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}"
1282
  candidate_positions[record_id] = frame_position
1283
  candidate_records.append(_candidate_record(
@@ -1301,12 +1359,9 @@ class MemoryDiTMixin:
1301
  exclude_local_context_frames=exclude_local_context_frames,
1302
  fov_overlap_threshold=fov_overlap_threshold,
1303
  plucker_weight=plucker_weight,
1304
- target_video_id=_target_video_id(batch_idx, target_idx),
1305
  **retrieval_kwargs,
1306
  )
1307
- preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0)))
1308
- preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0))
1309
- preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0)))
1310
  for selected_record in result.records:
1311
  if selected_record.chunk_id is None:
1312
  continue
@@ -1319,8 +1374,6 @@ class MemoryDiTMixin:
1319
  continue
1320
  record_id = str(record.chunk_id)
1321
  frame_position = candidate_positions[record_id]
1322
- projected_revisit_records += 1
1323
- projected_revisit_frames += int(frame_position.numel())
1324
  revisit_projected = self._project_latent_patch_tokens(
1325
  committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1],
1326
  self.dememwm_revisit_proj,
@@ -1344,17 +1397,7 @@ class MemoryDiTMixin:
1344
  anchor_banks.append(anchor_bank)
1345
  revisit_banks.append(revisit_bank)
1346
 
1347
- diagnostics = {
1348
- "preselected_anchor_projected_frame_count": projected_anchor_frames,
1349
- "preselected_revisit_projected_frame_count": projected_revisit_frames,
1350
- "preselected_revisit_projected_frame_record_count": projected_revisit_records,
1351
- "preselected_revisit_candidate_frame_count": preselection_candidate_count,
1352
- "preselected_revisit_candidate_count": preselection_candidate_count,
1353
- "preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count,
1354
- "preselected_revisit_selected_frame_count": preselection_selected_count,
1355
- "preselected_revisit_selected_count": preselection_selected_count,
1356
- }
1357
- return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
1358
 
1359
  def _causal_cached_revisit_records(
1360
  self,
@@ -1486,8 +1529,6 @@ class MemoryDiTMixin:
1486
  target_video_ids=None,
1487
  source_is_generated: torch.Tensor | None = None,
1488
  denoising_fraction: float | None = None,
1489
- noise_bucket: str | None = None,
1490
- noise_bucket_ids: torch.Tensor | None = None,
1491
  streaming_cache: StreamingCache | None = None,
1492
  ) -> MemoryStreamTensors:
1493
  if target_frame_indices is None:
@@ -1499,10 +1540,9 @@ class MemoryDiTMixin:
1499
  dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
1500
  revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
1501
  injection_cfg = self._cfg_get(memory_cfg, "injection", None)
1502
- contract_diag = self._validate_config_contract()
1503
  gate_state = self._effective_gate_state(
1504
  denoising_fraction=denoising_fraction,
1505
- noise_bucket=noise_bucket,
1506
  )
1507
  anchor_config_enabled = gate_state["anchor_config_enabled"]
1508
  dynamic_config_enabled = gate_state["dynamic_config_enabled"]
@@ -1510,7 +1550,6 @@ class MemoryDiTMixin:
1510
  curriculum_state = gate_state["curriculum_state"]
1511
  eval_ablation_enabled = gate_state["eval_ablation_enabled"]
1512
  eval_ablation_branch = gate_state["eval_ablation_branch"]
1513
- resolved_noise_bucket = gate_state["resolved_noise_bucket"]
1514
  gates = gate_state["gates"]
1515
  anchor_effective_enabled = gate_state["anchor_effective_enabled"]
1516
  dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
@@ -1565,12 +1604,12 @@ class MemoryDiTMixin:
1565
  "plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)),
1566
  "plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)),
1567
  }
1568
- preselection_diag = {}
 
1569
  use_cache_revisit_records = False
1570
  revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None
1571
 
1572
  cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None
1573
- cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0}
1574
  if committed_latents is not None:
1575
  stream_device = committed_latents.device
1576
  stream_dtype = committed_latents.dtype
@@ -1638,7 +1677,7 @@ class MemoryDiTMixin:
1638
  B = committed_latents.shape[1]
1639
  hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
1640
  target_pose_source = target_pose if target_pose is not None else pose
1641
- anchor_banks, revisit_banks, tokens_per_frame, preselection_diag = self._build_preselected_causal_memory_banks(
1642
  committed_latents,
1643
  source_frame_indices.to(device=stream_device),
1644
  None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool),
@@ -1710,29 +1749,21 @@ class MemoryDiTMixin:
1710
  dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w)
1711
  dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype)
1712
  dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool)
1713
- dynamic_diag = {
1714
- "selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
1715
- "max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1716
- "generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device),
1717
- "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1718
- "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1719
- "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
1720
- "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
1721
- }
1722
  else:
1723
  # Pre-select dynamic source frame positions using only frame index metadata
1724
  # before touching latents, so we pass a small slice instead of the full
1725
  # 1000-frame tensor to the compressor.
1726
  _dfi = dynamic_frame_indices.to(device=stream_device)
1727
  _max_src = self.dememwm_dynamic_compressor.max_source_frames
1728
- _needed: list[int] = []
1729
  for _b in range(B):
1730
  for _j in range(T_tgt):
1731
- _target = int(target_frame_indices[_j, _b].item())
1732
  _valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
1733
- _needed.extend(_valid[-_max_src:].tolist())
1734
- if _needed:
1735
- _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long)
 
1736
  _dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx)
1737
  _dynamic_fi_small = _dfi.index_select(0, _needed_idx)
1738
  _dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None
@@ -1745,7 +1776,7 @@ class MemoryDiTMixin:
1745
  _dynamic_fi_small = _dfi[:0]
1746
  _dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None
1747
  _dynamic_gen_small = None
1748
- dynamic_tokens, dynamic_mask, dynamic_diag = self.dememwm_dynamic_compressor(
1749
  _dynamic_latents_small,
1750
  _dynamic_fi_small,
1751
  _dynamic_pose_small,
@@ -1754,18 +1785,6 @@ class MemoryDiTMixin:
1754
  exclude_latest_local_frames=dynamic_recent_exclusion_frames,
1755
  )
1756
 
1757
- dynamic_min_gap_tensor = torch.as_tensor(
1758
- dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)),
1759
- device=stream_device,
1760
- )
1761
- dynamic_max_gap_tensor = torch.as_tensor(
1762
- dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)),
1763
- device=stream_device,
1764
- )
1765
- dynamic_gap_valid = dynamic_min_gap_tensor >= 0
1766
- dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1
1767
- dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0
1768
- dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1
1769
  def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int):
1770
  if tensor is None or tensor.ndim < 2:
1771
  return None
@@ -1804,15 +1823,11 @@ class MemoryDiTMixin:
1804
  revisit_mask_rows = []
1805
  revisit_max_rows = []
1806
  valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
1807
- revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1808
- revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1809
  revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1810
  revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1811
  revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
1812
  eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
1813
- revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
1814
  eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
1815
- revisit_result_diagnostics = []
1816
  projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {}
1817
  if revisit_record_batches is None:
1818
  revisit_record_batches = [tuple(bank.records) for bank in revisit_banks]
@@ -1823,49 +1838,65 @@ class MemoryDiTMixin:
1823
  batch_max_rows = []
1824
  for target_idx in range(T_tgt):
1825
  target_frame = int(target_frame_indices[target_idx, batch_idx].item())
1826
- if use_cache_revisit_records:
1827
- candidate_records = self._causal_cached_revisit_records(
1828
- revisit_record_batches[batch_idx],
1829
- target_frame,
1830
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
1831
  else:
1832
- candidate_records = revisit_bank.query(
1833
- MemoryBankQuery(
1834
- target_frame=target_frame,
1835
- include_generated=True,
1836
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1837
  )
1838
- result = deterministic_revisit_retrieval(
1839
- candidate_records,
1840
- target_frame=target_frame,
1841
- target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
1842
- target_summary=None,
1843
- topk=revisit_max_frames,
1844
- exclude_local_context_frames=revisit_context_window_exclusion_frames,
1845
- fov_overlap_threshold=fov_overlap_threshold,
1846
- plucker_weight=plucker_weight,
1847
- target_video_id=_target_video_id_or_none(batch_idx, target_idx),
1848
- **revisit_retrieval_kwargs,
1849
- )
1850
- selected_records = result.records
1851
- if use_cache_revisit_records and selected_records:
1852
- selected_records = self._project_streaming_revisit_records(
1853
- cache=cache,
1854
- batch_idx=batch_idx,
1855
- records=selected_records,
1856
- device=stream_device,
1857
- dtype=stream_dtype,
1858
- token_patch_size=token_patch_size,
1859
- revisit_pool_h=revisit_pool_h,
1860
- revisit_pool_w=revisit_pool_w,
1861
- projection_cache=projected_revisit_record_cache,
1862
- )
1863
- revisit_result_diagnostics.append(result.diagnostics)
1864
- revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0)))
1865
- revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0)))
1866
- revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
1867
- revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
1868
- revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
1869
  revisit_bank.assert_causal(target_frame, selected_records)
1870
  if selected_records:
1871
  valid_revisit_mask[batch_idx, target_idx] = True
@@ -1876,7 +1907,6 @@ class MemoryDiTMixin:
1876
  stream_device,
1877
  stream_dtype,
1878
  )
1879
- revisit_causal_max[batch_idx, target_idx] = max_source_frame
1880
  if eval_corruption_enabled:
1881
  stream_tokens, was_corrupted = apply_revisit_eval_corruption(
1882
  tokens=stream_tokens,
@@ -1932,8 +1962,6 @@ class MemoryDiTMixin:
1932
  if not revisit_stage_config_enabled:
1933
  revisit_mask = torch.zeros_like(revisit_mask)
1934
  valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
1935
- revisit_candidate_count = torch.zeros_like(revisit_candidate_count)
1936
- revisit_selected_count = torch.zeros_like(revisit_selected_count)
1937
  revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
1938
  revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
1939
  revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
@@ -1941,85 +1969,6 @@ class MemoryDiTMixin:
1941
  valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
1942
  revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
1943
  revisit_gate = torch.zeros_like(revisit_gate)
1944
- no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask)
1945
- revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask)
1946
- causal_violation_count = 0
1947
- for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max):
1948
- if source_max is None:
1949
- continue
1950
- source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device)
1951
- valid = source_max_t >= 0
1952
- if valid.any():
1953
- causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item())
1954
- diagnostics = {
1955
- **curriculum_state.diagnostics(),
1956
- **getattr(self, "_last_dememwm_freeze_diagnostics", {}),
1957
- **contract_diag,
1958
- **cache_diag,
1959
- **preselection_diag,
1960
- **revisit_diag,
1961
- "dememwm_stage": gates.stage,
1962
- "dememwm_gate_reason": gates.reason,
1963
- "anchor_config_enabled": anchor_config_enabled,
1964
- "dynamic_config_enabled": dynamic_config_enabled,
1965
- "revisit_config_enabled": revisit_config_enabled,
1966
- "anchor_effective_enabled": anchor_effective_enabled,
1967
- "dynamic_effective_enabled": dynamic_effective_enabled,
1968
- "revisit_effective_enabled": revisit_effective_enabled,
1969
- "revisit_stage_config_enabled": revisit_stage_config_enabled,
1970
- "revisit_gate_raw": revisit_gate_raw.detach(),
1971
- "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
1972
- "no_valid_revisit_mask": no_valid_revisit_mask,
1973
- "valid_revisit_eff_mask": valid_revisit_eff_mask,
1974
- "revisit_candidate_frame_count_per_target": revisit_candidate_count,
1975
- "revisit_selected_frame_count_per_target": revisit_selected_count,
1976
- "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
1977
- "revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap,
1978
- "revisit_selected_gap_frames_per_target": revisit_selected_gap_frames,
1979
- "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0,
1980
- "revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()),
1981
- **summarize_noise_bucket_diagnostics(
1982
- noise_bucket=resolved_noise_bucket,
1983
- valid_revisit_mask=valid_revisit_mask,
1984
- no_valid_revisit_mask=no_valid_revisit_mask,
1985
- noise_bucket_ids=noise_bucket_ids,
1986
- ),
1987
- **summarize_eval_ablation_diagnostics(
1988
- enabled=eval_ablation_enabled,
1989
- branch=eval_ablation_branch,
1990
- valid_revisit_mask=valid_revisit_mask,
1991
- no_valid_revisit_mask=no_valid_revisit_mask,
1992
- eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None,
1993
- ),
1994
- "token_patch_size": token_patch_size,
1995
- "tokens_per_frame": tokens_per_frame,
1996
- "anchor_token_slots": int(anchor_tokens.shape[-2]),
1997
- "anchor_target_slots": anchor_num_tokens,
1998
- "anchor_pool_h": anchor_pool_h,
1999
- "anchor_pool_w": anchor_pool_w,
2000
- "dynamic_token_slots": int(dynamic_tokens.shape[-2]),
2001
- "dynamic_target_slots": int(dynamic_tokens.shape[-2]),
2002
- "dynamic_min_gap_to_target": dynamic_min_gap_to_target,
2003
- "dynamic_max_gap_to_target": dynamic_max_gap_to_target,
2004
- "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
2005
- "revisit_token_slots": int(revisit_tokens.shape[-2]),
2006
- "revisit_target_slots": revisit_target_slots,
2007
- "revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames,
2008
- "revisit_pool_h": revisit_pool_h,
2009
- "revisit_pool_w": revisit_pool_w,
2010
- "revisit_max_frames": revisit_max_frames,
2011
- "anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0,
2012
- "dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0,
2013
- "revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0,
2014
- "causal_violation_count": causal_violation_count,
2015
- "anchor_max_source_frame": anchor_max,
2016
- "dynamic_max_source_frame": dynamic_diag.get("max_source_frame"),
2017
- "revisit_max_source_frame": revisit_max,
2018
- "dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"),
2019
- }
2020
- if eval_corruption_enabled:
2021
- diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask
2022
-
2023
  return MemoryStreamTensors(
2024
  anchor_tokens=anchor_tokens,
2025
  anchor_mask=anchor_mask,
@@ -2032,19 +1981,18 @@ class MemoryDiTMixin:
2032
  revisit_gate=revisit_gate,
2033
  revisit_gate_raw=revisit_gate_raw,
2034
  valid_revisit_mask=valid_revisit_mask,
2035
- no_valid_revisit_mask=no_valid_revisit_mask,
2036
- diagnostics=diagnostics,
 
2037
  )
2038
 
2039
  def _refresh_stream_gates(
2040
  self,
2041
  streams: MemoryStreamTensors,
2042
  denoising_fraction: float | None = None,
2043
- noise_bucket: str | None = None,
2044
  ) -> MemoryStreamTensors:
2045
  gate_state = self._effective_gate_state(
2046
  denoising_fraction=denoising_fraction,
2047
- noise_bucket=noise_bucket,
2048
  )
2049
  gates = gate_state["gates"]
2050
  device = streams.anchor_tokens.device
@@ -2056,20 +2004,17 @@ class MemoryDiTMixin:
2056
  else:
2057
  valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool)
2058
 
2059
- diagnostics = dict(streams.diagnostics)
2060
-
2061
- def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor:
2062
- value = diagnostics.get(name)
2063
  if value is None:
2064
  return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32)
2065
- tensor = torch.as_tensor(value, device=device, dtype=torch.float32)
2066
  if tensor.ndim == 0:
2067
  return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32)
2068
  return tensor.expand((B, T_tgt))
2069
 
2070
- revisit_best_selected_fov_overlap = _diagnostic_tensor("revisit_best_selected_fov_overlap_per_target")
2071
- revisit_best_selected_plucker_overlap = _diagnostic_tensor("revisit_best_selected_plucker_overlap_per_target")
2072
- revisit_selected_gap_frames = _diagnostic_tensor("revisit_selected_gap_frames_per_target", -1.0)
2073
 
2074
  anchor_effective_enabled = gate_state["anchor_effective_enabled"]
2075
  dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
@@ -2086,53 +2031,16 @@ class MemoryDiTMixin:
2086
  best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
2087
  selected_gap_frames=revisit_selected_gap_frames,
2088
  ).to(device=device, dtype=dtype)
2089
- valid_revisit_eff_mask = valid_revisit_mask
2090
  if not revisit_stage_config_enabled or gate_state["force_revisit_off"]:
2091
  revisit_gate = torch.zeros_like(revisit_gate_raw)
2092
  elif gate_state["force_revisit_on"]:
2093
- revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw)
2094
  else:
2095
- revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate)
2096
  if not revisit_stage_config_enabled:
2097
  valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
2098
- valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
2099
  revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
2100
  revisit_gate = torch.zeros_like(revisit_gate)
2101
- no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask)
2102
- eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask")
2103
- if eval_corrupted_revisit_mask is not None:
2104
- eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool)
2105
- revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item())
2106
- diagnostics.update(gate_state["curriculum_state"].diagnostics())
2107
- diagnostics.update({
2108
- "dememwm_stage": gates.stage,
2109
- "dememwm_gate_reason": gates.reason,
2110
- "anchor_config_enabled": gate_state["anchor_config_enabled"],
2111
- "dynamic_config_enabled": gate_state["dynamic_config_enabled"],
2112
- "revisit_config_enabled": gate_state["revisit_config_enabled"],
2113
- "anchor_effective_enabled": anchor_effective_enabled,
2114
- "dynamic_effective_enabled": dynamic_effective_enabled,
2115
- "revisit_effective_enabled": revisit_effective_enabled,
2116
- "revisit_stage_config_enabled": revisit_stage_config_enabled,
2117
- "revisit_gate_raw": revisit_gate_raw.detach(),
2118
- "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
2119
- "no_valid_revisit_mask": no_valid_revisit_mask,
2120
- "valid_revisit_eff_mask": valid_revisit_eff_mask,
2121
- "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0,
2122
- "revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0,
2123
- })
2124
- diagnostics.update(summarize_noise_bucket_diagnostics(
2125
- noise_bucket=gate_state["resolved_noise_bucket"],
2126
- valid_revisit_mask=valid_revisit_mask,
2127
- no_valid_revisit_mask=no_valid_revisit_mask,
2128
- ))
2129
- diagnostics.update(summarize_eval_ablation_diagnostics(
2130
- enabled=gate_state["eval_ablation_enabled"],
2131
- branch=gate_state["eval_ablation_branch"],
2132
- valid_revisit_mask=valid_revisit_mask,
2133
- no_valid_revisit_mask=no_valid_revisit_mask,
2134
- eval_corrupted_revisit_mask=eval_corrupted_revisit_mask,
2135
- ))
2136
  return replace(
2137
  streams,
2138
  anchor_gate=anchor_gate,
@@ -2140,84 +2048,90 @@ class MemoryDiTMixin:
2140
  revisit_gate=revisit_gate,
2141
  revisit_gate_raw=revisit_gate_raw,
2142
  valid_revisit_mask=valid_revisit_mask,
2143
- no_valid_revisit_mask=no_valid_revisit_mask,
2144
- diagnostics=diagnostics,
 
2145
  )
2146
 
2147
- def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> tuple[dict, dict]:
2148
- memory_kwargs, diagnostics = self.dememwm_injection_adapter(streams, device=streams.anchor_tokens.device, dtype=streams.anchor_tokens.dtype)
2149
- return memory_kwargs, diagnostics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2150
 
2151
- def build_memory_kwargs(self, *args, **kwargs) -> tuple[dict, dict]:
2152
  streams = self.build_memory_streams(*args, **kwargs)
2153
  return self._streams_to_kwargs(streams)
2154
 
2155
- def _memory_adapter_delta_diagnostics(self) -> dict:
2156
- dit_model = getattr(getattr(self, "diffusion_model", None), "model", None)
2157
- diagnostics_fn = getattr(dit_model, "memory_adapter_delta_diagnostics", None)
2158
- if diagnostics_fn is None:
2159
- return {}
2160
- return diagnostics_fn()
2161
-
2162
- def _log_memory_diagnostics(self, namespace: str, diagnostics: dict) -> None:
2163
- if namespace == "training/dememwm":
2164
- allowed_keys = self._TRAIN_DIAGNOSTIC_LOG_KEYS
2165
- elif namespace.endswith("/dememwm"):
2166
- allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS
2167
- else:
2168
- allowed_keys = None
2169
- for key, value in diagnostics.items():
2170
- if allowed_keys is not None and key not in allowed_keys:
2171
- continue
2172
- if isinstance(value, str) or value is None:
2173
- continue
2174
- if torch.is_tensor(value):
2175
- if value.numel() > 0:
2176
- self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True)
2177
- elif isinstance(value, (bool, int, float)):
2178
- self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True)
2179
-
2180
- def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx):
2181
- from ..df_video import convert_to_plucker
2182
- image_height, image_width = self._image_size(xs)
2183
  if self.use_plucker:
2184
  if self.relative_embedding:
 
 
 
 
 
 
 
2185
  input_pose_condition = []
2186
  frame_idx_list = []
2187
- ref_c2w = c2w_mat[-self.memory_condition_length:] if self.memory_condition_length else c2w_mat[:0]
2188
- ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0]
2189
- for i in range(c2w_mat.shape[0]):
2190
  input_pose_condition.append(
2191
  convert_to_plucker(
2192
- torch.cat([c2w_mat[i:i + 1], ref_c2w]).clone(),
2193
  0,
2194
  focal_length=self.focal_length,
2195
- image_height=image_height, image_width=image_width
2196
- ).to(xs.dtype)
 
 
 
 
 
 
 
2197
  )
2198
- frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone())
2199
  return torch.cat(input_pose_condition), torch.cat(frame_idx_list)
2200
- return convert_to_plucker(
2201
- c2w_mat, 0, focal_length=self.focal_length,
2202
- image_height=image_height, image_width=image_width
2203
- ).to(xs.dtype), frame_idx
2204
- return pose_conditions.to(xs.dtype), None
2205
-
2206
- def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]:
2207
- total_frames = max(0, int(total_frames))
2208
- n_tokens = max(1, min(int(self.n_tokens), total_frames))
2209
- max_start = max(0, total_frames - n_tokens)
2210
- if max_start == 0:
2211
- return 0, n_tokens
2212
- context_start = self._context_frame_count()
2213
- min_start = min(context_start, max_start)
2214
- if min_start == max_start:
2215
- return min_start, min_start + n_tokens
2216
- start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item())
2217
- return start, start + n_tokens
2218
 
2219
  def training_step(self, batch, batch_idx):
2220
  xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
 
2221
  xs = self._as_latents(xs)
2222
 
2223
  # Randomly select a contiguous n_tokens denoising window inside the long
@@ -2231,7 +2145,12 @@ class MemoryDiTMixin:
2231
  frame_idx_window = frame_idx[start:end]
2232
 
2233
  input_pose_condition, frame_idx_list = self._training_pose_condition(
2234
- xs_window, pose_conditions[start:end], c2w_mat[start:end], frame_idx_window
 
 
 
 
 
2235
  )
2236
 
2237
  noise_levels = self._generate_noise_levels(xs_window)
@@ -2239,17 +2158,15 @@ class MemoryDiTMixin:
2239
  noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
2240
  conditions_window[-self.memory_condition_length:] *= 0
2241
  source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool)
2242
- memory_source_latents, source_is_generated, proxy_diagnostics = self._apply_generated_history_proxy(
2243
  xs,
2244
  source_is_generated,
2245
  context_frame_count=self._context_frame_count(),
2246
  target_start_frame=start,
2247
  )
2248
  timesteps = int(getattr(self, "timesteps", 0) or 0)
2249
- training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps)
2250
- training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps)
2251
  training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps)
2252
- memory_kwargs, diagnostics = self.build_memory_kwargs(
2253
  memory_source_latents,
2254
  frame_idx,
2255
  target_frame_indices=frame_idx_window,
@@ -2259,10 +2176,7 @@ class MemoryDiTMixin:
2259
  target_action=conditions_window,
2260
  source_is_generated=source_is_generated,
2261
  denoising_fraction=training_denoising_fraction,
2262
- noise_bucket=training_noise_bucket,
2263
- noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1),
2264
  )
2265
- diagnostics.update(proxy_diagnostics)
2266
  _, loss = self.diffusion_model(
2267
  xs_window,
2268
  conditions_window,
@@ -2272,19 +2186,19 @@ class MemoryDiTMixin:
2272
  frame_idx=frame_idx_list,
2273
  **memory_kwargs,
2274
  )
2275
- diagnostics.update(self._memory_adapter_delta_diagnostics())
2276
  if self.memory_condition_length:
2277
  loss = loss[:-self.memory_condition_length]
2278
  loss_denoise = self.reweight_loss(loss, None)
2279
  loss_total = loss_denoise
2280
- diagnostics["training_window_start"] = int(start)
2281
- diagnostics["training_window_end"] = int(end)
2282
- diagnostics["training_window_size"] = int(end - start)
2283
- diagnostics["loss_denoise"] = float(loss_denoise.detach().item())
2284
- diagnostics["loss_total"] = float(loss_total.detach().item())
2285
  if batch_idx % 20 == 0:
2286
- self.log("training/loss", loss_total.detach().cpu())
2287
- self._log_memory_diagnostics("training/dememwm", diagnostics)
 
 
 
 
 
 
2288
  return {"loss": loss_total}
2289
 
2290
  def validation_step(self, batch, batch_idx, namespace="validation"):
@@ -2308,7 +2222,6 @@ class MemoryDiTMixin:
2308
  streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}")
2309
  cached_until = 0
2310
  pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
2311
- last_diagnostics = None
2312
  while curr_frame < n_frames:
2313
  if streaming_cache is not None and curr_frame > cached_until:
2314
  new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device)
@@ -2375,7 +2288,7 @@ class MemoryDiTMixin:
2375
  from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length)
2376
  denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0)
2377
  step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac)
2378
- memory_kwargs, last_diagnostics = self._streams_to_kwargs(step_streams)
2379
  xs_pred[start_frame:] = self.diffusion_model.sample_step(
2380
  xs_pred[start_frame:].to(input_condition.device),
2381
  input_condition,
@@ -2405,12 +2318,8 @@ class MemoryDiTMixin:
2405
  action=conditions[cached_until:curr_frame],
2406
  )
2407
  cached_until = curr_frame
2408
- if last_diagnostics is not None:
2409
- last_diagnostics.update(streaming_cache.diagnostics("cache"))
2410
  pbar.update(horizon)
2411
  pbar.close()
2412
- if last_diagnostics is not None:
2413
- self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics)
2414
  xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
2415
  xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
2416
  self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
@@ -2440,10 +2349,7 @@ class MemoryDiTMixin:
2440
  # Compatibility aliases for old DeMemWM test and experiment call sites.
2441
  dememwm_strict_key_prefixes = strict_key_prefixes
2442
  dememwm_strict_key_substrings = strict_key_substrings
2443
- _DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS
2444
- _DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS
2445
  _dememwm_cfg = _memory_cfg
2446
- _dememwm_stage_policy_cfg = _stage_policy_cfg
2447
  _dememwm_eval_ablation_cfg = _eval_ablation_cfg
2448
  _dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg
2449
  _dememwm_eval_ablation_state = _eval_ablation_state
@@ -2476,8 +2382,6 @@ class MemoryDiTMixin:
2476
  _dememwm_refresh_stream_gates = _refresh_stream_gates
2477
  _dememwm_streams_to_kwargs = _streams_to_kwargs
2478
  build_dememwm_memory_kwargs = build_memory_kwargs
2479
- _dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics
2480
- _log_dememwm_diagnostics = _log_memory_diagnostics
2481
  _dememwm_training_window_bounds = _training_window_bounds
2482
  strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check
2483
 
 
 
1
  from __future__ import annotations
2
 
3
  import math
 
9
 
10
  from .cache import StreamingCache
11
  from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens
 
12
  from .injection import InjectionAdapter
13
  from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
14
  from .negatives import apply_revisit_eval_corruption
15
+ from .retrieval import batched_revisit_select_positions, deterministic_revisit_retrieval
16
+ from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum
17
  from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors
18
+ from ..df_video import convert_to_plucker
19
 
20
 
21
  class MemoryDiTMixin:
 
35
  strict_key_substrings = (
36
  ".memory_token_cross_attn.",
37
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def _memory_cfg(self):
40
+ return getattr(getattr(self, "cfg", None), "dememwm", None)
41
 
42
  def _cfg_get(self, obj, name, default):
43
  if obj is None:
 
57
  except Exception:
58
  return False
59
 
 
 
60
 
61
  def _eval_ablation_cfg(self):
62
  return self._cfg_get(self._memory_cfg(), "eval_ablation", None)
 
70
  branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal"))
71
  return enabled, branch
72
 
73
+ def _effective_gate_state(self, denoising_fraction: float | None = None) -> dict:
74
  memory_cfg = self._memory_cfg()
75
  anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
76
  dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
 
81
  revisit_config_enabled = self._stream_enabled(revisit_cfg)
82
  curriculum_state = self._curriculum_state()
83
  eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state()
 
 
84
  gates = compute_stream_gates(
85
  curriculum_state.stage,
86
  denoising_fraction=denoising_fraction,
 
87
  anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)),
88
  dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)),
89
  revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)),
 
107
  return {
108
  "curriculum_state": curriculum_state,
109
  "gates": gates,
 
110
  "anchor_config_enabled": anchor_config_enabled,
111
  "dynamic_config_enabled": dynamic_config_enabled,
112
  "revisit_config_enabled": revisit_config_enabled,
 
119
  "force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"),
120
  }
121
 
122
+ def _validate_config_contract(self) -> None:
123
  if bool(getattr(self, "_dememwm_contract_validated", False)):
124
+ return
125
  memory_cfg = self._memory_cfg()
126
  if memory_cfg is None:
127
  self._dememwm_contract_validated = True
128
+ return
 
129
 
130
  stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)]
131
  if stale_sections:
 
159
  if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)):
160
  raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval")
161
  fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
162
+ if fov_overlap_threshold is not None and float(fov_overlap_threshold) < 0.0:
163
+ raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative")
 
 
164
  high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
165
  if high_quality_fov_threshold < 0.0:
166
  raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative")
 
186
  value = int(self._cfg_get(revisit_cfg, field_name, default))
187
  if value <= 0:
188
  raise ValueError(f"dememwm.revisit.{field_name} must be positive")
 
 
 
189
  proxy_cfg = self._generated_history_proxy_cfg()
190
  proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0))
191
  proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0))
 
201
  raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative")
202
  eval_ablation_cfg = self._eval_ablation_cfg()
203
  normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal"))
 
 
 
 
 
 
 
 
 
204
  self._dememwm_contract_validated = True
 
 
205
 
206
  def _stream_enabled(self, stream_cfg) -> bool:
207
  return bool(self._cfg_get(stream_cfg, "enabled", True))
 
244
  source_is_generated: torch.Tensor | None,
245
  context_frame_count: int | None = None,
246
  target_start_frame: int | None = None,
247
+ ) -> tuple[torch.Tensor, torch.Tensor]:
248
  cfg = self._generated_history_proxy_cfg()
249
  prob = self._generated_history_proxy_prob()
250
  noise_std = float(self._cfg_get(cfg, "noise_std", 0.0))
251
  dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0))
 
 
 
 
 
 
 
 
252
  if source_is_generated is None:
253
  source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
254
  else:
255
  source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool)
256
  if prob <= 0.0 or source_latents.numel() == 0:
257
+ return source_latents, source_is_generated
258
 
259
  eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool)
260
  if context_frame_count is not None or target_start_frame is not None:
 
264
  if target_start_frame is not None:
265
  eligible_mask &= frame_positions < max(0, int(target_start_frame))
266
  proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask
267
+ if not proxy_mask.any():
268
+ return source_latents, source_is_generated
 
 
 
 
269
 
270
  corrupt_latents = source_latents.clone()
271
  frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype)
 
280
  corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents)
281
  source_is_generated = source_is_generated.clone()
282
  source_is_generated |= proxy_mask
283
+ return corrupt_latents, source_is_generated
284
 
285
  def _checkpoint_cfg(self):
286
  return self._cfg_get(self._memory_cfg(), "checkpoint", None)
 
335
 
336
  def _apply_freeze_policy(self, optimizer=None, step: int | None = None):
337
  state = self._curriculum_state(step)
 
 
 
 
338
  freeze_key = (state.stage, state.dit_train_state, state.freeze_vae)
339
+ if getattr(self, "_last_freeze_key", None) != freeze_key:
 
 
 
 
 
 
 
 
 
 
340
  for name, param in self.named_parameters():
341
  group_name = self._param_group_name(name, state)
 
342
  if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae):
343
+ param.requires_grad_(False)
 
344
  else:
345
+ param.requires_grad_(True)
 
 
 
 
 
 
 
346
  self._last_freeze_key = freeze_key
 
 
 
 
 
 
 
 
 
347
 
348
  if optimizer is not None:
349
  for param_group in optimizer.param_groups:
350
  group_name = param_group.get("name", "")
351
  trainable = self._group_trainable(group_name, state)
352
  param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0
 
 
 
 
 
 
 
 
 
353
  return state
354
 
355
  def configure_optimizers(self):
 
998
  plucker_weight: float,
999
  revisit_retrieval_kwargs: dict | None,
1000
  token_patch_size: int,
1001
+ ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, list[list[tuple[MemoryRecord, ...]]] | None, dict | None]:
1002
  if committed_latents.ndim != 5:
1003
  raise ValueError("committed_latents must have shape (T_src,B,C,H,W)")
1004
  T_src, B, _, H, W = committed_latents.shape
 
1023
  revisit_banks: list[CausalMemoryBank] = []
1024
  dummy_tokens = committed_latents.new_zeros((1, hidden_size))
1025
  dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool)
 
 
 
 
 
 
1026
  retrieval_kwargs = dict(revisit_retrieval_kwargs or {})
1027
 
1028
  # Pre-convert pose tensors to stream_device once so that the
 
1079
  def _pose_subset(positions: torch.Tensor, batch_idx: int):
1080
  return _tensor_subset(pose, positions, batch_idx)
1081
 
1082
+ fast_source_pose_shape = (
1083
+ pose is not None
1084
+ and pose.ndim >= 3
1085
+ and ((pose.shape[0] == T_src and pose.shape[1] == B) or (pose.shape[0] == B and pose.shape[1] == T_src))
1086
+ )
1087
+ fast_target_pose_shape = (
1088
+ target_pose is not None
1089
+ and target_pose.ndim >= 3
1090
+ and ((target_pose.shape[0] == T_tgt and target_pose.shape[1] == B) or (target_pose.shape[0] == B and target_pose.shape[1] == T_tgt))
1091
+ )
1092
+ if fast_source_pose_shape and fast_target_pose_shape:
1093
+ selection = batched_revisit_select_positions(
1094
+ source_frame_indices,
1095
+ pose,
1096
+ target_frame_indices,
1097
+ target_pose,
1098
+ topk=revisit_max_frames,
1099
+ exclude_local_context_frames=exclude_local_context_frames,
1100
+ fov_overlap_threshold=fov_overlap_threshold,
1101
+ plucker_weight=plucker_weight,
1102
+ fov_half_h=float(retrieval_kwargs.get("fov_half_h", 52.5)),
1103
+ fov_half_v=float(retrieval_kwargs.get("fov_half_v", 37.5)),
1104
+ fov_yaw_samples=int(retrieval_kwargs.get("fov_yaw_samples", 25)),
1105
+ fov_pitch_samples=int(retrieval_kwargs.get("fov_pitch_samples", 20)),
1106
+ fov_depth_samples=int(retrieval_kwargs.get("fov_depth_samples", 20)),
1107
+ fov_radius=float(retrieval_kwargs.get("fov_radius", 30.0)),
1108
+ plucker_grid_h=int(retrieval_kwargs.get("plucker_grid_h", 4)),
1109
+ plucker_grid_w=int(retrieval_kwargs.get("plucker_grid_w", 4)),
1110
+ plucker_focal_length=float(retrieval_kwargs.get("plucker_focal_length", 0.35)),
1111
+ pose_preselect_topk=retrieval_kwargs.get("pose_preselect_topk", 64),
1112
+ )
1113
+ selected_records_by_target: list[list[tuple[MemoryRecord, ...]]] = [
1114
+ [tuple() for _ in range(T_tgt)] for _ in range(B)
1115
+ ]
1116
+
1117
+ for batch_idx in range(B):
1118
+ anchor_bank = CausalMemoryBank()
1119
+ revisit_bank = CausalMemoryBank()
1120
+ src_frames = source_frame_indices[:, batch_idx]
1121
+ if generated is None:
1122
+ non_generated = torch.ones_like(src_frames, dtype=torch.bool)
1123
+ else:
1124
+ non_generated = ~generated[:, batch_idx]
1125
+
1126
+ source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
1127
+ anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long)
1128
+ if anchor_indices and source_positions.numel() > 0:
1129
+ if anchor_diverse:
1130
+ anchor_source_positions = source_positions[source_positions < self._context_frame_count()]
1131
+ if anchor_source_positions.numel() > 0:
1132
+ anchor_pose = _pose_subset(anchor_source_positions, batch_idx)
1133
+ anchor_positions = self._select_diverse_anchor_positions(
1134
+ anchor_source_positions, anchor_pose, len(anchor_indices)
1135
+ ).to(device=stream_device, dtype=torch.long)
1136
+ else:
1137
+ selected_anchor_positions = []
1138
+ for anchor_idx in anchor_indices:
1139
+ if 0 <= int(anchor_idx) < source_positions.numel():
1140
+ selected_anchor_positions.append(source_positions[int(anchor_idx)])
1141
+ if selected_anchor_positions:
1142
+ anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
1143
+ if anchor_positions.numel() > 0:
1144
+ anchor_projected = self._project_latent_patch_tokens(
1145
+ committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
1146
+ self.dememwm_anchor_proj,
1147
+ token_patch_size,
1148
+ )[0]
1149
+ for local_idx, source_pos in enumerate(anchor_positions):
1150
+ source_pos_i = int(source_pos.item())
1151
+ anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w)
1152
+ n_slots = anchor_tokens.shape[0]
1153
+ record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool)
1154
+ anchor_bank.add_prefix_anchors(
1155
+ anchor_tokens.unsqueeze(0),
1156
+ record_mask.unsqueeze(0),
1157
+ src_frames[source_pos_i:source_pos_i + 1],
1158
+ slots_per_anchor=n_slots,
1159
+ )
1160
+
1161
+ selected_b = selection.selected_positions[batch_idx]
1162
+ selected_mask_b = selection.selected_mask[batch_idx]
1163
+ selected_fov_b = selection.selected_fov_overlap[batch_idx]
1164
+ selected_plucker_b = selection.selected_plucker_overlap[batch_idx]
1165
+ selected_gap_b = selection.selected_gap_frames[batch_idx]
1166
+ high_quality_fov_threshold = float(retrieval_kwargs.get("high_quality_fov_threshold", 0.70))
1167
+ metadata_by_position: dict[int, dict] = {}
1168
+ for target_idx in range(T_tgt):
1169
+ for slot_idx in range(selected_b.shape[1]):
1170
+ if not bool(selected_mask_b[target_idx, slot_idx].detach().item()):
1171
+ continue
1172
+ source_pos_i = int(selected_b[target_idx, slot_idx].detach().item())
1173
+ if source_pos_i < 0:
1174
+ continue
1175
+ frame = int(src_frames[source_pos_i].detach().item())
1176
+ fov_overlap = float(selected_fov_b[target_idx, slot_idx].detach().item())
1177
+ plucker_overlap = float(selected_plucker_b[target_idx, slot_idx].detach().item())
1178
+ gap_frames = float(selected_gap_b[target_idx, slot_idx].detach().item())
1179
+ existing = metadata_by_position.get(source_pos_i)
1180
+ if existing is not None:
1181
+ existing_rank = (
1182
+ float(existing.get("dememwm_selected_frame_fov_overlap", 0.0)),
1183
+ -float(existing.get("dememwm_selected_gap_frames", 1.0e9)),
1184
+ )
1185
+ new_rank = (fov_overlap, -gap_frames)
1186
+ if existing_rank >= new_rank:
1187
+ continue
1188
+ metadata_by_position[source_pos_i] = {
1189
+ "dememwm_selected_revisit_fov_overlap": fov_overlap,
1190
+ "dememwm_selected_revisit_plucker_overlap": plucker_overlap,
1191
+ "dememwm_selected_gap_frames": gap_frames,
1192
+ "dememwm_selected_frame_index": frame,
1193
+ "dememwm_selected_frame_fov_overlap": fov_overlap,
1194
+ "dememwm_selected_frame_fov_threshold": high_quality_fov_threshold,
1195
+ "dememwm_selected_frame_passes_high_quality": bool(fov_overlap >= high_quality_fov_threshold),
1196
+ }
1197
+ flat_selected = selected_b[selected_b >= 0].to(device=stream_device, dtype=torch.long)
1198
+ unique_positions = torch.unique(flat_selected, sorted=True) if flat_selected.numel() > 0 else flat_selected
1199
+ records_by_position: dict[int, MemoryRecord] = {}
1200
+ if unique_positions.numel() > 0:
1201
+ revisit_projected = self._project_latent_patch_tokens(
1202
+ committed_latents.index_select(0, unique_positions)[:, batch_idx:batch_idx + 1],
1203
+ self.dememwm_revisit_proj,
1204
+ token_patch_size,
1205
+ )[0]
1206
+ for local_idx, source_pos in enumerate(unique_positions):
1207
+ source_pos_i = int(source_pos.item())
1208
+ frame_index = src_frames[source_pos_i]
1209
+ frame = int(frame_index.detach().item())
1210
+ is_generated = False if generated is None else bool(generated[source_pos_i, batch_idx].detach().item())
1211
+ source_type = MemorySourceType.GENERATED if is_generated else MemorySourceType.PREFIX_GT
1212
+ prefix = "generated" if is_generated else "prefix"
1213
+ frame_tokens = self._spatial_pool_tokens(revisit_projected[local_idx], revisit_pool_h, revisit_pool_w, src_h, src_w)
1214
+ frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool)
1215
+ record = MemoryRecord(
1216
+ tokens=frame_tokens,
1217
+ mask=frame_mask,
1218
+ source_start=frame,
1219
+ source_end=frame + 1,
1220
+ frame_indices=frame_index.reshape(1).to(device=stream_device),
1221
+ pose=_pose_subset(source_pos.reshape(1), batch_idx),
1222
+ source_type=source_type,
1223
+ is_generated=is_generated,
1224
+ chunk_id=f"{prefix}_revisit_b{batch_idx}_f{frame}",
1225
+ metadata=metadata_by_position.get(source_pos_i, {}),
1226
+ )
1227
+ revisit_bank.add_record(record)
1228
+ records_by_position[source_pos_i] = record
1229
+
1230
+ for target_idx in range(T_tgt):
1231
+ target_records: list[MemoryRecord] = []
1232
+ for source_pos in selected_b[target_idx]:
1233
+ source_pos_i = int(source_pos.detach().item())
1234
+ if source_pos_i < 0:
1235
+ continue
1236
+ record = records_by_position.get(source_pos_i)
1237
+ if record is not None:
1238
+ target_records.append(record)
1239
+ selected_records_by_target[batch_idx][target_idx] = tuple(target_records)
1240
+
1241
+ anchor_banks.append(anchor_bank)
1242
+ revisit_banks.append(revisit_bank)
1243
+
1244
+ fast_revisit_stats = {
1245
+ "best_selected_fov_overlap": selection.best_selected_fov_overlap.to(device=stream_device),
1246
+ "best_selected_plucker_overlap": selection.best_selected_plucker_overlap.to(device=stream_device),
1247
+ "best_selected_gap_frames": selection.best_selected_gap_frames.to(device=stream_device),
1248
+ }
1249
+ return anchor_banks, revisit_banks, tokens_per_frame, selected_records_by_target, fast_revisit_stats
1250
+
1251
  def _candidate_record(
1252
  *,
1253
  batch_idx: int,
 
1298
  if selected_anchor_positions:
1299
  anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long)
1300
  if anchor_positions.numel() > 0:
 
1301
  anchor_projected = self._project_latent_patch_tokens(
1302
  committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1],
1303
  self.dememwm_anchor_proj,
 
1317
 
1318
  candidate_records: list[MemoryRecord] = []
1319
  candidate_positions: dict[str, torch.Tensor] = {}
1320
+ latest_valid_source_frame_exclusive = target_frame_indices[:, batch_idx].amax() - int(exclude_local_context_frames)
 
 
1321
  for prefix, positions, source_type, is_generated in (
1322
  ("prefix", source_positions, MemorySourceType.PREFIX_GT, False),
1323
  (
 
1327
  True,
1328
  ),
1329
  ):
1330
+ if positions.numel() == 0:
1331
  continue
1332
+ positions = positions.to(device=stream_device, dtype=torch.long)
1333
+ frame_values = src_frames.index_select(0, positions).to(device=stream_device)
1334
+ valid_positions = positions[frame_values < latest_valid_source_frame_exclusive]
1335
+ valid_frames = src_frames.index_select(0, valid_positions) if valid_positions.numel() else src_frames[:0]
1336
+ for frame_position, frame_tensor in zip(valid_positions.unbind(0), valid_frames.unbind(0)):
1337
+ frame = int(frame_tensor.item())
1338
+ frame_position = frame_position.reshape(1)
1339
  record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}"
1340
  candidate_positions[record_id] = frame_position
1341
  candidate_records.append(_candidate_record(
 
1359
  exclude_local_context_frames=exclude_local_context_frames,
1360
  fov_overlap_threshold=fov_overlap_threshold,
1361
  plucker_weight=plucker_weight,
1362
+ target_video_id=None,
1363
  **retrieval_kwargs,
1364
  )
 
 
 
1365
  for selected_record in result.records:
1366
  if selected_record.chunk_id is None:
1367
  continue
 
1374
  continue
1375
  record_id = str(record.chunk_id)
1376
  frame_position = candidate_positions[record_id]
 
 
1377
  revisit_projected = self._project_latent_patch_tokens(
1378
  committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1],
1379
  self.dememwm_revisit_proj,
 
1397
  anchor_banks.append(anchor_bank)
1398
  revisit_banks.append(revisit_bank)
1399
 
1400
+ return anchor_banks, revisit_banks, tokens_per_frame, None, None
 
 
 
 
 
 
 
 
 
 
1401
 
1402
  def _causal_cached_revisit_records(
1403
  self,
 
1529
  target_video_ids=None,
1530
  source_is_generated: torch.Tensor | None = None,
1531
  denoising_fraction: float | None = None,
 
 
1532
  streaming_cache: StreamingCache | None = None,
1533
  ) -> MemoryStreamTensors:
1534
  if target_frame_indices is None:
 
1540
  dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
1541
  revisit_cfg = self._cfg_get(memory_cfg, "revisit", None)
1542
  injection_cfg = self._cfg_get(memory_cfg, "injection", None)
1543
+ self._validate_config_contract()
1544
  gate_state = self._effective_gate_state(
1545
  denoising_fraction=denoising_fraction,
 
1546
  )
1547
  anchor_config_enabled = gate_state["anchor_config_enabled"]
1548
  dynamic_config_enabled = gate_state["dynamic_config_enabled"]
 
1550
  curriculum_state = gate_state["curriculum_state"]
1551
  eval_ablation_enabled = gate_state["eval_ablation_enabled"]
1552
  eval_ablation_branch = gate_state["eval_ablation_branch"]
 
1553
  gates = gate_state["gates"]
1554
  anchor_effective_enabled = gate_state["anchor_effective_enabled"]
1555
  dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
 
1604
  "plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)),
1605
  "plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)),
1606
  }
1607
+ preselected_revisit_records_by_target: list[list[tuple[MemoryRecord, ...]]] | None = None
1608
+ preselected_revisit_stats: dict | None = None
1609
  use_cache_revisit_records = False
1610
  revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None
1611
 
1612
  cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None
 
1613
  if committed_latents is not None:
1614
  stream_device = committed_latents.device
1615
  stream_dtype = committed_latents.dtype
 
1677
  B = committed_latents.shape[1]
1678
  hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024))
1679
  target_pose_source = target_pose if target_pose is not None else pose
1680
+ anchor_banks, revisit_banks, tokens_per_frame, preselected_revisit_records_by_target, preselected_revisit_stats = self._build_preselected_causal_memory_banks(
1681
  committed_latents,
1682
  source_frame_indices.to(device=stream_device),
1683
  None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool),
 
1749
  dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w)
1750
  dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype)
1751
  dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool)
 
 
 
 
 
 
 
 
 
1752
  else:
1753
  # Pre-select dynamic source frame positions using only frame index metadata
1754
  # before touching latents, so we pass a small slice instead of the full
1755
  # 1000-frame tensor to the compressor.
1756
  _dfi = dynamic_frame_indices.to(device=stream_device)
1757
  _max_src = self.dememwm_dynamic_compressor.max_source_frames
1758
+ _needed_tensors: list[torch.Tensor] = []
1759
  for _b in range(B):
1760
  for _j in range(T_tgt):
1761
+ _target = target_frame_indices[_j, _b]
1762
  _valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
1763
+ if _valid.numel() > 0:
1764
+ _needed_tensors.append(_valid[-_max_src:])
1765
+ if _needed_tensors:
1766
+ _needed_idx = torch.unique(torch.cat(_needed_tensors, dim=0), sorted=True).to(device=stream_device, dtype=torch.long)
1767
  _dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx)
1768
  _dynamic_fi_small = _dfi.index_select(0, _needed_idx)
1769
  _dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None
 
1776
  _dynamic_fi_small = _dfi[:0]
1777
  _dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None
1778
  _dynamic_gen_small = None
1779
+ dynamic_tokens, dynamic_mask = self.dememwm_dynamic_compressor(
1780
  _dynamic_latents_small,
1781
  _dynamic_fi_small,
1782
  _dynamic_pose_small,
 
1785
  exclude_latest_local_frames=dynamic_recent_exclusion_frames,
1786
  )
1787
 
 
 
 
 
 
 
 
 
 
 
 
 
1788
  def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int):
1789
  if tensor is None or tensor.ndim < 2:
1790
  return None
 
1823
  revisit_mask_rows = []
1824
  revisit_max_rows = []
1825
  valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
 
 
1826
  revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1827
  revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1828
  revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
1829
  eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
 
1830
  eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
 
1831
  projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {}
1832
  if revisit_record_batches is None:
1833
  revisit_record_batches = [tuple(bank.records) for bank in revisit_banks]
 
1838
  batch_max_rows = []
1839
  for target_idx in range(T_tgt):
1840
  target_frame = int(target_frame_indices[target_idx, batch_idx].item())
1841
+ if preselected_revisit_records_by_target is not None:
1842
+ selected_records = list(preselected_revisit_records_by_target[batch_idx][target_idx])
1843
+ if preselected_revisit_stats is not None:
1844
+ revisit_best_selected_fov_overlap[batch_idx, target_idx] = torch.as_tensor(
1845
+ preselected_revisit_stats["best_selected_fov_overlap"][batch_idx, target_idx],
1846
+ device=stream_device,
1847
+ dtype=torch.float32,
1848
+ )
1849
+ revisit_best_selected_plucker_overlap[batch_idx, target_idx] = torch.as_tensor(
1850
+ preselected_revisit_stats["best_selected_plucker_overlap"][batch_idx, target_idx],
1851
+ device=stream_device,
1852
+ dtype=torch.float32,
1853
+ )
1854
+ revisit_selected_gap_frames[batch_idx, target_idx] = torch.as_tensor(
1855
+ preselected_revisit_stats["best_selected_gap_frames"][batch_idx, target_idx],
1856
+ device=stream_device,
1857
+ dtype=torch.float32,
1858
+ )
1859
  else:
1860
+ if use_cache_revisit_records:
1861
+ candidate_records = self._causal_cached_revisit_records(
1862
+ revisit_record_batches[batch_idx],
1863
+ target_frame,
1864
  )
1865
+ else:
1866
+ candidate_records = revisit_bank.query(
1867
+ MemoryBankQuery(
1868
+ target_frame=target_frame,
1869
+ include_generated=True,
1870
+ )
1871
+ )
1872
+ result = deterministic_revisit_retrieval(
1873
+ candidate_records,
1874
+ target_frame=target_frame,
1875
+ target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
1876
+ target_summary=None,
1877
+ topk=revisit_max_frames,
1878
+ exclude_local_context_frames=revisit_context_window_exclusion_frames,
1879
+ fov_overlap_threshold=fov_overlap_threshold,
1880
+ plucker_weight=plucker_weight,
1881
+ target_video_id=None,
1882
+ **revisit_retrieval_kwargs,
1883
  )
1884
+ selected_records = result.records
1885
+ if use_cache_revisit_records and selected_records:
1886
+ selected_records = self._project_streaming_revisit_records(
1887
+ cache=cache,
1888
+ batch_idx=batch_idx,
1889
+ records=selected_records,
1890
+ device=stream_device,
1891
+ dtype=stream_dtype,
1892
+ token_patch_size=token_patch_size,
1893
+ revisit_pool_h=revisit_pool_h,
1894
+ revisit_pool_w=revisit_pool_w,
1895
+ projection_cache=projected_revisit_record_cache,
1896
+ )
1897
+ revisit_best_selected_fov_overlap[batch_idx, target_idx] = result.best_selected_fov_overlap.to(device=stream_device, dtype=torch.float32)
1898
+ revisit_best_selected_plucker_overlap[batch_idx, target_idx] = result.best_selected_plucker_overlap.to(device=stream_device, dtype=torch.float32)
1899
+ revisit_selected_gap_frames[batch_idx, target_idx] = result.best_selected_gap_frames.to(device=stream_device, dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1900
  revisit_bank.assert_causal(target_frame, selected_records)
1901
  if selected_records:
1902
  valid_revisit_mask[batch_idx, target_idx] = True
 
1907
  stream_device,
1908
  stream_dtype,
1909
  )
 
1910
  if eval_corruption_enabled:
1911
  stream_tokens, was_corrupted = apply_revisit_eval_corruption(
1912
  tokens=stream_tokens,
 
1962
  if not revisit_stage_config_enabled:
1963
  revisit_mask = torch.zeros_like(revisit_mask)
1964
  valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
 
 
1965
  revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
1966
  revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
1967
  revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
 
1969
  valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
1970
  revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
1971
  revisit_gate = torch.zeros_like(revisit_gate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1972
  return MemoryStreamTensors(
1973
  anchor_tokens=anchor_tokens,
1974
  anchor_mask=anchor_mask,
 
1981
  revisit_gate=revisit_gate,
1982
  revisit_gate_raw=revisit_gate_raw,
1983
  valid_revisit_mask=valid_revisit_mask,
1984
+ revisit_best_selected_fov_overlap=revisit_best_selected_fov_overlap,
1985
+ revisit_best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
1986
+ revisit_selected_gap_frames=revisit_selected_gap_frames,
1987
  )
1988
 
1989
  def _refresh_stream_gates(
1990
  self,
1991
  streams: MemoryStreamTensors,
1992
  denoising_fraction: float | None = None,
 
1993
  ) -> MemoryStreamTensors:
1994
  gate_state = self._effective_gate_state(
1995
  denoising_fraction=denoising_fraction,
 
1996
  )
1997
  gates = gate_state["gates"]
1998
  device = streams.anchor_tokens.device
 
2004
  else:
2005
  valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool)
2006
 
2007
+ def _gate_feature(value: torch.Tensor | None, fill_value: float = 0.0) -> torch.Tensor:
 
 
 
2008
  if value is None:
2009
  return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32)
2010
+ tensor = value.to(device=device, dtype=torch.float32)
2011
  if tensor.ndim == 0:
2012
  return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32)
2013
  return tensor.expand((B, T_tgt))
2014
 
2015
+ revisit_best_selected_fov_overlap = _gate_feature(streams.revisit_best_selected_fov_overlap)
2016
+ revisit_best_selected_plucker_overlap = _gate_feature(streams.revisit_best_selected_plucker_overlap)
2017
+ revisit_selected_gap_frames = _gate_feature(streams.revisit_selected_gap_frames, -1.0)
2018
 
2019
  anchor_effective_enabled = gate_state["anchor_effective_enabled"]
2020
  dynamic_effective_enabled = gate_state["dynamic_effective_enabled"]
 
2031
  best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
2032
  selected_gap_frames=revisit_selected_gap_frames,
2033
  ).to(device=device, dtype=dtype)
 
2034
  if not revisit_stage_config_enabled or gate_state["force_revisit_off"]:
2035
  revisit_gate = torch.zeros_like(revisit_gate_raw)
2036
  elif gate_state["force_revisit_on"]:
2037
+ revisit_gate = valid_revisit_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw)
2038
  else:
2039
+ revisit_gate = valid_revisit_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate)
2040
  if not revisit_stage_config_enabled:
2041
  valid_revisit_mask = torch.zeros_like(valid_revisit_mask)
 
2042
  revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
2043
  revisit_gate = torch.zeros_like(revisit_gate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2044
  return replace(
2045
  streams,
2046
  anchor_gate=anchor_gate,
 
2048
  revisit_gate=revisit_gate,
2049
  revisit_gate_raw=revisit_gate_raw,
2050
  valid_revisit_mask=valid_revisit_mask,
2051
+ revisit_best_selected_fov_overlap=revisit_best_selected_fov_overlap,
2052
+ revisit_best_selected_plucker_overlap=revisit_best_selected_plucker_overlap,
2053
+ revisit_selected_gap_frames=revisit_selected_gap_frames,
2054
  )
2055
 
2056
+ def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]:
2057
+ total_frames = int(total_frames)
2058
+ window = int(getattr(self, "n_tokens", total_frames) or total_frames)
2059
+ if total_frames <= 0:
2060
+ return 0, 0
2061
+ if window <= 0 or total_frames <= window:
2062
+ return 0, total_frames
2063
+ max_start = max(0, total_frames - window)
2064
+ min_start = min(max(0, self._context_frame_count()), max_start)
2065
+ if max_start <= min_start:
2066
+ start = min_start
2067
+ else:
2068
+ start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item())
2069
+ return start, start + window
2070
+
2071
+ def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> dict:
2072
+ return self.dememwm_injection_adapter(
2073
+ streams,
2074
+ device=streams.anchor_tokens.device,
2075
+ dtype=streams.anchor_tokens.dtype,
2076
+ )
2077
 
2078
+ def build_memory_kwargs(self, *args, **kwargs) -> dict:
2079
  streams = self.build_memory_streams(*args, **kwargs)
2080
  return self._streams_to_kwargs(streams)
2081
 
2082
+ def _training_pose_condition(
2083
+ self,
2084
+ pose_conditions: torch.Tensor,
2085
+ c2w_mat: torch.Tensor,
2086
+ frame_idx: torch.Tensor,
2087
+ *,
2088
+ dtype: torch.dtype,
2089
+ image_width: int,
2090
+ image_height: int,
2091
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2092
  if self.use_plucker:
2093
  if self.relative_embedding:
2094
+ memory_condition_length = max(0, int(self.memory_condition_length or 0))
2095
+ if memory_condition_length:
2096
+ memory_c2w = c2w_mat[-memory_condition_length:]
2097
+ memory_frame_idx = frame_idx[-memory_condition_length:]
2098
+ else:
2099
+ memory_c2w = c2w_mat[:0]
2100
+ memory_frame_idx = frame_idx[:0]
2101
  input_pose_condition = []
2102
  frame_idx_list = []
2103
+ for target_idx in range(c2w_mat.shape[0]):
 
 
2104
  input_pose_condition.append(
2105
  convert_to_plucker(
2106
+ torch.cat([c2w_mat[target_idx:target_idx + 1], memory_c2w]).clone(),
2107
  0,
2108
  focal_length=self.focal_length,
2109
+ image_width=image_width,
2110
+ image_height=image_height,
2111
+ ).to(dtype)
2112
+ )
2113
+ frame_idx_list.append(
2114
+ torch.cat([
2115
+ frame_idx[target_idx:target_idx + 1] - frame_idx[target_idx:target_idx + 1],
2116
+ memory_frame_idx - frame_idx[target_idx:target_idx + 1],
2117
+ ]).clone()
2118
  )
 
2119
  return torch.cat(input_pose_condition), torch.cat(frame_idx_list)
2120
+ return (
2121
+ convert_to_plucker(
2122
+ c2w_mat,
2123
+ 0,
2124
+ focal_length=self.focal_length,
2125
+ image_width=image_width,
2126
+ image_height=image_height,
2127
+ ).to(dtype),
2128
+ frame_idx,
2129
+ )
2130
+ return pose_conditions.to(dtype), None
 
 
 
 
 
 
 
2131
 
2132
  def training_step(self, batch, batch_idx):
2133
  xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
2134
+ image_height, image_width = self._image_size(xs)
2135
  xs = self._as_latents(xs)
2136
 
2137
  # Randomly select a contiguous n_tokens denoising window inside the long
 
2145
  frame_idx_window = frame_idx[start:end]
2146
 
2147
  input_pose_condition, frame_idx_list = self._training_pose_condition(
2148
+ pose_conditions[start:end],
2149
+ c2w_mat[start:end],
2150
+ frame_idx_window,
2151
+ dtype=xs_window.dtype,
2152
+ image_width=image_width,
2153
+ image_height=image_height,
2154
  )
2155
 
2156
  noise_levels = self._generate_noise_levels(xs_window)
 
2158
  noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
2159
  conditions_window[-self.memory_condition_length:] *= 0
2160
  source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool)
2161
+ memory_source_latents, source_is_generated = self._apply_generated_history_proxy(
2162
  xs,
2163
  source_is_generated,
2164
  context_frame_count=self._context_frame_count(),
2165
  target_start_frame=start,
2166
  )
2167
  timesteps = int(getattr(self, "timesteps", 0) or 0)
 
 
2168
  training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps)
2169
+ memory_kwargs = self.build_memory_kwargs(
2170
  memory_source_latents,
2171
  frame_idx,
2172
  target_frame_indices=frame_idx_window,
 
2176
  target_action=conditions_window,
2177
  source_is_generated=source_is_generated,
2178
  denoising_fraction=training_denoising_fraction,
 
 
2179
  )
 
2180
  _, loss = self.diffusion_model(
2181
  xs_window,
2182
  conditions_window,
 
2186
  frame_idx=frame_idx_list,
2187
  **memory_kwargs,
2188
  )
 
2189
  if self.memory_condition_length:
2190
  loss = loss[:-self.memory_condition_length]
2191
  loss_denoise = self.reweight_loss(loss, None)
2192
  loss_total = loss_denoise
 
 
 
 
 
2193
  if batch_idx % 20 == 0:
2194
+ revisit_gate = memory_kwargs.get("memory_retrieval_gate")
2195
+ if torch.is_tensor(revisit_gate):
2196
+ revisit_gate_value = revisit_gate.detach().float().mean()
2197
+ else:
2198
+ revisit_gate_value = loss_total.detach().new_tensor(0.0 if revisit_gate is None else float(revisit_gate))
2199
+ self.log("training/loss", loss_total.detach(), prog_bar=True, sync_dist=True)
2200
+ self.log("training/denoise_loss", loss_denoise.detach(), prog_bar=False, sync_dist=True)
2201
+ self.log("training/revisit_gate", revisit_gate_value, prog_bar=False, sync_dist=True)
2202
  return {"loss": loss_total}
2203
 
2204
  def validation_step(self, batch, batch_idx, namespace="validation"):
 
2222
  streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}")
2223
  cached_until = 0
2224
  pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
 
2225
  while curr_frame < n_frames:
2226
  if streaming_cache is not None and curr_frame > cached_until:
2227
  new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device)
 
2288
  from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length)
2289
  denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0)
2290
  step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac)
2291
+ memory_kwargs = self._streams_to_kwargs(step_streams)
2292
  xs_pred[start_frame:] = self.diffusion_model.sample_step(
2293
  xs_pred[start_frame:].to(input_condition.device),
2294
  input_condition,
 
2318
  action=conditions[cached_until:curr_frame],
2319
  )
2320
  cached_until = curr_frame
 
 
2321
  pbar.update(horizon)
2322
  pbar.close()
 
 
2323
  xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
2324
  xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
2325
  self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
 
2349
  # Compatibility aliases for old DeMemWM test and experiment call sites.
2350
  dememwm_strict_key_prefixes = strict_key_prefixes
2351
  dememwm_strict_key_substrings = strict_key_substrings
 
 
2352
  _dememwm_cfg = _memory_cfg
 
2353
  _dememwm_eval_ablation_cfg = _eval_ablation_cfg
2354
  _dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg
2355
  _dememwm_eval_ablation_state = _eval_ablation_state
 
2382
  _dememwm_refresh_stream_gates = _refresh_stream_gates
2383
  _dememwm_streams_to_kwargs = _streams_to_kwargs
2384
  build_dememwm_memory_kwargs = build_memory_kwargs
 
 
2385
  _dememwm_training_window_bounds = _training_window_bounds
2386
  strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check
2387
 
algorithms/worldmem/dememwm/cache.py CHANGED
@@ -486,23 +486,6 @@ class StreamingCache:
486
  ).to(device=device)
487
  return latents, frame_indices, generated, pose
488
 
489
- def diagnostics(self, prefix: str = "cache") -> dict[str, Any]:
490
- return {
491
- f"{prefix}_enabled": bool(self.enabled),
492
- f"{prefix}_records": int(self.record_count),
493
- f"{prefix}_anchor_records": int(self.records_count("anchor")),
494
- f"{prefix}_revisit_records": int(self.records_count("revisit")),
495
- f"{prefix}_slots": int(self.slot_count),
496
- f"{prefix}_raw_frame_slots": int(self.raw_frame_slots),
497
- f"{prefix}_raw_segments": int(self.raw_segment_count),
498
- f"{prefix}_evictions": int(self.evictions),
499
- f"{prefix}_resets": int(self.reset_count),
500
- f"{prefix}_capacity_exceeded": int(self.capacity_exceeded_count),
501
- f"{prefix}_device": self.device,
502
- f"{prefix}_current_video_id": self.current_video_id,
503
- f"{prefix}_clear_between_videos": bool(self.clear_between_videos),
504
- f"{prefix}_no_evict": bool(self.no_evict),
505
- }
506
 
507
 
508
  DeMemWMStreamingCache = StreamingCache
 
486
  ).to(device=device)
487
  return latents, frame_indices, generated, pose
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
 
491
  DeMemWMStreamingCache = StreamingCache
algorithms/worldmem/dememwm/compression.py CHANGED
@@ -163,7 +163,8 @@ class CausalConv3DDynamicCompressor(nn.Module):
163
  target_frame_indices: torch.Tensor,
164
  source_is_generated: Optional[torch.Tensor] = None,
165
  exclude_latest_local_frames: Optional[int] = None,
166
- ) -> tuple[torch.Tensor, torch.Tensor, dict]:
 
167
  if latents.ndim != 5:
168
  raise ValueError("latents must have shape (T_src,B,C,H,W)")
169
  exclude_latest_local_frames = (
@@ -175,86 +176,78 @@ class CausalConv3DDynamicCompressor(nn.Module):
175
  p = self.patch_size
176
  if H % p != 0 or W % p != 0:
177
  raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
 
 
178
  if target_frame_indices.ndim == 1:
179
  target_frame_indices = target_frame_indices[:, None].expand(-1, B)
180
- T_tgt = target_frame_indices.shape[0]
 
 
181
  device = latents.device
182
- generated_flags = None if source_is_generated is None else source_is_generated.to(device=device, dtype=torch.bool)
 
 
183
  n_spatial = (H // p) * (W // p)
184
  T_out = self._temporal_output_count()
185
  num_slots = T_out * n_spatial
186
  output_time_idx = self._output_time_indices(device)
187
- selected_source_count = torch.zeros((B, T_tgt), dtype=torch.long, device=device)
188
- max_source_frame = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
189
- generated_source_fraction = torch.zeros((B, T_tgt), dtype=torch.float32, device=device)
190
- min_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
191
- max_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
192
- output_rows, mask_rows = [], []
193
- for b in range(B):
194
- src_frames_b = frame_indices[:, b]
195
- tgt_outputs, tgt_masks = [], []
196
- for j in range(T_tgt):
197
- target = int(target_frame_indices[j, b].item())
198
- valid_idx = (
199
- src_frames_b < target - exclude_latest_local_frames
200
- ).nonzero(as_tuple=False).flatten()
201
- if valid_idx.numel() == 0:
202
- tgt_outputs.append(latents.new_zeros(num_slots, self.dit_hidden_size))
203
- tgt_masks.append(torch.zeros(num_slots, device=device, dtype=torch.bool))
204
- continue
205
- selected_frames = src_frames_b.index_select(0, valid_idx)
206
- order = torch.argsort(selected_frames)
207
- valid_idx = valid_idx.index_select(0, order)[-self.max_source_frames:]
208
- selected_frames = src_frames_b.index_select(0, valid_idx)
209
- selected_source_count[b, j] = int(selected_frames.numel())
210
- max_source_frame[b, j] = selected_frames.max()
211
- gaps = target - selected_frames
212
- min_gap[b, j] = gaps.min()
213
- max_gap[b, j] = gaps.max()
214
- if generated_flags is not None:
215
- generated = generated_flags.index_select(0, valid_idx)[:, b]
216
- generated_source_fraction[b, j] = generated.float().mean()
217
- chunk = latents[valid_idx, b]
218
- real_mask = torch.ones((chunk.shape[0],), device=device, dtype=torch.bool)
219
- if chunk.shape[0] < self.max_source_frames:
220
- pad = chunk.new_zeros(self.max_source_frames - chunk.shape[0], C, H, W)
221
- chunk = torch.cat([pad, chunk], dim=0)
222
- real_mask = torch.cat([
223
- torch.zeros((pad.shape[0],), device=device, dtype=torch.bool),
224
- real_mask,
225
- ])
226
- inp = chunk.clone()
227
- inp[1:] = chunk[1:] - chunk[:-1]
228
- x = inp.permute(1, 0, 2, 3).unsqueeze(0) # (1,C,T,H,W)
229
- x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0)) # left-pad time
230
- x = self.conv3d(x) # (1,D,T_out,H//p,W//p)
231
- x = x.squeeze(0).permute(1, 2, 3, 0) # (T_out,H//p,W//p,D)
232
- x = self.out_norm(x)
233
- tokens = x.reshape(num_slots, self.dit_hidden_size)
234
- clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
235
- temporal_mask = (
236
- (output_time_idx >= 0)
237
- & (output_time_idx < self.max_source_frames)
238
- & real_mask.index_select(0, clamped_time_idx)
239
- )
240
- mask = temporal_mask[:, None].expand(T_out, n_spatial).reshape(num_slots)
241
- tgt_outputs.append(tokens)
242
- tgt_masks.append(mask)
243
- output_rows.append(torch.stack(tgt_outputs))
244
- mask_rows.append(torch.stack(tgt_masks))
245
- out_tokens = torch.stack(output_rows)
246
- out_mask = torch.stack(mask_rows)
247
- diagnostics = {
248
- "num_dynamic_slots": num_slots,
249
- "dynamic_T_out": T_out,
250
- "dynamic_n_spatial": n_spatial,
251
- "dynamic_temporal_left_pad": self.causal_pad,
252
- "dynamic_output_time_indices": output_time_idx,
253
- "selected_source_count": selected_source_count,
254
- "max_source_frame": max_source_frame,
255
- "generated_source_fraction": generated_source_fraction,
256
- "dynamic_min_gap_to_target_per_target": min_gap,
257
- "dynamic_max_gap_to_target_per_target": max_gap,
258
- "dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
259
- }
260
- return out_tokens, out_mask, diagnostics
 
163
  target_frame_indices: torch.Tensor,
164
  source_is_generated: Optional[torch.Tensor] = None,
165
  exclude_latest_local_frames: Optional[int] = None,
166
+ ) -> tuple[torch.Tensor, torch.Tensor]:
167
+ del pose, source_is_generated
168
  if latents.ndim != 5:
169
  raise ValueError("latents must have shape (T_src,B,C,H,W)")
170
  exclude_latest_local_frames = (
 
176
  p = self.patch_size
177
  if H % p != 0 or W % p != 0:
178
  raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
179
+ if frame_indices.shape != (T_src, B):
180
+ raise ValueError("frame_indices must have shape (T_src,B)")
181
  if target_frame_indices.ndim == 1:
182
  target_frame_indices = target_frame_indices[:, None].expand(-1, B)
183
+ if target_frame_indices.ndim != 2 or target_frame_indices.shape[1] != B:
184
+ raise ValueError("target_frame_indices must have shape (T_tgt,B)")
185
+
186
  device = latents.device
187
+ frame_indices = frame_indices.to(device=device)
188
+ target_frame_indices = target_frame_indices.to(device=device)
189
+ T_tgt = target_frame_indices.shape[0]
190
  n_spatial = (H // p) * (W // p)
191
  T_out = self._temporal_output_count()
192
  num_slots = T_out * n_spatial
193
  output_time_idx = self._output_time_indices(device)
194
+ if T_src == 0:
195
+ out_tokens = latents.new_zeros((B, T_tgt, num_slots, self.dit_hidden_size))
196
+ out_mask = torch.zeros((B, T_tgt, num_slots), device=device, dtype=torch.bool)
197
+ return out_tokens, out_mask
198
+
199
+ source_frames = frame_indices.transpose(0, 1).contiguous()
200
+ target_frames = target_frame_indices.transpose(0, 1).contiguous()
201
+ valid = source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames))
202
+ valid_flat = valid.reshape(B * T_tgt, T_src)
203
+ source_frames_flat = source_frames[:, None, :].expand(B, T_tgt, T_src).reshape(B * T_tgt, T_src)
204
+
205
+ topk = min(int(self.max_source_frames), T_src)
206
+ rank = source_frames_flat.to(dtype=torch.float64).masked_fill(~valid_flat, -float("inf"))
207
+ top = torch.topk(rank, k=topk, dim=1, largest=True, sorted=True)
208
+ selected_idx = top.indices.flip(dims=(1,))
209
+ selected_valid = torch.isfinite(top.values).flip(dims=(1,))
210
+ if topk < self.max_source_frames:
211
+ pad_count = self.max_source_frames - topk
212
+ selected_idx = torch.cat([
213
+ torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.long),
214
+ selected_idx,
215
+ ], dim=1)
216
+ selected_valid = torch.cat([
217
+ torch.zeros((B * T_tgt, pad_count), device=device, dtype=torch.bool),
218
+ selected_valid,
219
+ ], dim=1)
220
+
221
+ selected_idx_clamped = selected_idx.to(device=device, dtype=torch.long).clamp(min=0, max=max(0, T_src - 1))
222
+ has_valid = selected_valid.any(dim=1)
223
+ batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt)
224
+ latents_by_batch = latents.permute(1, 0, 2, 3, 4).contiguous()
225
+ latents_per_query = latents_by_batch.index_select(0, batch_ids)
226
+ gather_idx = selected_idx_clamped.reshape(B * T_tgt, self.max_source_frames, 1, 1, 1).expand(
227
+ -1, -1, C, H, W
228
+ )
229
+ chunk = torch.gather(latents_per_query, 1, gather_idx)
230
+ chunk = torch.where(
231
+ selected_valid[:, :, None, None, None],
232
+ chunk,
233
+ torch.zeros((), device=device, dtype=latents.dtype),
234
+ )
235
+
236
+ inp = chunk.clone()
237
+ inp[:, 1:] = chunk[:, 1:] - chunk[:, :-1]
238
+ x = inp.permute(0, 2, 1, 3, 4)
239
+ x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0))
240
+ x = self.conv3d(x)
241
+ x = self.out_norm(x.permute(0, 2, 3, 4, 1))
242
+ tokens_flat = x.reshape(B * T_tgt, num_slots, self.dit_hidden_size)
243
+ tokens_flat = torch.where(has_valid[:, None, None], tokens_flat, torch.zeros_like(tokens_flat))
244
+ out_tokens = tokens_flat.reshape(B, T_tgt, num_slots, self.dit_hidden_size)
245
+
246
+ clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
247
+ temporal_mask = (
248
+ (output_time_idx >= 0)
249
+ & (output_time_idx < self.max_source_frames)
250
+ & selected_valid.index_select(1, clamped_time_idx)
251
+ )
252
+ out_mask = temporal_mask[:, :, None].expand(B * T_tgt, T_out, n_spatial).reshape(B, T_tgt, num_slots)
253
+ return out_tokens, out_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
algorithms/worldmem/dememwm/diagnostics.py DELETED
@@ -1,172 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Any
4
-
5
- import torch
6
-
7
- from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket
8
-
9
-
10
- _REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker"
11
-
12
-
13
- def tensor_valid_fraction(mask: torch.Tensor | None) -> float:
14
- if mask is None or mask.numel() == 0:
15
- return 0.0
16
- return float(mask.detach().bool().float().mean().item())
17
-
18
-
19
- def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]:
20
- if gate is None:
21
- return {"mean": 0.0, "min": 0.0, "max": 0.0}
22
- if not torch.is_tensor(gate):
23
- value = float(gate)
24
- return {"mean": value, "min": value, "max": value}
25
- g = gate.detach().float()
26
- return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())}
27
-
28
-
29
- def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]:
30
- return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)}
31
-
32
-
33
- def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None:
34
- max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame)
35
- if max_src >= int(target_frame):
36
- raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}")
37
-
38
-
39
- def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]:
40
- values: list[float] = []
41
- for diag in result_diagnostics:
42
- for value in diag.get(key, []) or []:
43
- values.append(float(value))
44
- return values
45
-
46
-
47
- def _value_stats(values: list[float], prefix: str) -> dict[str, float]:
48
- if not values:
49
- return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
50
- return {
51
- f"{prefix}_mean": float(sum(values) / len(values)),
52
- f"{prefix}_min": float(min(values)),
53
- f"{prefix}_max": float(max(values)),
54
- }
55
-
56
-
57
- def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]:
58
- target_count = len(result_diagnostics)
59
- candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics)
60
- candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0
61
- valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics)
62
- pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics)
63
- pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics)
64
- exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
65
- valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
66
- valid_count_mean = float(valid_count / target_count) if target_count else 0.0
67
- selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
68
- no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
69
- abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
70
- selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0]
71
- diagnostics: dict[str, Any] = {
72
- "revisit_candidate_frame_count": candidate_count_mean,
73
- "revisit_candidate_count": candidate_count_mean,
74
- "valid_candidate_label_count": int(valid_candidate_label_count),
75
- "revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0,
76
- "revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0,
77
- "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
78
- "valid_revisit_frame_count": valid_count_mean,
79
- "valid_revisit_count": valid_count_mean,
80
- "no_valid_revisit_count": int(no_valid_count),
81
- "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
82
- "revisit_selected_frame_count": int(selected_count),
83
- "revisit_selected_count": int(selected_count),
84
- "revisit_abstained_count": int(abstained_count),
85
- "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
86
- "revisit_label_source": _REVISIT_LABEL_SOURCE,
87
- }
88
- frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values")
89
- if not frame_fov_values:
90
- frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values")
91
- diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap"))
92
- diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap"))
93
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap"))
94
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap"))
95
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap"))
96
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames"))
97
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap"))
98
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap"))
99
- diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap"))
100
- return diagnostics
101
-
102
-
103
- def summarize_noise_bucket_diagnostics(
104
- *,
105
- noise_bucket: str | None,
106
- valid_revisit_mask: torch.Tensor | None,
107
- no_valid_revisit_mask: torch.Tensor | None,
108
- noise_bucket_ids: torch.Tensor | None = None,
109
- ) -> dict[str, Any]:
110
- bucket = normalize_noise_bucket(noise_bucket)
111
- diagnostics: dict[str, Any] = {
112
- "noise_bucket": bucket,
113
- "noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]),
114
- }
115
- for candidate in NOISE_BUCKETS:
116
- diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate)
117
-
118
- valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
119
- no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
120
- target_count = int(valid.numel())
121
- diagnostics["noise_bucket_target_count"] = target_count
122
- if noise_bucket_ids is None:
123
- target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long)
124
- else:
125
- target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu()
126
- if int(target_bucket_ids.numel()) != target_count:
127
- raise ValueError(
128
- f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}"
129
- )
130
-
131
- for bucket_name in NOISE_BUCKETS:
132
- bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
133
- diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item())
134
-
135
- mask_specs = (
136
- ("valid_revisit", valid),
137
- ("no_valid_revisit", no_valid),
138
- )
139
- for mask_name, mask in mask_specs:
140
- for bucket_name in NOISE_BUCKETS:
141
- bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
142
- count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0
143
- diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count
144
- return diagnostics
145
-
146
-
147
- def summarize_eval_ablation_diagnostics(
148
- *,
149
- enabled: bool,
150
- branch: str | None,
151
- valid_revisit_mask: torch.Tensor | None,
152
- no_valid_revisit_mask: torch.Tensor | None,
153
- eval_corrupted_revisit_mask: torch.Tensor | None,
154
- ) -> dict[str, Any]:
155
- branch = normalize_eval_ablation_branch(branch)
156
- valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
157
- no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
158
- corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu()
159
- true_revisit = valid & (~corrupted)
160
- diagnostics: dict[str, Any] = {
161
- "eval_ablation_enabled": bool(enabled),
162
- "eval_ablation_branch": branch,
163
- "eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]),
164
- "eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()),
165
- "eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()),
166
- "eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()),
167
- }
168
- total = max(int(valid.numel()), 1)
169
- diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total)
170
- diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total)
171
- diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total)
172
- return diagnostics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
algorithms/worldmem/dememwm/injection.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from __future__ import annotations
3
 
4
  from dataclasses import dataclass
@@ -6,7 +5,6 @@ from typing import Any
6
 
7
  import torch
8
 
9
- from .diagnostics import summarize_stream
10
  from .types import MemoryStreamTensors
11
 
12
 
@@ -31,7 +29,12 @@ class InjectionAdapter:
31
  return gate.to(device=device, dtype=dtype)
32
  return torch.tensor(float(gate), device=device, dtype=dtype)
33
 
34
- def __call__(self, streams: MemoryStreamTensors, device=None, dtype=None) -> tuple[dict[str, Any], dict[str, Any]]:
 
 
 
 
 
35
  ref = streams.anchor_tokens
36
  device = device or ref.device
37
  dtype = dtype or ref.dtype
@@ -62,22 +65,7 @@ class InjectionAdapter:
62
  if not revisit_mask.any():
63
  kwargs["memory_retrieval_tokens"] = None
64
  kwargs["memory_retrieval_mask"] = None
65
- diagnostics = dict(streams.diagnostics)
66
- diagnostics.update(summarize_stream("anchor", anchor_tokens, anchor_mask, kwargs["memory_anchor_gate"]))
67
- diagnostics.update(summarize_stream("dynamic", dynamic_tokens, dynamic_mask, kwargs["memory_dynamic_gate"]))
68
- diagnostics.update(summarize_stream("revisit", revisit_tokens, revisit_mask, kwargs["memory_retrieval_gate"]))
69
- if streams.revisit_gate_raw is not None:
70
- raw_gate = streams.revisit_gate_raw.to(device=device, dtype=dtype)
71
- diagnostics["revisit_gate_raw"] = raw_gate
72
- diagnostics["revisit_gate_raw_mean"] = float(raw_gate.detach().float().mean().item()) if raw_gate.numel() else 0.0
73
- diagnostics["revisit_gate_raw_min"] = float(raw_gate.detach().float().min().item()) if raw_gate.numel() else 0.0
74
- diagnostics["revisit_gate_raw_max"] = float(raw_gate.detach().float().max().item()) if raw_gate.numel() else 0.0
75
- if streams.no_valid_revisit_mask is not None:
76
- diagnostics["no_valid_revisit_mask"] = streams.no_valid_revisit_mask.to(device=device, dtype=torch.bool)
77
- max_sources = [v for k, v in streams.diagnostics.items() if k.endswith("max_source_frame")]
78
- if max_sources:
79
- diagnostics["max_source_frame"] = max(int(torch.as_tensor(v).max().item()) for v in max_sources)
80
- return kwargs, diagnostics
81
 
82
 
83
  DeMemWMInjectionAdapter = InjectionAdapter
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
 
5
 
6
  import torch
7
 
 
8
  from .types import MemoryStreamTensors
9
 
10
 
 
29
  return gate.to(device=device, dtype=dtype)
30
  return torch.tensor(float(gate), device=device, dtype=dtype)
31
 
32
+ def __call__(
33
+ self,
34
+ streams: MemoryStreamTensors,
35
+ device=None,
36
+ dtype=None,
37
+ ) -> dict[str, Any]:
38
  ref = streams.anchor_tokens
39
  device = device or ref.device
40
  dtype = dtype or ref.dtype
 
65
  if not revisit_mask.any():
66
  kwargs["memory_retrieval_tokens"] = None
67
  kwargs["memory_retrieval_mask"] = None
68
+ return kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  DeMemWMInjectionAdapter = InjectionAdapter
algorithms/worldmem/dememwm/retrieval.py CHANGED
@@ -1,7 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import math
4
- from dataclasses import replace
5
  from typing import Any, Optional
6
 
7
  import torch
@@ -9,6 +9,7 @@ import torch
9
  from .labels import (
10
  LABEL_SOURCE,
11
  RevisitCandidateLabel,
 
12
  _inside_fov_3d_hv,
13
  _plucker_descriptor,
14
  _target_fov_points,
@@ -16,24 +17,19 @@ from .labels import (
16
  from .types import MemoryRecord, RevisitRetrievalResult
17
 
18
 
19
- def _overlap_values(labels, name: str) -> list[float]:
20
- values: list[float] = []
21
- for label in labels:
22
- value = getattr(label, name)
23
- if value is not None:
24
- values.append(float(value))
25
- return values
 
 
 
 
26
 
27
 
28
- def _overlap_stats(values: list[float], prefix: str) -> dict[str, float]:
29
- if not values:
30
- return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
31
- return {
32
- f"{prefix}_mean": float(sum(values) / len(values)),
33
- f"{prefix}_min": float(min(values)),
34
- f"{prefix}_max": float(max(values)),
35
- }
36
-
37
 
38
  def _pose_rows(pose) -> torch.Tensor | None:
39
  if pose is None:
@@ -58,6 +54,348 @@ def _pose_forward(poses: torch.Tensor) -> torch.Tensor:
58
  )
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None:
62
  if int(record.frame_indices.numel()) != 1:
63
  return None
@@ -83,19 +421,9 @@ def _vectorized_frame_candidate_labels(
83
  plucker_grid_w: int,
84
  plucker_focal_length: float,
85
  pose_preselect_topk: Optional[int],
86
- ) -> tuple[list[RevisitCandidateLabel], dict[str, float | int]]:
87
- diagnostics: dict[str, float | int] = {
88
- "revisit_pose_preselect_input_count": len(records),
89
- "revisit_pose_preselect_scored_count": len(records),
90
- "revisit_pose_preselect_unscored_count": 0,
91
- "revisit_pose_preselect_selected_count": len(records),
92
- "revisit_pose_preselect_min_distance": 0.0,
93
- "revisit_pose_preselect_max_distance": 0.0,
94
- "revisit_exact_fov_candidate_count": len(records),
95
- "revisit_vectorized_frame_scorer_used": 1,
96
- }
97
  if not records:
98
- return [], diagnostics
99
 
100
  target_poses = _pose_rows(target_pose)
101
  if target_poses is None:
@@ -137,9 +465,6 @@ def _vectorized_frame_candidate_labels(
137
  ]
138
  ranked.sort()
139
  selected_indices = [idx for *_, idx in ranked[:topk]]
140
- diagnostics["revisit_pose_preselect_selected_count"] = len(selected_indices)
141
- diagnostics["revisit_pose_preselect_min_distance"] = float(min(distance_values))
142
- diagnostics["revisit_pose_preselect_max_distance"] = float(max(distance_values))
143
 
144
  selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long)
145
  selected_records = [records[idx] for idx in selected_indices]
@@ -175,7 +500,6 @@ def _vectorized_frame_candidate_labels(
175
  if fov_overlap_threshold is not None:
176
  valid_mask = fov_values >= float(fov_overlap_threshold)
177
 
178
- diagnostics["revisit_exact_fov_candidate_count"] = len(selected_records)
179
  fov_list = [float(value) for value in fov_values.detach().cpu().tolist()]
180
  plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()]
181
  valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()]
@@ -200,7 +524,7 @@ def _vectorized_frame_candidate_labels(
200
  best_frame_fov_overlap=fov_overlap,
201
  )
202
  )
203
- return labels, diagnostics
204
 
205
 
206
  def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float:
@@ -297,23 +621,6 @@ def _best_selected_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidat
297
  )
298
 
299
 
300
- def _best_selected_frame_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None:
301
- frame_labels = [label for label in labels if label.best_frame_fov_overlap is not None]
302
- if not frame_labels:
303
- return None
304
- return max(
305
- frame_labels,
306
- key=lambda label: (
307
- float(label.best_frame_fov_overlap),
308
- 0.0 if label.fov_overlap is None else float(label.fov_overlap),
309
- 0.0 if label.plucker_overlap is None else float(label.plucker_overlap),
310
- -int(label.gap_to_target),
311
- -int(label.record.source_start),
312
- str(label.record.chunk_id or ""),
313
- ),
314
- )
315
-
316
-
317
  def _record_with_selected_frame_metadata(
318
  label: RevisitCandidateLabel,
319
  *,
@@ -334,6 +641,11 @@ def _record_with_selected_frame_metadata(
334
  return replace(label.record, metadata=metadata)
335
 
336
 
 
 
 
 
 
337
  def deterministic_revisit_retrieval(
338
  records: list[MemoryRecord],
339
  target_frame: int,
@@ -367,7 +679,7 @@ def deterministic_revisit_retrieval(
367
  for record in causal_records
368
  if int(record.source_end) <= target_frame - exclude_local_context_frames
369
  ]
370
- labels, pose_preselect_diagnostics = _vectorized_frame_candidate_labels(
371
  score_records,
372
  target_frame=target_frame,
373
  target_pose=target_pose,
@@ -383,21 +695,16 @@ def deterministic_revisit_retrieval(
383
  plucker_focal_length=plucker_focal_length,
384
  pose_preselect_topk=pose_preselect_topk,
385
  )
386
- exact_fov_candidate_count = int(pose_preselect_diagnostics["revisit_exact_fov_candidate_count"])
387
  valid_labels = [label for label in labels if label.valid]
388
- selected_labels, selected_scores, selected_gains = _select_greedy_coverage(
389
  valid_labels,
390
  topk=topk,
391
  plucker_weight=float(plucker_weight),
392
  )
393
  best_selected = _best_selected_label(selected_labels)
394
- best_selected_frame = _best_selected_frame_label(selected_labels)
395
  best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap)
396
  best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap)
397
  best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target)
398
- best_selected_frame_fov = 0.0 if best_selected_frame is None else float(best_selected_frame.best_frame_fov_overlap)
399
- best_selected_frame_index = -1 if best_selected_frame is None or best_selected_frame.best_frame_index is None else int(best_selected_frame.best_frame_index)
400
- high_quality_selected = int(best_selected_frame is not None and best_selected_frame_fov >= float(high_quality_fov_threshold))
401
  selected_records = [
402
  _record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold))
403
  for label in selected_labels
@@ -405,71 +712,11 @@ def deterministic_revisit_retrieval(
405
  score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu")
406
  scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device)
407
 
408
- fov_values = _overlap_values(valid_labels, "fov_overlap")
409
- plucker_values = _overlap_values(valid_labels, "plucker_overlap")
410
- selected_gaps = [label.gap_to_target for label in selected_labels]
411
- selected_frame_fov_values = [
412
- float(label.best_frame_fov_overlap)
413
- for label in selected_labels
414
- if label.best_frame_fov_overlap is not None
415
- ]
416
- diagnostics = {
417
- "target_frame": int(target_frame),
418
- "candidate_count": len(causal_records),
419
- "candidate_frame_count": len(causal_records),
420
- "valid_candidate_count": len(valid_labels),
421
- "revisit_exact_fov_candidate_count": exact_fov_candidate_count,
422
- "valid_candidate_frame_count": len(valid_labels),
423
- "valid_candidate_label_count": len(valid_labels),
424
- "selected_count": len(selected_records),
425
- "selected_frame_count": len(selected_records),
426
- "revisit_candidate_frame_count": len(causal_records),
427
- "revisit_candidate_count": len(causal_records),
428
- "valid_revisit_frame_count": len(valid_labels),
429
- "valid_revisit_count": len(valid_labels),
430
- "no_valid_revisit_count": int(len(valid_labels) == 0),
431
- "valid_revisit_mask": int(len(valid_labels) > 0),
432
- "revisit_abstained_count": int(len(selected_records) == 0),
433
- "abstained": bool(len(selected_records) == 0),
434
- "revisit_selected_frame_count": len(selected_records),
435
- "revisit_selected_count": len(selected_records),
436
- "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
437
- "best_selected_fov_overlap": best_selected_fov,
438
- "best_selected_plucker_overlap": best_selected_plucker,
439
- "best_selected_gap_frames": best_selected_gap,
440
- "best_selected_frame_index": best_selected_frame_index,
441
- "best_selected_frame_fov_overlap": best_selected_frame_fov,
442
- "best_selected_frame_passes_high_quality": high_quality_selected,
443
- "high_quality_selected_revisit": high_quality_selected,
444
- "high_quality_fov_threshold": float(high_quality_fov_threshold),
445
- "revisit_label_source": LABEL_SOURCE,
446
- "selected_frame_ids": [int(record.max_source_frame) for record in selected_records],
447
- "selected_frame_record_ids": [record.chunk_id for record in selected_records],
448
- "selected_ranges": [(record.source_start, record.source_end) for record in selected_records],
449
- "frame_fov_overlap_values": fov_values,
450
- "fov_overlap_values": fov_values,
451
- "plucker_overlap_values": plucker_values,
452
- "best_selected_fov_overlap_values": [] if best_selected is None else [best_selected_fov],
453
- "best_selected_plucker_overlap_values": [] if best_selected is None else [best_selected_plucker],
454
- "best_selected_gap_frame_values": [] if best_selected is None else [best_selected_gap],
455
- "best_selected_frame_fov_overlap_values": [] if best_selected_frame is None else [best_selected_frame_fov],
456
- "selected_frame_fov_overlap_values": selected_frame_fov_values,
457
- "selected_incremental_fov_overlap_values": selected_gains,
458
- "selected_revisit_scores": selected_scores,
459
- **pose_preselect_diagnostics,
460
- }
461
- diagnostics.update(_overlap_stats(fov_values, "revisit_frame_fov_overlap"))
462
- diagnostics.update(_overlap_stats(fov_values, "revisit_fov_overlap"))
463
- diagnostics.update(_overlap_stats(plucker_values, "revisit_plucker_overlap"))
464
- diagnostics.update(_overlap_stats(diagnostics["best_selected_fov_overlap_values"], "revisit_best_selected_fov_overlap"))
465
- diagnostics.update(_overlap_stats(diagnostics["best_selected_plucker_overlap_values"], "revisit_best_selected_plucker_overlap"))
466
- diagnostics.update(_overlap_stats(diagnostics["best_selected_gap_frame_values"], "revisit_best_selected_gap_frames"))
467
- diagnostics.update(_overlap_stats(diagnostics["best_selected_frame_fov_overlap_values"], "revisit_best_selected_frame_fov_overlap"))
468
- diagnostics.update(_overlap_stats(selected_frame_fov_values, "revisit_selected_frame_fov_overlap"))
469
- diagnostics.update(_overlap_stats(selected_gains, "revisit_incremental_fov_overlap"))
470
  return RevisitRetrievalResult(
471
  records=selected_records,
472
  scores=scores,
473
  selected_frame_ids=[int(record.max_source_frame) for record in selected_records],
474
- diagnostics=diagnostics,
 
 
475
  )
 
1
  from __future__ import annotations
2
 
3
  import math
4
+ from dataclasses import dataclass, replace
5
  from typing import Any, Optional
6
 
7
  import torch
 
9
  from .labels import (
10
  LABEL_SOURCE,
11
  RevisitCandidateLabel,
12
+ _angle_diff_degrees,
13
  _inside_fov_3d_hv,
14
  _plucker_descriptor,
15
  _target_fov_points,
 
17
  from .types import MemoryRecord, RevisitRetrievalResult
18
 
19
 
20
+ @dataclass
21
+ class BatchedRevisitSelectionResult:
22
+ selected_positions: torch.Tensor
23
+ selected_mask: torch.Tensor
24
+ selected_scores: torch.Tensor
25
+ selected_fov_overlap: torch.Tensor
26
+ selected_plucker_overlap: torch.Tensor
27
+ selected_gap_frames: torch.Tensor
28
+ best_selected_fov_overlap: torch.Tensor
29
+ best_selected_plucker_overlap: torch.Tensor
30
+ best_selected_gap_frames: torch.Tensor
31
 
32
 
 
 
 
 
 
 
 
 
 
33
 
34
  def _pose_rows(pose) -> torch.Tensor | None:
35
  if pose is None:
 
54
  )
55
 
56
 
57
+
58
+ def _time_batch_pose_rows(
59
+ pose: torch.Tensor,
60
+ *,
61
+ time: int,
62
+ batch: int,
63
+ name: str,
64
+ ) -> torch.Tensor:
65
+ if pose is None:
66
+ raise ValueError(f"{name} is required for batched DeMemWM revisit retrieval")
67
+ pose_tensor = pose if torch.is_tensor(pose) else torch.as_tensor(pose, dtype=torch.float32)
68
+ if pose_tensor.ndim < 3 or pose_tensor.shape[-1] < 5:
69
+ raise ValueError(f"{name} must have shape (T,B,D) or (B,T,D) with D >= 5")
70
+ if pose_tensor.shape[0] == time and pose_tensor.shape[1] == batch:
71
+ pose_tb = pose_tensor
72
+ elif pose_tensor.shape[0] == batch and pose_tensor.shape[1] == time:
73
+ pose_tb = pose_tensor.transpose(0, 1)
74
+ else:
75
+ raise ValueError(f"{name} must match time/batch dimensions ({time},{batch})")
76
+ return pose_tb[..., :5].detach().to(dtype=torch.float32)
77
+
78
+
79
+ def _time_batch_mask(
80
+ mask: torch.Tensor | None,
81
+ *,
82
+ time: int,
83
+ batch: int,
84
+ device: torch.device,
85
+ ) -> torch.Tensor:
86
+ if mask is None:
87
+ return torch.ones((time, batch), device=device, dtype=torch.bool)
88
+ mask_tensor = mask if torch.is_tensor(mask) else torch.as_tensor(mask)
89
+ if mask_tensor.ndim != 2:
90
+ raise ValueError("source_candidate_mask must have shape (T,B) or (B,T)")
91
+ if mask_tensor.shape == (time, batch):
92
+ mask_tb = mask_tensor
93
+ elif mask_tensor.shape == (batch, time):
94
+ mask_tb = mask_tensor.transpose(0, 1)
95
+ else:
96
+ raise ValueError(f"source_candidate_mask must match time/batch dimensions ({time},{batch})")
97
+ return mask_tb.to(device=device, dtype=torch.bool)
98
+
99
+
100
+ def _target_fov_points_batched(
101
+ target_poses: torch.Tensor,
102
+ *,
103
+ fov_half_h: float,
104
+ fov_half_v: float,
105
+ yaw_samples: int,
106
+ pitch_samples: int,
107
+ depth_samples: int,
108
+ radius: float,
109
+ ) -> torch.Tensor:
110
+ yaw_samples = max(1, int(yaw_samples))
111
+ pitch_samples = max(1, int(pitch_samples))
112
+ depth_samples = max(1, int(depth_samples))
113
+ device = target_poses.device
114
+ dtype = target_poses.dtype
115
+ if yaw_samples == 1:
116
+ yaw_offsets = torch.zeros((1,), device=device, dtype=dtype)
117
+ else:
118
+ yaw_offsets = torch.linspace(-float(fov_half_h), float(fov_half_h), yaw_samples + 2, device=device, dtype=dtype)[1:-1]
119
+ if pitch_samples == 1:
120
+ pitch_offsets = torch.zeros((1,), device=device, dtype=dtype)
121
+ else:
122
+ pitch_offsets = torch.linspace(-float(fov_half_v), float(fov_half_v), pitch_samples + 2, device=device, dtype=dtype)[1:-1]
123
+ if depth_samples == 1:
124
+ depths = torch.full((1,), float(radius), device=device, dtype=dtype)
125
+ else:
126
+ depths = torch.linspace(float(radius) / float(depth_samples), float(radius), depth_samples, device=device, dtype=dtype)
127
+ depth_grid, pitch_grid, yaw_grid = torch.meshgrid(depths, pitch_offsets, yaw_offsets, indexing="ij")
128
+ pitch_offsets_flat = pitch_grid.reshape(1, -1)
129
+ yaw_offsets_flat = yaw_grid.reshape(1, -1)
130
+ depth = depth_grid.reshape(1, -1)
131
+ pitch = torch.deg2rad(target_poses[:, 3:4] + pitch_offsets_flat)
132
+ yaw = torch.deg2rad(target_poses[:, 4:5] + yaw_offsets_flat)
133
+ cos_pitch = torch.cos(pitch)
134
+ vectors = torch.stack(
135
+ [
136
+ depth * cos_pitch * torch.sin(yaw),
137
+ depth * torch.sin(pitch),
138
+ depth * cos_pitch * torch.cos(yaw),
139
+ ],
140
+ dim=-1,
141
+ )
142
+ return target_poses[:, None, :3] + vectors
143
+
144
+
145
+ def _inside_fov_3d_hv_batched(
146
+ points: torch.Tensor,
147
+ poses: torch.Tensor,
148
+ *,
149
+ fov_half_h: float,
150
+ fov_half_v: float,
151
+ ) -> torch.Tensor:
152
+ vectors = points[:, None, :, :] - poses[:, :, None, :3]
153
+ x = vectors[..., 0]
154
+ y = vectors[..., 1]
155
+ z = vectors[..., 2]
156
+ azimuth = torch.atan2(x, z) * (180.0 / math.pi)
157
+ elevation = torch.atan2(y, torch.sqrt(x.square() + z.square()).clamp_min(1e-8)) * (180.0 / math.pi)
158
+ diff_azimuth = _angle_diff_degrees(azimuth - poses[:, :, None, 4])
159
+ diff_elevation = _angle_diff_degrees(elevation - poses[:, :, None, 3])
160
+ return (diff_azimuth < float(fov_half_h)) & (diff_elevation < float(fov_half_v))
161
+
162
+
163
+ def _batched_tie_mask(mask: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
164
+ neg_inf = torch.full_like(values, -float("inf"))
165
+ best = torch.where(mask, values, neg_inf).max(dim=1).values
166
+ return mask & torch.isclose(values, best[:, None], rtol=0.0, atol=1e-12)
167
+
168
+
169
+ def batched_revisit_select_positions(
170
+ source_frame_indices: torch.Tensor,
171
+ source_pose: torch.Tensor,
172
+ target_frame_indices: torch.Tensor,
173
+ target_pose: torch.Tensor,
174
+ *,
175
+ source_candidate_mask: torch.Tensor | None = None,
176
+ topk: int = 2,
177
+ exclude_local_context_frames: int = 0,
178
+ fov_overlap_threshold: Optional[float] = 0.30,
179
+ plucker_weight: float = 0.1,
180
+ fov_half_h: float = 105.0 / 2.0,
181
+ fov_half_v: float = 75.0 / 2.0,
182
+ fov_yaw_samples: int = 25,
183
+ fov_pitch_samples: int = 20,
184
+ fov_depth_samples: int = 20,
185
+ fov_radius: float = 30.0,
186
+ plucker_grid_h: int = 4,
187
+ plucker_grid_w: int = 4,
188
+ plucker_focal_length: float = 0.35,
189
+ pose_preselect_topk: Optional[int] = 64,
190
+ query_chunk_size: int = 16,
191
+ ) -> BatchedRevisitSelectionResult:
192
+ if source_frame_indices.ndim != 2:
193
+ raise ValueError("source_frame_indices must have shape (T_src,B)")
194
+ if target_frame_indices.ndim == 1:
195
+ target_frame_indices = target_frame_indices[:, None]
196
+ if target_frame_indices.ndim != 2:
197
+ raise ValueError("target_frame_indices must have shape (T_tgt,B)")
198
+
199
+ T_src, B = source_frame_indices.shape
200
+ T_tgt, B_tgt = target_frame_indices.shape
201
+ if B_tgt != B:
202
+ raise ValueError("source_frame_indices and target_frame_indices must share batch dimension")
203
+
204
+ source_pose_tensor = source_pose if torch.is_tensor(source_pose) else torch.as_tensor(source_pose, dtype=torch.float32)
205
+ target_pose_tensor = target_pose if torch.is_tensor(target_pose) else torch.as_tensor(target_pose, dtype=torch.float32)
206
+ device = source_pose_tensor.device
207
+ if target_pose_tensor.is_cuda:
208
+ device = target_pose_tensor.device
209
+ elif source_pose_tensor.is_cuda:
210
+ device = source_pose_tensor.device
211
+ source_frames_tb = source_frame_indices.to(device=device)
212
+ target_frames_tb = target_frame_indices.to(device=device)
213
+ source_pose_tb = _time_batch_pose_rows(source_pose_tensor.to(device=device), time=T_src, batch=B, name="source_pose")
214
+ target_pose_tb = _time_batch_pose_rows(target_pose_tensor.to(device=device), time=T_tgt, batch=B, name="target_pose")
215
+ candidate_mask_tb = _time_batch_mask(source_candidate_mask, time=T_src, batch=B, device=device)
216
+
217
+ topk = max(0, int(topk))
218
+ selected_positions = torch.full((B, T_tgt, topk), -1, device=device, dtype=torch.long)
219
+ selected_mask = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.bool)
220
+ selected_scores = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.float32)
221
+ selected_fov_overlap = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.float32)
222
+ selected_plucker_overlap = torch.zeros((B, T_tgt, topk), device=device, dtype=torch.float32)
223
+ selected_gap_frames = torch.full((B, T_tgt, topk), -1.0, device=device, dtype=torch.float32)
224
+ best_fov = torch.zeros((B, T_tgt), device=device, dtype=torch.float32)
225
+ best_plucker = torch.zeros((B, T_tgt), device=device, dtype=torch.float32)
226
+ best_gap = torch.full((B, T_tgt), -1.0, device=device, dtype=torch.float32)
227
+ if topk == 0 or T_src == 0 or T_tgt == 0 or B == 0:
228
+ return BatchedRevisitSelectionResult(
229
+ selected_positions=selected_positions,
230
+ selected_mask=selected_mask,
231
+ selected_scores=selected_scores,
232
+ selected_fov_overlap=selected_fov_overlap,
233
+ selected_plucker_overlap=selected_plucker_overlap,
234
+ selected_gap_frames=selected_gap_frames,
235
+ best_selected_fov_overlap=best_fov,
236
+ best_selected_plucker_overlap=best_plucker,
237
+ best_selected_gap_frames=best_gap,
238
+ )
239
+
240
+ source_pose_flat = source_pose_tb.reshape(-1, source_pose_tb.shape[-1])
241
+ source_forward_tb = _pose_forward(source_pose_flat).reshape(T_src, B, 3)
242
+ source_desc_tb = _plucker_descriptor(
243
+ source_pose_flat,
244
+ grid_h=plucker_grid_h,
245
+ grid_w=plucker_grid_w,
246
+ focal_length=plucker_focal_length,
247
+ ).reshape(T_src, B, -1)
248
+ target_frames_flat = target_frames_tb.transpose(0, 1).contiguous().reshape(-1)
249
+ target_pose_flat = target_pose_tb.transpose(0, 1).contiguous().reshape(-1, target_pose_tb.shape[-1])
250
+ target_forward_flat = _pose_forward(target_pose_flat)
251
+ target_desc_flat = _plucker_descriptor(
252
+ target_pose_flat,
253
+ grid_h=plucker_grid_h,
254
+ grid_w=plucker_grid_w,
255
+ focal_length=plucker_focal_length,
256
+ )
257
+ batch_ids = torch.arange(B, device=device, dtype=torch.long).repeat_interleave(T_tgt)
258
+ Q = int(target_frames_flat.numel())
259
+ chunk_size = max(1, int(query_chunk_size))
260
+ source_positions = torch.arange(T_src, device=device, dtype=torch.long)
261
+ pose_topk = None if pose_preselect_topk is None else int(pose_preselect_topk)
262
+
263
+ selected_positions_flat_view = selected_positions.reshape(-1, topk)
264
+ selected_mask_flat_view = selected_mask.reshape(-1, topk)
265
+ selected_scores_flat_view = selected_scores.reshape(-1, topk)
266
+ selected_fov_flat_view = selected_fov_overlap.reshape(-1, topk)
267
+ selected_plucker_flat_view = selected_plucker_overlap.reshape(-1, topk)
268
+ selected_gap_flat_view = selected_gap_frames.reshape(-1, topk)
269
+ best_fov_flat_view = best_fov.reshape(-1)
270
+ best_plucker_flat_view = best_plucker.reshape(-1)
271
+ best_gap_flat_view = best_gap.reshape(-1)
272
+
273
+ for start in range(0, Q, chunk_size):
274
+ end = min(Q, start + chunk_size)
275
+ b_idx = batch_ids[start:end]
276
+ target_frames = target_frames_flat[start:end]
277
+ target_poses = target_pose_flat[start:end]
278
+ q = int(end - start)
279
+
280
+ source_frames = source_frames_tb.index_select(1, b_idx).transpose(0, 1).contiguous()
281
+ source_candidates = candidate_mask_tb.index_select(1, b_idx).transpose(0, 1).contiguous()
282
+ score_valid = source_candidates & (source_frames < (target_frames[:, None] - int(exclude_local_context_frames)))
283
+
284
+ if pose_topk is None or pose_topk <= 0 or T_src <= pose_topk:
285
+ preselect_idx = source_positions.reshape(1, -1).expand(q, -1)
286
+ preselected_valid = score_valid
287
+ else:
288
+ source_poses_q = source_pose_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
289
+ source_forward_q = source_forward_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
290
+ translation_norm = torch.linalg.vector_norm(source_poses_q[:, :, :3] - target_poses[:, None, :3], dim=-1) / max(float(fov_radius), 1e-6)
291
+ dot = (
292
+ source_forward_q * target_forward_flat[start:end].reshape(q, 1, 3)
293
+ ).sum(dim=-1).clamp(-1.0, 1.0)
294
+ pose_distance = translation_norm + (torch.acos(dot) / math.pi)
295
+ rank = (
296
+ pose_distance.to(dtype=torch.float64)
297
+ - source_frames.to(dtype=torch.float64) * 1e-12
298
+ + source_frames.to(dtype=torch.float64) * 1e-15
299
+ )
300
+ rank = rank.masked_fill(~score_valid, float("inf"))
301
+ k_pre = min(max(1, pose_topk), T_src)
302
+ top = torch.topk(rank, k=k_pre, largest=False, sorted=True)
303
+ preselect_idx = top.indices
304
+ preselected_valid = torch.isfinite(top.values) & torch.gather(score_valid, 1, preselect_idx)
305
+
306
+ K = int(preselect_idx.shape[1])
307
+ if K == 0:
308
+ continue
309
+
310
+ gather_pose_idx = preselect_idx.unsqueeze(-1).expand(-1, -1, source_pose_tb.shape[-1])
311
+ source_poses_q = source_pose_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
312
+ selected_poses = torch.gather(source_poses_q, 1, gather_pose_idx)
313
+ selected_frames = torch.gather(source_frames, 1, preselect_idx)
314
+
315
+ points = _target_fov_points_batched(
316
+ target_poses,
317
+ fov_half_h=fov_half_h,
318
+ fov_half_v=fov_half_v,
319
+ yaw_samples=fov_yaw_samples,
320
+ pitch_samples=fov_pitch_samples,
321
+ depth_samples=fov_depth_samples,
322
+ radius=fov_radius,
323
+ )
324
+ inside = _inside_fov_3d_hv_batched(points, selected_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v)
325
+ fov_values = inside.float().mean(dim=2)
326
+
327
+ source_desc_q = source_desc_tb.index_select(1, b_idx).permute(1, 0, 2).contiguous()
328
+ gather_desc_idx = preselect_idx.unsqueeze(-1).expand(-1, -1, source_desc_tb.shape[-1])
329
+ selected_desc = torch.gather(source_desc_q, 1, gather_desc_idx)
330
+ diff = selected_desc - target_desc_flat[start:end, None, :]
331
+ plucker_distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
332
+ plucker_values = 1.0 / (1.0 + plucker_distance.clamp_min(0.0))
333
+
334
+ valid_mask = preselected_valid
335
+ if fov_overlap_threshold is not None:
336
+ valid_mask = valid_mask & (fov_values >= float(fov_overlap_threshold))
337
+
338
+ remaining = valid_mask.clone()
339
+ covered = torch.zeros((q, inside.shape[2]), device=device, dtype=torch.bool)
340
+ chosen_positions = torch.full((q, topk), -1, device=device, dtype=torch.long)
341
+ chosen_scores = torch.zeros((q, topk), device=device, dtype=torch.float32)
342
+ chosen_fov = torch.zeros((q, topk), device=device, dtype=torch.float32)
343
+ chosen_plucker = torch.zeros((q, topk), device=device, dtype=torch.float32)
344
+ chosen_gap = torch.full((q, topk), -1.0, device=device, dtype=torch.float32)
345
+ row_idx = torch.arange(q, device=device, dtype=torch.long)
346
+ gap_values = target_frames[:, None] - selected_frames
347
+ for slot in range(topk):
348
+ active = remaining.any(dim=1)
349
+ gains = (inside & ~covered[:, None, :]).float().mean(dim=2)
350
+ tied = _batched_tie_mask(remaining, gains)
351
+ tied = _batched_tie_mask(tied, fov_values)
352
+ tied = _batched_tie_mask(tied, plucker_values * float(plucker_weight))
353
+ tied = _batched_tie_mask(tied, -gap_values.to(dtype=torch.float32))
354
+ tied = _batched_tie_mask(tied, -selected_frames.to(dtype=torch.float32))
355
+ best_idx = tied.to(dtype=torch.long).argmax(dim=1)
356
+ chosen_positions[active, slot] = preselect_idx[row_idx[active], best_idx[active]]
357
+ chosen_scores[active, slot] = gains[row_idx[active], best_idx[active]]
358
+ chosen_fov[active, slot] = fov_values[row_idx[active], best_idx[active]]
359
+ chosen_plucker[active, slot] = plucker_values[row_idx[active], best_idx[active]]
360
+ chosen_gap[active, slot] = gap_values[row_idx[active], best_idx[active]].to(dtype=torch.float32)
361
+ covered[active] = covered[active] | inside[row_idx[active], best_idx[active]]
362
+ remaining[row_idx[active], best_idx[active]] = False
363
+
364
+ chosen_mask = chosen_positions >= 0
365
+ chosen_rank = (
366
+ chosen_fov.to(dtype=torch.float64)
367
+ + chosen_plucker.to(dtype=torch.float64) * 1e-9
368
+ - chosen_gap.to(dtype=torch.float64) * 1e-12
369
+ ).masked_fill(~chosen_mask, -float("inf"))
370
+ has_choice = chosen_mask.any(dim=1)
371
+ best_slot = chosen_rank.argmax(dim=1)
372
+ best_fov_flat = torch.where(has_choice, chosen_fov[row_idx, best_slot], torch.zeros((q,), device=device, dtype=torch.float32))
373
+ best_plucker_flat = torch.where(has_choice, chosen_plucker[row_idx, best_slot], torch.zeros((q,), device=device, dtype=torch.float32))
374
+ best_gap_flat = torch.where(has_choice, chosen_gap[row_idx, best_slot], torch.full((q,), -1.0, device=device, dtype=torch.float32))
375
+
376
+ selected_positions_flat_view[start:end] = chosen_positions
377
+ selected_mask_flat_view[start:end] = chosen_mask
378
+ selected_scores_flat_view[start:end] = chosen_scores
379
+ selected_fov_flat_view[start:end] = chosen_fov
380
+ selected_plucker_flat_view[start:end] = chosen_plucker
381
+ selected_gap_flat_view[start:end] = chosen_gap
382
+ best_fov_flat_view[start:end] = best_fov_flat
383
+ best_plucker_flat_view[start:end] = best_plucker_flat
384
+ best_gap_flat_view[start:end] = best_gap_flat
385
+
386
+ return BatchedRevisitSelectionResult(
387
+ selected_positions=selected_positions,
388
+ selected_mask=selected_mask,
389
+ selected_scores=selected_scores,
390
+ selected_fov_overlap=selected_fov_overlap,
391
+ selected_plucker_overlap=selected_plucker_overlap,
392
+ selected_gap_frames=selected_gap_frames,
393
+ best_selected_fov_overlap=best_fov,
394
+ best_selected_plucker_overlap=best_plucker,
395
+ best_selected_gap_frames=best_gap,
396
+ )
397
+
398
+
399
  def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None:
400
  if int(record.frame_indices.numel()) != 1:
401
  return None
 
421
  plucker_grid_w: int,
422
  plucker_focal_length: float,
423
  pose_preselect_topk: Optional[int],
424
+ ) -> list[RevisitCandidateLabel]:
 
 
 
 
 
 
 
 
 
 
425
  if not records:
426
+ return []
427
 
428
  target_poses = _pose_rows(target_pose)
429
  if target_poses is None:
 
465
  ]
466
  ranked.sort()
467
  selected_indices = [idx for *_, idx in ranked[:topk]]
 
 
 
468
 
469
  selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long)
470
  selected_records = [records[idx] for idx in selected_indices]
 
500
  if fov_overlap_threshold is not None:
501
  valid_mask = fov_values >= float(fov_overlap_threshold)
502
 
 
503
  fov_list = [float(value) for value in fov_values.detach().cpu().tolist()]
504
  plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()]
505
  valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()]
 
524
  best_frame_fov_overlap=fov_overlap,
525
  )
526
  )
527
+ return labels
528
 
529
 
530
  def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float:
 
621
  )
622
 
623
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  def _record_with_selected_frame_metadata(
625
  label: RevisitCandidateLabel,
626
  *,
 
641
  return replace(label.record, metadata=metadata)
642
 
643
 
644
+
645
+ def _record_frame_id(record: MemoryRecord) -> int:
646
+ return int(record.source_end) - 1
647
+
648
+
649
  def deterministic_revisit_retrieval(
650
  records: list[MemoryRecord],
651
  target_frame: int,
 
679
  for record in causal_records
680
  if int(record.source_end) <= target_frame - exclude_local_context_frames
681
  ]
682
+ labels = _vectorized_frame_candidate_labels(
683
  score_records,
684
  target_frame=target_frame,
685
  target_pose=target_pose,
 
695
  plucker_focal_length=plucker_focal_length,
696
  pose_preselect_topk=pose_preselect_topk,
697
  )
 
698
  valid_labels = [label for label in labels if label.valid]
699
+ selected_labels, selected_scores, _ = _select_greedy_coverage(
700
  valid_labels,
701
  topk=topk,
702
  plucker_weight=float(plucker_weight),
703
  )
704
  best_selected = _best_selected_label(selected_labels)
 
705
  best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap)
706
  best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap)
707
  best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target)
 
 
 
708
  selected_records = [
709
  _record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold))
710
  for label in selected_labels
 
712
  score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu")
713
  scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device)
714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  return RevisitRetrievalResult(
716
  records=selected_records,
717
  scores=scores,
718
  selected_frame_ids=[int(record.max_source_frame) for record in selected_records],
719
+ best_selected_fov_overlap=torch.as_tensor(best_selected_fov, dtype=torch.float32, device=score_device),
720
+ best_selected_plucker_overlap=torch.as_tensor(best_selected_plucker, dtype=torch.float32, device=score_device),
721
+ best_selected_gap_frames=torch.as_tensor(float(best_selected_gap), dtype=torch.float32, device=score_device),
722
  )
algorithms/worldmem/dememwm/schedules.py CHANGED
@@ -8,8 +8,6 @@ import torch
8
 
9
  from .types import StreamGateState
10
 
11
- NOISE_BUCKETS = ("high", "mid", "low")
12
- NOISE_BUCKET_TO_ID = {name: idx for idx, name in enumerate(NOISE_BUCKETS)}
13
  EVAL_ABLATION_BRANCHES = (
14
  "memory_off",
15
  "A_only",
@@ -41,46 +39,6 @@ def _clamp01(value: float) -> float:
41
  return max(0.0, min(1.0, float(value)))
42
 
43
 
44
- def noise_bucket_from_denoising_fraction(denoising_fraction: float | None) -> str:
45
- if denoising_fraction is None:
46
- return "mid"
47
- frac = _clamp01(float(denoising_fraction))
48
- if frac < (1.0 / 3.0):
49
- return "high"
50
- if frac < (2.0 / 3.0):
51
- return "mid"
52
- return "low"
53
-
54
-
55
- def noise_bucket_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> str:
56
- if noise_levels is None or timesteps is None or int(timesteps) <= 1:
57
- return "mid"
58
- noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1))
59
- if noise_fraction >= (2.0 / 3.0):
60
- return "high"
61
- if noise_fraction >= (1.0 / 3.0):
62
- return "mid"
63
- return "low"
64
-
65
-
66
- def noise_bucket_ids_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> torch.Tensor | None:
67
- if noise_levels is None or timesteps is None or int(timesteps) <= 1:
68
- return None
69
- noise_fraction = noise_levels.detach().float() / float(int(timesteps) - 1)
70
- bucket_ids = torch.full_like(noise_levels, NOISE_BUCKET_TO_ID["mid"], dtype=torch.long)
71
- bucket_ids = torch.where(
72
- noise_fraction >= (2.0 / 3.0),
73
- torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["high"]),
74
- bucket_ids,
75
- )
76
- bucket_ids = torch.where(
77
- noise_fraction < (1.0 / 3.0),
78
- torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["low"]),
79
- bucket_ids,
80
- )
81
- return bucket_ids
82
-
83
-
84
  def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None:
85
  if noise_levels is None or timesteps is None or int(timesteps) <= 1:
86
  return None
@@ -97,12 +55,6 @@ def normalize_eval_ablation_branch(branch: str | None) -> str:
97
  return branch
98
 
99
 
100
- def normalize_noise_bucket(noise_bucket: str | None) -> str:
101
- if noise_bucket in NOISE_BUCKET_TO_ID:
102
- return str(noise_bucket)
103
- return "mid"
104
-
105
-
106
  _STAGE_ENABLES = {
107
  'stage_1': (True, True, True),
108
  'stage_2': (True, True, True),
@@ -131,22 +83,6 @@ class CurriculumState:
131
  def dit_full_trainable(self) -> bool:
132
  return self.dit_train_state == "full"
133
 
134
- def diagnostics(self) -> dict[str, Any]:
135
- return {
136
- "dememwm_global_step": self.global_step,
137
- "dememwm_curriculum_enabled": self.enabled,
138
- "dememwm_stage": self.stage,
139
- "curriculum_anchor_enabled": self.anchor_enabled,
140
- "curriculum_dynamic_enabled": self.dynamic_enabled,
141
- "curriculum_revisit_enabled": self.revisit_enabled,
142
- "dit_train_state": self.dit_train_state,
143
- "dit_full_trainable": self.dit_full_trainable,
144
- "freeze_vae": self.freeze_vae,
145
- "lr_dememwm_modules": self.dememwm_lr,
146
- "lr_memory_adapters": self.memory_adapter_lr,
147
- "lr_full_dit": self.full_dit_lr,
148
- }
149
-
150
 
151
  def _cfg_get(obj: Any, name: str, default: Any) -> Any:
152
  return getattr(obj, name, default) if obj is not None else default
@@ -209,9 +145,7 @@ DeMemWMCurriculumState = CurriculumState
209
  resolve_dememwm_curriculum = resolve_curriculum
210
 
211
 
212
- def compute_stream_gates(stage: str, denoising_fraction: float | None = None, debug_force_all_streams: bool = False, anchor_gate: float = 1.0, dynamic_gate: float = 1.0, revisit_gate: float = 1.0) -> StreamGateState:
213
- if debug_force_all_streams:
214
- return StreamGateState(True, True, True, float(anchor_gate), float(dynamic_gate), float(revisit_gate), stage, "debug_force_all_streams")
215
  if stage not in _STAGE_ENABLES:
216
  raise ValueError(f"unknown DeMemWM stage: {stage}")
217
  a_on, d_on, r_on = _STAGE_ENABLES[stage]
 
8
 
9
  from .types import StreamGateState
10
 
 
 
11
  EVAL_ABLATION_BRANCHES = (
12
  "memory_off",
13
  "A_only",
 
39
  return max(0.0, min(1.0, float(value)))
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None:
43
  if noise_levels is None or timesteps is None or int(timesteps) <= 1:
44
  return None
 
55
  return branch
56
 
57
 
 
 
 
 
 
 
58
  _STAGE_ENABLES = {
59
  'stage_1': (True, True, True),
60
  'stage_2': (True, True, True),
 
83
  def dit_full_trainable(self) -> bool:
84
  return self.dit_train_state == "full"
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def _cfg_get(obj: Any, name: str, default: Any) -> Any:
88
  return getattr(obj, name, default) if obj is not None else default
 
145
  resolve_dememwm_curriculum = resolve_curriculum
146
 
147
 
148
+ def compute_stream_gates(stage: str, denoising_fraction: float | None = None, anchor_gate: float = 1.0, dynamic_gate: float = 1.0, revisit_gate: float = 1.0) -> StreamGateState:
 
 
149
  if stage not in _STAGE_ENABLES:
150
  raise ValueError(f"unknown DeMemWM stage: {stage}")
151
  a_on, d_on, r_on = _STAGE_ENABLES[stage]
algorithms/worldmem/dememwm/types.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  from __future__ import annotations
3
 
4
  from dataclasses import dataclass, field
@@ -74,8 +73,9 @@ class MemoryStreamTensors:
74
  revisit_gate: torch.Tensor | float
75
  revisit_gate_raw: torch.Tensor | None = None
76
  valid_revisit_mask: torch.Tensor | None = None
77
- no_valid_revisit_mask: torch.Tensor | None = None
78
- diagnostics: dict[str, Any] = field(default_factory=dict)
 
79
 
80
 
81
  @dataclass(frozen=True)
@@ -95,4 +95,6 @@ class RevisitRetrievalResult:
95
  records: list[MemoryRecord]
96
  scores: torch.Tensor
97
  selected_frame_ids: list[int]
98
- diagnostics: dict[str, Any]
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass, field
 
73
  revisit_gate: torch.Tensor | float
74
  revisit_gate_raw: torch.Tensor | None = None
75
  valid_revisit_mask: torch.Tensor | None = None
76
+ revisit_best_selected_fov_overlap: torch.Tensor | None = None
77
+ revisit_best_selected_plucker_overlap: torch.Tensor | None = None
78
+ revisit_selected_gap_frames: torch.Tensor | None = None
79
 
80
 
81
  @dataclass(frozen=True)
 
95
  records: list[MemoryRecord]
96
  scores: torch.Tensor
97
  selected_frame_ids: list[int]
98
+ best_selected_fov_overlap: torch.Tensor
99
+ best_selected_plucker_overlap: torch.Tensor
100
+ best_selected_gap_frames: torch.Tensor
algorithms/worldmem/models/dit.py CHANGED
@@ -166,12 +166,6 @@ class MemoryTokenCrossAttention(nn.Module):
166
  self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
167
  self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
168
  self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
169
- self.last_gate_mean = None
170
- self.last_delta_ratio = None
171
- self.last_valid_fraction = None
172
- self.last_type_gate_mean = None
173
- for type_name in MEMORY_TYPE_NAMES[:num_memory_types]:
174
- setattr(self, f"last_type_gate_{type_name}_mean", None)
175
  nn.init.normal_(self.memory_type_embed.weight, std=0.02)
176
  self.reset_identity_init()
177
 
@@ -236,19 +230,11 @@ class MemoryTokenCrossAttention(nn.Module):
236
  type_scale = type_scale.unsqueeze(0)
237
  return memory_tokens * type_scale + type_embed
238
 
239
- def _store_type_gate_diagnostics(self, stage_gate):
240
- with torch.no_grad():
241
- detached = stage_gate.detach().float()
242
- self.last_type_gate_mean = detached.mean()
243
- for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]):
244
- setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean())
245
-
246
  def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
247
  if memory_type_ids is None:
248
  return None
249
  memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
250
  stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
251
- self._store_type_gate_diagnostics(stage_gate)
252
  if memory_tokens.dim() == 4:
253
  batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
254
  if memory_type_ids.dim() == 1:
@@ -329,33 +315,6 @@ class MemoryTokenCrossAttention(nn.Module):
329
  gate_tensor = gate_tensor.unsqueeze(-1)
330
  return gate_tensor
331
 
332
- def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows):
333
- with torch.no_grad():
334
- batch_size, num_frames = base.shape[:2]
335
- gate_values = torch.cat(
336
- [gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()],
337
- dim=-1,
338
- )
339
- gate_mask = self._gate_valid_mask(
340
- valid_rows,
341
- batch_size,
342
- num_frames,
343
- dtype=gate_values.dtype,
344
- device=gate_values.device,
345
- )
346
- if gate_mask is not None:
347
- gate_values = gate_values * gate_mask
348
- self.last_valid_fraction = valid_rows.detach().float().mean()
349
- valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0)
350
- self.last_gate_mean = gate_values.sum() / valid_count
351
- else:
352
- self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32)
353
- self.last_gate_mean = gate_values.mean()
354
-
355
- delta_norm = (output.detach().float() - base.detach().float()).norm()
356
- base_norm = base.detach().float().norm()
357
- self.last_delta_ratio = delta_norm / (base_norm + 1e-6)
358
-
359
  def forward(
360
  self,
361
  x,
@@ -437,7 +396,6 @@ class MemoryTokenCrossAttention(nn.Module):
437
  if residual_gate_tensor is not None:
438
  mlp_delta = mlp_delta * residual_gate_tensor
439
  output = output + mlp_delta
440
- self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows)
441
  if return_delta:
442
  return attn_delta + mlp_delta
443
  return output
@@ -767,38 +725,6 @@ class DiT(nn.Module):
767
  if memory_adapter is not None:
768
  memory_adapter.reset_identity_init()
769
 
770
- def memory_adapter_delta_diagnostics(self):
771
- diagnostics = {}
772
- ratios = []
773
- type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES}
774
- shared_type_gate_values = []
775
- for block in self.blocks:
776
- adapter = getattr(block, "memory_token_cross_attn", None)
777
- if adapter is None:
778
- continue
779
- ratio = getattr(adapter, "last_delta_ratio", None)
780
- if ratio is not None:
781
- ratios.append(torch.as_tensor(ratio).detach().float())
782
- type_gate = getattr(adapter, "last_type_gate_mean", None)
783
- if type_gate is not None:
784
- shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float())
785
- for type_name in MEMORY_TYPE_NAMES:
786
- value = getattr(adapter, f"last_type_gate_{type_name}_mean", None)
787
- if value is not None:
788
- type_gate_values[type_name].append(torch.as_tensor(value).detach().float())
789
- if ratios:
790
- values = torch.stack(ratios)
791
- diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item())
792
- diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item())
793
- if shared_type_gate_values:
794
- values = torch.stack(shared_type_gate_values)
795
- diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item())
796
- for type_name, values_list in type_gate_values.items():
797
- if values_list:
798
- values = torch.stack(values_list)
799
- diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item())
800
- return diagnostics
801
-
802
  def unpatchify(self, x):
803
  """
804
  x: (N, H, W, patch_size**2 * C)
 
166
  self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
167
  self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
168
  self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
 
 
 
 
 
 
169
  nn.init.normal_(self.memory_type_embed.weight, std=0.02)
170
  self.reset_identity_init()
171
 
 
230
  type_scale = type_scale.unsqueeze(0)
231
  return memory_tokens * type_scale + type_embed
232
 
 
 
 
 
 
 
 
233
  def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
234
  if memory_type_ids is None:
235
  return None
236
  memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
237
  stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
 
238
  if memory_tokens.dim() == 4:
239
  batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
240
  if memory_type_ids.dim() == 1:
 
315
  gate_tensor = gate_tensor.unsqueeze(-1)
316
  return gate_tensor
317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  def forward(
319
  self,
320
  x,
 
396
  if residual_gate_tensor is not None:
397
  mlp_delta = mlp_delta * residual_gate_tensor
398
  output = output + mlp_delta
 
399
  if return_delta:
400
  return attn_delta + mlp_delta
401
  return output
 
725
  if memory_adapter is not None:
726
  memory_adapter.reset_identity_init()
727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  def unpatchify(self, x):
729
  """
730
  x: (N, H, W, patch_size**2 * C)
configurations/algorithm/dememwm_memory_dit.yaml CHANGED
@@ -15,7 +15,6 @@ log_video: false
15
  dememwm:
16
  enabled: true
17
  training_stage: stage_1 # fallback only when curriculum.enabled=false
18
- debug_force_all_streams: false
19
  curriculum:
20
  enabled: true
21
  full_stage_start_step: 60000
@@ -66,8 +65,6 @@ dememwm:
66
  plucker_focal_length: 0.35
67
  compress:
68
  downsample_ratio: 4
69
- stage_policy:
70
- noise_bucket_logging: true
71
  eval_ablation:
72
  enabled: false
73
  branch: A_plus_D_plus_R_normal
 
15
  dememwm:
16
  enabled: true
17
  training_stage: stage_1 # fallback only when curriculum.enabled=false
 
18
  curriculum:
19
  enabled: true
20
  full_stage_start_step: 60000
 
65
  plucker_focal_length: 0.35
66
  compress:
67
  downsample_ratio: 4
 
 
68
  eval_ablation:
69
  enabled: false
70
  branch: A_plus_D_plus_R_normal
scripts/dememwm_full_eval.slurm CHANGED
@@ -118,7 +118,6 @@ EVAL_ARGS=(
118
  "++algorithm.context_frames=${CONTEXT_FRAMES}"
119
  "++algorithm.log_video=${LOG_VIDEO}"
120
  "++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}"
121
- "++algorithm.dememwm.debug_force_all_streams=false"
122
  "++algorithm.dememwm.training_stage=stage_2"
123
  "++algorithm.dememwm.anchor.enabled=true"
124
  "++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]"
@@ -139,7 +138,6 @@ EVAL_ARGS=(
139
  "++algorithm.dememwm.revisit.plucker_weight=0.10"
140
  "++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}"
141
  "++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}"
142
- "++algorithm.dememwm.stage_policy.noise_bucket_logging=true"
143
  "++algorithm.dememwm.eval_ablation.enabled=true"
144
  "++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}"
145
  "++algorithm.dememwm.cache.enabled=true"
 
118
  "++algorithm.context_frames=${CONTEXT_FRAMES}"
119
  "++algorithm.log_video=${LOG_VIDEO}"
120
  "++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}"
 
121
  "++algorithm.dememwm.training_stage=stage_2"
122
  "++algorithm.dememwm.anchor.enabled=true"
123
  "++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]"
 
138
  "++algorithm.dememwm.revisit.plucker_weight=0.10"
139
  "++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}"
140
  "++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}"
 
141
  "++algorithm.dememwm.eval_ablation.enabled=true"
142
  "++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}"
143
  "++algorithm.dememwm.cache.enabled=true"
scripts/dememwm_full_train.slurm CHANGED
@@ -52,7 +52,6 @@ srun python -m main \
52
  ++algorithm.context_frames=100 \
53
  ++algorithm.log_video=true \
54
  ++algorithm.diffusion.sampling_timesteps=20 \
55
- ++algorithm.dememwm.debug_force_all_streams=false \
56
  ++algorithm.dememwm.generated_history_proxy.enabled=true \
57
  ++algorithm.dememwm.generated_history_proxy.start_step=40000 \
58
  ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
@@ -77,7 +76,6 @@ srun python -m main \
77
  ++algorithm.dememwm.revisit.plucker_weight=0.10 \
78
  ++algorithm.dememwm.revisit.max_frames=2 \
79
  ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
80
- ++algorithm.dememwm.stage_policy.noise_bucket_logging=true \
81
  ++algorithm.dememwm.cache.enabled=true \
82
  ++algorithm.dememwm.cache.device=cpu \
83
  ++algorithm.dememwm.cache.keep_raw_latents=all \
 
52
  ++algorithm.context_frames=100 \
53
  ++algorithm.log_video=true \
54
  ++algorithm.diffusion.sampling_timesteps=20 \
 
55
  ++algorithm.dememwm.generated_history_proxy.enabled=true \
56
  ++algorithm.dememwm.generated_history_proxy.start_step=40000 \
57
  ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
 
76
  ++algorithm.dememwm.revisit.plucker_weight=0.10 \
77
  ++algorithm.dememwm.revisit.max_frames=2 \
78
  ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
 
79
  ++algorithm.dememwm.cache.enabled=true \
80
  ++algorithm.dememwm.cache.device=cpu \
81
  ++algorithm.dememwm.cache.keep_raw_latents=all \
tests/test_dememwm_compression.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from dememwm_import_helper import install_dememwm_namespace
3
 
4
  install_dememwm_namespace()
@@ -18,16 +19,77 @@ def small_compressor(**kwargs):
18
  )
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def test_dynamic_compressor_shapes_and_budget():
22
  comp = small_compressor(exclude_latest_local_frames=0)
23
  latents = torch.randn(4, 2, 3, 2, 2)
24
  frame_indices = torch.arange(4)[:, None].repeat(1, 2)
25
  target = torch.tensor([[1, 2], [4, 4]])
26
- tokens, mask, diag = comp(latents, frame_indices, None, target)
27
  assert tokens.shape == (2, 2, 2, 8)
28
  assert mask.shape == (2, 2, 2)
29
  assert mask[0, 0].any()
30
- assert diag["selected_source_count"].max().item() <= 4
31
 
32
 
33
  def test_dynamic_compressor_abstains_without_old_enough_sources():
@@ -35,23 +97,23 @@ def test_dynamic_compressor_abstains_without_old_enough_sources():
35
  latents = torch.randn(2, 1, 3, 2, 2)
36
  frame_indices = torch.tensor([[5], [6]])
37
  target = torch.tensor([[8]])
38
- tokens, mask, diag = comp(latents, frame_indices, None, target)
39
  assert tokens.shape == (1, 1, 2, 8)
 
40
  assert not mask.any()
41
- assert diag["max_source_frame"].item() == -1
42
- assert diag["dynamic_min_gap_to_target_per_target"].item() == -1
43
 
44
 
45
- def test_dynamic_compressor_reports_generated_fraction_and_no_future():
46
  comp = small_compressor(exclude_latest_local_frames=0)
47
  latents = torch.randn(3, 1, 3, 2, 2)
48
  frame_indices = torch.tensor([[0], [2], [5]])
49
  generated = torch.tensor([[False], [True], [True]])
50
  target = torch.tensor([[3]])
51
- _, mask, diag = comp(latents, frame_indices, None, target, generated)
 
52
  assert mask.any()
53
- assert diag["max_source_frame"].item() == 2
54
- assert 0.0 < diag["generated_source_fraction"].item() < 1.0
55
 
56
 
57
  def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
@@ -59,13 +121,12 @@ def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
59
  latents = torch.randn(5, 1, 3, 2, 2)
60
  frame_indices = torch.tensor([[0], [1], [2], [3], [4]])
61
  target = torch.tensor([[5]])
62
- tokens, mask, diag = comp(latents, frame_indices, None, target)
 
63
  assert tokens.shape == (1, 1, 2, 8)
64
  assert mask.any()
65
- assert diag["max_source_frame"].item() == 2
66
- assert diag["dynamic_min_gap_to_target_per_target"].item() == 3
67
- assert diag["dynamic_max_gap_to_target_per_target"].item() == 5
68
- assert diag["dynamic_exclude_latest_local_frames"] == 2
69
 
70
 
71
  def test_cache_materialize_raw_latents_excludes_c_short_overlap():
@@ -91,7 +152,7 @@ def test_dynamic_compressor_preserves_grad_to_trainable_parts():
91
  latents = torch.randn(4, 1, 3, 2, 2)
92
  frame_indices = torch.arange(4)[:, None]
93
  target = torch.tensor([[4]])
94
- tokens, mask, _ = comp(latents, frame_indices, None, target)
95
  assert mask.any()
96
  tokens.square().sum().backward()
97
  grads = [
@@ -107,6 +168,54 @@ def test_dynamic_compressor_selects_only_recent_valid_sources():
107
  latents = torch.randn(20, 1, 3, 2, 2)
108
  frame_indices = torch.arange(20)[:, None]
109
  target = torch.tensor([[10]])
110
- _, mask, diag = comp(latents, frame_indices, None, target)
 
111
  assert mask.any()
112
- assert diag["selected_source_count"].item() == 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn.functional as F
3
  from dememwm_import_helper import install_dememwm_namespace
4
 
5
  install_dememwm_namespace()
 
19
  )
20
 
21
 
22
+ def legacy_dynamic_forward(comp, latents, frame_indices, target_frame_indices, source_is_generated=None, exclude_latest_local_frames=None):
23
+ del source_is_generated
24
+ exclude_latest_local_frames = (
25
+ comp.exclude_latest_local_frames
26
+ if exclude_latest_local_frames is None
27
+ else int(exclude_latest_local_frames)
28
+ )
29
+ T_src, B, C, H, W = latents.shape
30
+ if target_frame_indices.ndim == 1:
31
+ target_frame_indices = target_frame_indices[:, None].expand(-1, B)
32
+ T_tgt = target_frame_indices.shape[0]
33
+ device = latents.device
34
+ n_spatial = (H // comp.patch_size) * (W // comp.patch_size)
35
+ T_out = comp._temporal_output_count()
36
+ num_slots = T_out * n_spatial
37
+ output_time_idx = comp._output_time_indices(device)
38
+ output_rows, mask_rows = [], []
39
+ for b in range(B):
40
+ src_frames_b = frame_indices[:, b]
41
+ tgt_outputs, tgt_masks = [], []
42
+ for j in range(T_tgt):
43
+ target = int(target_frame_indices[j, b].item())
44
+ valid_idx = (src_frames_b < target - exclude_latest_local_frames).nonzero(as_tuple=False).flatten()
45
+ if valid_idx.numel() == 0:
46
+ tgt_outputs.append(latents.new_zeros(num_slots, comp.dit_hidden_size))
47
+ tgt_masks.append(torch.zeros(num_slots, device=device, dtype=torch.bool))
48
+ continue
49
+ selected_frames = src_frames_b.index_select(0, valid_idx)
50
+ order = torch.argsort(selected_frames)
51
+ valid_idx = valid_idx.index_select(0, order)[-comp.max_source_frames:]
52
+ chunk = latents[valid_idx, b]
53
+ real_mask = torch.ones((chunk.shape[0],), device=device, dtype=torch.bool)
54
+ if chunk.shape[0] < comp.max_source_frames:
55
+ pad = chunk.new_zeros(comp.max_source_frames - chunk.shape[0], C, H, W)
56
+ chunk = torch.cat([pad, chunk], dim=0)
57
+ real_mask = torch.cat([
58
+ torch.zeros((pad.shape[0],), device=device, dtype=torch.bool),
59
+ real_mask,
60
+ ])
61
+ inp = chunk.clone()
62
+ inp[1:] = chunk[1:] - chunk[:-1]
63
+ x = inp.permute(1, 0, 2, 3).unsqueeze(0)
64
+ x = F.pad(x, (0, 0, 0, 0, comp.causal_pad, 0))
65
+ x = comp.conv3d(x)
66
+ x = x.squeeze(0).permute(1, 2, 3, 0)
67
+ x = comp.out_norm(x)
68
+ tokens = x.reshape(num_slots, comp.dit_hidden_size)
69
+ clamped_time_idx = output_time_idx.clamp(min=0, max=comp.max_source_frames - 1)
70
+ temporal_mask = (
71
+ (output_time_idx >= 0)
72
+ & (output_time_idx < comp.max_source_frames)
73
+ & real_mask.index_select(0, clamped_time_idx)
74
+ )
75
+ mask = temporal_mask[:, None].expand(T_out, n_spatial).reshape(num_slots)
76
+ tgt_outputs.append(tokens)
77
+ tgt_masks.append(mask)
78
+ output_rows.append(torch.stack(tgt_outputs))
79
+ mask_rows.append(torch.stack(tgt_masks))
80
+ return torch.stack(output_rows), torch.stack(mask_rows)
81
+
82
+
83
  def test_dynamic_compressor_shapes_and_budget():
84
  comp = small_compressor(exclude_latest_local_frames=0)
85
  latents = torch.randn(4, 2, 3, 2, 2)
86
  frame_indices = torch.arange(4)[:, None].repeat(1, 2)
87
  target = torch.tensor([[1, 2], [4, 4]])
88
+ tokens, mask = comp(latents, frame_indices, None, target)
89
  assert tokens.shape == (2, 2, 2, 8)
90
  assert mask.shape == (2, 2, 2)
91
  assert mask[0, 0].any()
92
+ assert mask.sum(dim=-1).max().item() <= tokens.shape[2]
93
 
94
 
95
  def test_dynamic_compressor_abstains_without_old_enough_sources():
 
97
  latents = torch.randn(2, 1, 3, 2, 2)
98
  frame_indices = torch.tensor([[5], [6]])
99
  target = torch.tensor([[8]])
100
+ tokens, mask = comp(latents, frame_indices, None, target)
101
  assert tokens.shape == (1, 1, 2, 8)
102
+ assert not tokens.any()
103
  assert not mask.any()
 
 
104
 
105
 
106
+ def test_dynamic_compressor_ignores_generated_flags_and_excludes_future_sources():
107
  comp = small_compressor(exclude_latest_local_frames=0)
108
  latents = torch.randn(3, 1, 3, 2, 2)
109
  frame_indices = torch.tensor([[0], [2], [5]])
110
  generated = torch.tensor([[False], [True], [True]])
111
  target = torch.tensor([[3]])
112
+ tokens, mask = comp(latents, frame_indices, None, target, generated)
113
+ expected_tokens, expected_mask = legacy_dynamic_forward(comp, latents, frame_indices, target, generated)
114
  assert mask.any()
115
+ assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
116
+ assert torch.equal(mask, expected_mask)
117
 
118
 
119
  def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
 
121
  latents = torch.randn(5, 1, 3, 2, 2)
122
  frame_indices = torch.tensor([[0], [1], [2], [3], [4]])
123
  target = torch.tensor([[5]])
124
+ tokens, mask = comp(latents, frame_indices, None, target)
125
+ expected_tokens, expected_mask = legacy_dynamic_forward(comp, latents, frame_indices, target)
126
  assert tokens.shape == (1, 1, 2, 8)
127
  assert mask.any()
128
+ assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
129
+ assert torch.equal(mask, expected_mask)
 
 
130
 
131
 
132
  def test_cache_materialize_raw_latents_excludes_c_short_overlap():
 
152
  latents = torch.randn(4, 1, 3, 2, 2)
153
  frame_indices = torch.arange(4)[:, None]
154
  target = torch.tensor([[4]])
155
+ tokens, mask = comp(latents, frame_indices, None, target)
156
  assert mask.any()
157
  tokens.square().sum().backward()
158
  grads = [
 
168
  latents = torch.randn(20, 1, 3, 2, 2)
169
  frame_indices = torch.arange(20)[:, None]
170
  target = torch.tensor([[10]])
171
+ tokens, mask = comp(latents, frame_indices, None, target)
172
+ expected_tokens, expected_mask = legacy_dynamic_forward(comp, latents, frame_indices, target)
173
  assert mask.any()
174
+ assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
175
+ assert torch.equal(mask, expected_mask)
176
+
177
+
178
+ def test_dynamic_compressor_batched_matches_legacy_loop():
179
+ torch.manual_seed(11)
180
+ comp = small_compressor(exclude_latest_local_frames=2)
181
+ latents = torch.randn(6, 3, 3, 2, 2)
182
+ frame_indices = torch.tensor([
183
+ [0, 4, 1],
184
+ [3, 1, 5],
185
+ [7, 8, 9],
186
+ [2, 6, 11],
187
+ [12, 3, 13],
188
+ [15, 10, 4],
189
+ ])
190
+ generated = torch.tensor([
191
+ [False, True, False],
192
+ [True, False, False],
193
+ [False, True, True],
194
+ [True, True, False],
195
+ [False, False, True],
196
+ [True, False, True],
197
+ ])
198
+ targets = torch.tensor([
199
+ [5, 5, 6],
200
+ [9, 9, 12],
201
+ [20, 11, 3],
202
+ ])
203
+
204
+ expected_tokens, expected_mask = legacy_dynamic_forward(
205
+ comp, latents, frame_indices, targets, generated
206
+ )
207
+ tokens, mask = comp(latents, frame_indices, None, targets, generated)
208
+
209
+ assert torch.allclose(tokens, expected_tokens, atol=1e-6, rtol=1e-6)
210
+ assert torch.equal(mask, expected_mask)
211
+
212
+
213
+ def test_dynamic_compressor_handles_empty_source_tensor():
214
+ comp = small_compressor(exclude_latest_local_frames=2)
215
+ latents = torch.randn(0, 2, 3, 2, 2)
216
+ frame_indices = torch.empty(0, 2, dtype=torch.long)
217
+ target = torch.tensor([[5, 6], [7, 8]])
218
+ tokens, mask = comp(latents, frame_indices, None, target)
219
+ assert tokens.shape == (2, 2, 2, 8)
220
+ assert not tokens.any()
221
+ assert not mask.any()
tests/test_dememwm_config_static.py CHANGED
@@ -7,7 +7,8 @@ def test_config_is_distinct_standalone_memory_dit_path():
7
  assert "base_video_dit" in text
8
  assert "memory_token_cross_attention: true" in text
9
  assert "dememwm:" in text
10
- assert "debug_force_all_streams" in text
 
11
  assert "ssm_memory" not in text
12
  assert "ssm_memory_ckpt_path" not in text
13
 
@@ -42,7 +43,6 @@ def test_current_config_contract_is_explicit_and_has_no_stale_sections():
42
  "plucker_grid_h: 4",
43
  "plucker_grid_w: 4",
44
  "plucker_focal_length: 0.35",
45
- "noise_bucket_logging: true",
46
  "eval_ablation:",
47
  "branch: A_plus_D_plus_R_normal",
48
  "generated_history_proxy:",
@@ -64,6 +64,8 @@ def test_current_config_contract_is_explicit_and_has_no_stale_sections():
64
  "min_gap_frames",
65
  "max_chunks",
66
  "chunk_frames",
 
 
67
  ):
68
  assert forbidden not in text
69
 
@@ -77,7 +79,6 @@ def test_full_scripts_use_consumed_contract_overrides():
77
  "algorithm.dememwm.revisit.fov_pitch_samples=20",
78
  "algorithm.dememwm.revisit.fov_depth_samples=20",
79
  "algorithm.dememwm.revisit.plucker_weight=0.10",
80
- "algorithm.dememwm.stage_policy.noise_bucket_logging=true",
81
  "algorithm.dememwm.cache.keep_compressed_records=true",
82
  ]
83
  stale = [
@@ -95,6 +96,8 @@ def test_full_scripts_use_consumed_contract_overrides():
95
  "algorithm.dememwm.revisit.min_score",
96
  "algorithm.dememwm.revisit.generated_penalty",
97
  "algorithm.dememwm.rollout.",
 
 
98
  ]
99
  expected_by_script = {
100
  "scripts/dememwm_full_train.slurm": [
@@ -120,10 +123,9 @@ def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args():
120
  "_validate_config_contract",
121
  "deterministic_pose_retrieval",
122
  "exclude_latest_local_frames",
123
- "noise_bucket_logging",
124
  "anchor_effective_enabled",
125
  "dynamic_effective_enabled",
126
- "revisit_effective_enabled",
127
  "stale DeMemWM config fields",
128
  "revisit_retrieval_kwargs",
129
  "fov_half_h",
@@ -135,56 +137,58 @@ def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args():
135
  assert token in text
136
  assert '_cfg_get(revisit_cfg, "topk"' not in text
137
  assert "lambda_abstain" not in text
 
 
138
 
139
 
140
  def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
141
  retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text()
142
  labels = Path("algorithms/worldmem/dememwm/labels.py").read_text()
143
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
144
- diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
145
  for token in [
146
  "exclude_local_context_frames",
147
  "fov_overlap_threshold",
148
  "plucker_weight",
149
  "high_quality_fov_threshold",
150
- "best_selected_frame_fov_overlap",
 
 
 
151
  "deterministic_fov_coverage_plucker",
152
  "valid_revisit_mask",
153
- "revisit_candidate_frame_count",
154
- "valid_candidate_label_count",
155
- "valid_revisit_frame_count",
156
- "no_valid_revisit_count",
157
- "revisit_selected_frame_count",
158
- "revisit_frame_fov_overlap",
159
- "revisit_abstained_count",
160
  ]:
161
- assert token in retrieval + labels + algorithm + diagnostics
162
  assert "same_video" not in retrieval + labels
163
  assert "wrong_video" not in retrieval + labels
164
  for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]:
165
  assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm
166
 
167
 
168
- def test_dynamic_compressor_excludes_c_short_contract():
169
  compression = Path("algorithms/worldmem/dememwm/compression.py").read_text()
170
  cache = Path("algorithms/worldmem/dememwm/cache.py").read_text()
171
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
172
  for token in [
173
  "exclude_latest_local_frames",
174
- "src_frames_b < target - exclude_latest_local_frames",
 
 
 
 
175
  "dynamic_min_gap_to_target_per_target",
176
  "dynamic_max_gap_to_target_per_target",
177
  "dynamic_exclude_latest_local_frames",
178
- "_local_context_exclusion_frames",
 
179
  ]:
180
- assert token in compression + cache + algorithm
181
  assert "src_frames_b < target, as_tuple=False" not in compression
182
  assert "src < int(target), as_tuple=False" not in cache
183
 
184
 
185
- def test_eval_ablation_and_noise_bucket_logging_contracts():
186
  schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
187
- diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
188
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
189
  for branch in [
190
  "memory_off",
@@ -202,15 +206,17 @@ def test_eval_ablation_and_noise_bucket_logging_contracts():
202
  "local_context_overlap_fake_revisit",
203
  ]:
204
  assert branch in schedules
205
- for token in [
206
- "noise_bucket_from_denoising_fraction",
207
- "noise_bucket_from_noise_levels",
208
  "summarize_noise_bucket_diagnostics",
209
- "noise_bucket_id",
210
  "summarize_eval_ablation_diagnostics",
211
  "eval_bucket_true_revisit_count",
212
  "eval_bucket_no_valid_revisit_count",
213
  "eval_bucket_corrupted_memory_count",
214
- "apply_revisit_eval_corruption",
215
  ]:
216
- assert token in schedules + diagnostics + algorithm
 
 
 
 
 
7
  assert "base_video_dit" in text
8
  assert "memory_token_cross_attention: true" in text
9
  assert "dememwm:" in text
10
+ assert "diagnostics:" not in text
11
+ assert "noise_bucket_logging" not in text
12
  assert "ssm_memory" not in text
13
  assert "ssm_memory_ckpt_path" not in text
14
 
 
43
  "plucker_grid_h: 4",
44
  "plucker_grid_w: 4",
45
  "plucker_focal_length: 0.35",
 
46
  "eval_ablation:",
47
  "branch: A_plus_D_plus_R_normal",
48
  "generated_history_proxy:",
 
64
  "min_gap_frames",
65
  "max_chunks",
66
  "chunk_frames",
67
+ "diagnostics:",
68
+ "noise_bucket_logging",
69
  ):
70
  assert forbidden not in text
71
 
 
79
  "algorithm.dememwm.revisit.fov_pitch_samples=20",
80
  "algorithm.dememwm.revisit.fov_depth_samples=20",
81
  "algorithm.dememwm.revisit.plucker_weight=0.10",
 
82
  "algorithm.dememwm.cache.keep_compressed_records=true",
83
  ]
84
  stale = [
 
96
  "algorithm.dememwm.revisit.min_score",
97
  "algorithm.dememwm.revisit.generated_penalty",
98
  "algorithm.dememwm.rollout.",
99
+ "algorithm.dememwm.diagnostics",
100
+ "algorithm.dememwm.stage_policy.noise_bucket_logging",
101
  ]
102
  expected_by_script = {
103
  "scripts/dememwm_full_train.slurm": [
 
123
  "_validate_config_contract",
124
  "deterministic_pose_retrieval",
125
  "exclude_latest_local_frames",
 
126
  "anchor_effective_enabled",
127
  "dynamic_effective_enabled",
128
+ "revisit_stage_config_enabled",
129
  "stale DeMemWM config fields",
130
  "revisit_retrieval_kwargs",
131
  "fov_half_h",
 
137
  assert token in text
138
  assert '_cfg_get(revisit_cfg, "topk"' not in text
139
  assert "lambda_abstain" not in text
140
+ assert "noise_bucket" not in text
141
+ assert "diagnostics" not in text
142
 
143
 
144
  def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
145
  retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text()
146
  labels = Path("algorithms/worldmem/dememwm/labels.py").read_text()
147
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
 
148
  for token in [
149
  "exclude_local_context_frames",
150
  "fov_overlap_threshold",
151
  "plucker_weight",
152
  "high_quality_fov_threshold",
153
+ "dememwm_selected_frame_fov_overlap",
154
+ "best_selected_fov_overlap",
155
+ "best_selected_plucker_overlap",
156
+ "best_selected_gap_frames",
157
  "deterministic_fov_coverage_plucker",
158
  "valid_revisit_mask",
159
+ "batched_revisit_select_positions",
 
 
 
 
 
 
160
  ]:
161
+ assert token in retrieval + labels + algorithm
162
  assert "same_video" not in retrieval + labels
163
  assert "wrong_video" not in retrieval + labels
164
  for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]:
165
  assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm
166
 
167
 
168
+ def test_dynamic_compressor_excludes_c_short_contract_without_diagnostics():
169
  compression = Path("algorithms/worldmem/dememwm/compression.py").read_text()
170
  cache = Path("algorithms/worldmem/dememwm/cache.py").read_text()
171
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
172
  for token in [
173
  "exclude_latest_local_frames",
174
+ "source_frames[:, None, :] < (target_frames[:, :, None] - int(exclude_latest_local_frames))",
175
+ "_local_context_exclusion_frames",
176
+ ]:
177
+ assert token in compression + cache + algorithm
178
+ for removed in [
179
  "dynamic_min_gap_to_target_per_target",
180
  "dynamic_max_gap_to_target_per_target",
181
  "dynamic_exclude_latest_local_frames",
182
+ "selected_source_count",
183
+ "generated_source_fraction",
184
  ]:
185
+ assert removed not in compression + algorithm
186
  assert "src_frames_b < target, as_tuple=False" not in compression
187
  assert "src < int(target), as_tuple=False" not in cache
188
 
189
 
190
+ def test_eval_ablation_contracts_without_diagnostic_summaries():
191
  schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
 
192
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
193
  for branch in [
194
  "memory_off",
 
206
  "local_context_overlap_fake_revisit",
207
  ]:
208
  assert branch in schedules
209
+ assert "apply_revisit_eval_corruption" in algorithm
210
+ for removed in [
 
211
  "summarize_noise_bucket_diagnostics",
 
212
  "summarize_eval_ablation_diagnostics",
213
  "eval_bucket_true_revisit_count",
214
  "eval_bucket_no_valid_revisit_count",
215
  "eval_bucket_corrupted_memory_count",
216
+ "noise_bucket_id",
217
  ]:
218
+ assert removed not in schedules + algorithm
219
+
220
+
221
+ def test_dememwm_diagnostics_module_removed():
222
+ assert not Path("algorithms/worldmem/dememwm/diagnostics.py").exists()
tests/test_dememwm_dit_extension_static.py CHANGED
@@ -73,17 +73,15 @@ def test_shared_memory_attention_zero_revisit_gate_matches_anchor_only():
73
  assert torch.allclose(out_anchor, out_packed, atol=1e-5)
74
 
75
 
76
- def test_memory_cross_attention_type_ids_record_stage_gates():
77
  attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
78
  x = torch.randn(1, 2, 1, 1, 8)
79
  c = torch.randn(1, 2, 8)
80
  mem = torch.randn(1, 2, 3, 8)
81
  mask = torch.ones(1, 2, 3, dtype=torch.bool)
82
  type_ids = torch.tensor([0, 1, 2])
83
- _ = attn(x, c, mem, mask, memory_type_ids=type_ids, memory_token_gate=torch.ones(1, 2, 3))
84
- assert torch.is_tensor(attn.last_type_gate_anchor_mean)
85
- assert torch.is_tensor(attn.last_type_gate_dynamic_mean)
86
- assert torch.is_tensor(attn.last_type_gate_revisit_mean)
87
 
88
 
89
  def test_dit_accepts_dynamic_rank4_tokens_and_all_false_masks_without_nan():
@@ -120,7 +118,7 @@ def test_diffusion_methods_accept_option_c_kwargs_by_signature():
120
 
121
 
122
 
123
- def test_fresh_memory_cross_attention_identity_init_delta_ratio_is_zero():
124
  attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
125
  x = torch.randn(1, 2, 1, 1, 8)
126
  c = torch.randn(1, 2, 8)
@@ -128,10 +126,9 @@ def test_fresh_memory_cross_attention_identity_init_delta_ratio_is_zero():
128
  mask = torch.ones(1, 2, 3, dtype=torch.bool)
129
  delta = attn(x, c, mem, mask, return_delta=True, residual_gate=torch.ones(1, 2, 1))
130
  assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6)
131
- assert float(attn.last_delta_ratio.item()) <= 1e-7
132
 
133
 
134
- def test_fresh_dit_memory_on_matches_memory_off_and_reports_delta_ratio():
135
  model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True)
136
  x = torch.randn(1, 2, 2, 4, 4)
137
  t = torch.zeros(1, 2, dtype=torch.long)
@@ -152,6 +149,4 @@ def test_fresh_dit_memory_on_matches_memory_off_and_reports_delta_ratio():
152
  memory_dynamic_gate=torch.ones(1, 2, 1),
153
  memory_retrieval_gate=torch.ones(1, 2, 1),
154
  )
155
- diagnostics = model.memory_adapter_delta_diagnostics()
156
  assert torch.allclose(out_on, out_off, atol=1e-6)
157
- assert diagnostics["memory_adapter_delta_ratio_max"] <= 1e-7
 
73
  assert torch.allclose(out_anchor, out_packed, atol=1e-5)
74
 
75
 
76
+ def test_memory_cross_attention_type_ids_apply_stage_gates():
77
  attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
78
  x = torch.randn(1, 2, 1, 1, 8)
79
  c = torch.randn(1, 2, 8)
80
  mem = torch.randn(1, 2, 3, 8)
81
  mask = torch.ones(1, 2, 3, dtype=torch.bool)
82
  type_ids = torch.tensor([0, 1, 2])
83
+ out = attn(x, c, mem, mask, memory_type_ids=type_ids, memory_token_gate=torch.ones(1, 2, 3))
84
+ assert out.shape == x.shape
 
 
85
 
86
 
87
  def test_dit_accepts_dynamic_rank4_tokens_and_all_false_masks_without_nan():
 
118
 
119
 
120
 
121
+ def test_fresh_memory_cross_attention_identity_init_delta_is_zero():
122
  attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2)
123
  x = torch.randn(1, 2, 1, 1, 8)
124
  c = torch.randn(1, 2, 8)
 
126
  mask = torch.ones(1, 2, 3, dtype=torch.bool)
127
  delta = attn(x, c, mem, mask, return_delta=True, residual_gate=torch.ones(1, 2, 1))
128
  assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6)
 
129
 
130
 
131
+ def test_fresh_dit_memory_on_matches_memory_off():
132
  model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True)
133
  x = torch.randn(1, 2, 2, 4, 4)
134
  t = torch.zeros(1, 2, dtype=torch.long)
 
149
  memory_dynamic_gate=torch.ones(1, 2, 1),
150
  memory_retrieval_gate=torch.ones(1, 2, 1),
151
  )
 
152
  assert torch.allclose(out_on, out_off, atol=1e-6)
 
tests/test_dememwm_eval_ablation.py CHANGED
@@ -7,7 +7,6 @@ from dememwm_import_helper import install_dememwm_namespace
7
  install_dememwm_namespace()
8
  from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
9
  from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor
10
- from algorithms.worldmem.dememwm.diagnostics import summarize_eval_ablation_diagnostics
11
  from algorithms.worldmem.dememwm.schedules import (
12
  EVAL_ABLATION_BRANCHES,
13
  EVAL_ABLATION_BRANCH_TO_ID,
@@ -48,25 +47,6 @@ def test_wave9_branch_registry_is_exact_and_validated():
48
  normalize_eval_ablation_branch("ratio_sweep")
49
 
50
 
51
- def test_eval_ablation_diagnostics_bucket_counts():
52
- diag = summarize_eval_ablation_diagnostics(
53
- enabled=True,
54
- branch="wrong_pose",
55
- valid_revisit_mask=torch.tensor([[True, True, True, False]]),
56
- no_valid_revisit_mask=torch.tensor([[False, False, False, True]]),
57
- eval_corrupted_revisit_mask=torch.tensor([[False, True, True, False]]),
58
- )
59
- assert diag["eval_ablation_enabled"] is True
60
- assert diag["eval_ablation_branch"] == "wrong_pose"
61
- assert diag["eval_ablation_branch_id"] == EVAL_ABLATION_BRANCH_TO_ID["wrong_pose"]
62
- assert diag["eval_bucket_true_revisit_count"] == 1
63
- assert diag["eval_bucket_no_valid_revisit_count"] == 1
64
- assert diag["eval_bucket_corrupted_memory_count"] == 2
65
- assert diag["eval_bucket_true_revisit_fraction"] == pytest.approx(0.25)
66
- assert diag["eval_bucket_no_valid_revisit_fraction"] == pytest.approx(0.25)
67
- assert diag["eval_bucket_corrupted_memory_fraction"] == pytest.approx(0.5)
68
-
69
-
70
  class ConstantGate(torch.nn.Module):
71
  def __init__(self, value: float):
72
  super().__init__()
@@ -83,7 +63,6 @@ class DummyDeMemWM(MemoryDiTMixin):
83
  dememwm=types.SimpleNamespace(
84
  enabled=True,
85
  training_stage="stage_2",
86
- debug_force_all_streams=False,
87
  token_patch_size=2,
88
  curriculum=types.SimpleNamespace(enabled=False),
89
  anchor=types.SimpleNamespace(
@@ -108,7 +87,6 @@ class DummyDeMemWM(MemoryDiTMixin):
108
  max_frames=2,
109
  compress=types.SimpleNamespace(pool_h=1, pool_w=1),
110
  ),
111
- stage_policy=types.SimpleNamespace(noise_bucket_logging=True),
112
  eval_ablation=types.SimpleNamespace(enabled=True, branch=branch),
113
  generated_history_proxy=types.SimpleNamespace(enabled=False),
114
  injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0),
@@ -190,12 +168,11 @@ def test_eval_ablation_forced_revisit_controls_are_isolated_to_eval_branch():
190
  assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25))
191
  assert torch.count_nonzero(forced_off.revisit_gate).item() == 0
192
  assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype))
193
- assert forced_on.diagnostics["eval_ablation_branch"] == "R_forced_on"
194
 
195
 
196
  def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate():
197
  wrong_pose = _streams("wrong_pose")
198
  assert wrong_pose.valid_revisit_mask.all()
 
199
  assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25))
200
- assert wrong_pose.diagnostics["eval_bucket_corrupted_memory_count"] == int(wrong_pose.valid_revisit_mask.numel())
201
- assert wrong_pose.diagnostics["eval_bucket_true_revisit_count"] == 0
 
7
  install_dememwm_namespace()
8
  from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
9
  from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor
 
10
  from algorithms.worldmem.dememwm.schedules import (
11
  EVAL_ABLATION_BRANCHES,
12
  EVAL_ABLATION_BRANCH_TO_ID,
 
47
  normalize_eval_ablation_branch("ratio_sweep")
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  class ConstantGate(torch.nn.Module):
51
  def __init__(self, value: float):
52
  super().__init__()
 
63
  dememwm=types.SimpleNamespace(
64
  enabled=True,
65
  training_stage="stage_2",
 
66
  token_patch_size=2,
67
  curriculum=types.SimpleNamespace(enabled=False),
68
  anchor=types.SimpleNamespace(
 
87
  max_frames=2,
88
  compress=types.SimpleNamespace(pool_h=1, pool_w=1),
89
  ),
 
90
  eval_ablation=types.SimpleNamespace(enabled=True, branch=branch),
91
  generated_history_proxy=types.SimpleNamespace(enabled=False),
92
  injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0),
 
168
  assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25))
169
  assert torch.count_nonzero(forced_off.revisit_gate).item() == 0
170
  assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype))
 
171
 
172
 
173
  def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate():
174
  wrong_pose = _streams("wrong_pose")
175
  assert wrong_pose.valid_revisit_mask.all()
176
+ assert wrong_pose.revisit_mask.any()
177
  assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25))
178
+ assert wrong_pose.revisit_best_selected_fov_overlap.shape == wrong_pose.valid_revisit_mask.shape
 
tests/test_dememwm_freeze_policy.py CHANGED
@@ -62,8 +62,6 @@ def test_dit_freeze_keeps_requires_grad_stable_and_zeroes_optimizer_lr():
62
  assert frozen_state.dit_train_state == "frozen"
63
  assert full_dit_params
64
  assert all(param.requires_grad for param in full_dit_params)
65
- assert model._last_dememwm_freeze_diagnostics["trainable_tensors_full_dit"] == 0
66
- assert model._last_dememwm_freeze_diagnostics["requires_grad_tensors_full_dit"] == len(full_dit_params)
67
  assert all(not param.requires_grad for param in model.vae.parameters())
68
 
69
  for param in full_dit_params:
@@ -85,7 +83,6 @@ def test_dit_freeze_keeps_requires_grad_stable_and_zeroes_optimizer_lr():
85
 
86
  assert full_state.dit_train_state == "full"
87
  assert all(param.requires_grad for param in full_dit_params)
88
- assert model._last_dememwm_freeze_diagnostics["trainable_tensors_full_dit"] == len(full_dit_params)
89
  assert lr_by_name["full_dit"] == 1.0e-5
90
  assert all(not param.requires_grad for param in model.vae.parameters())
91
 
 
62
  assert frozen_state.dit_train_state == "frozen"
63
  assert full_dit_params
64
  assert all(param.requires_grad for param in full_dit_params)
 
 
65
  assert all(not param.requires_grad for param in model.vae.parameters())
66
 
67
  for param in full_dit_params:
 
83
 
84
  assert full_state.dit_train_state == "full"
85
  assert all(param.requires_grad for param in full_dit_params)
 
86
  assert lr_by_name["full_dit"] == 1.0e-5
87
  assert all(not param.requires_grad for param in model.vae.parameters())
88
 
tests/test_dememwm_generated_history_proxy.py CHANGED
@@ -68,7 +68,7 @@ def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_
68
  source_is_generated = torch.zeros(4, 1, dtype=torch.bool)
69
 
70
  torch.manual_seed(123)
71
- corrupted, generated, diagnostics = model._apply_generated_history_proxy(
72
  source_latents,
73
  source_is_generated,
74
  )
@@ -77,8 +77,7 @@ def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_
77
  assert not torch.equal(corrupted, source_latents)
78
  assert generated.all()
79
  assert not source_is_generated.any()
80
- assert diagnostics["generated_history_proxy_frame_count"] == 4
81
- assert diagnostics["generated_history_proxy_frame_fraction"] == 1.0
82
 
83
 
84
 
@@ -88,7 +87,7 @@ def test_generated_history_proxy_respects_context_prefix_and_target_window_bound
88
  source_is_generated = torch.zeros(8, 1, dtype=torch.bool)
89
 
90
  torch.manual_seed(123)
91
- corrupted, generated, diagnostics = model._apply_generated_history_proxy(
92
  source_latents,
93
  source_is_generated,
94
  context_frame_count=3,
@@ -104,8 +103,7 @@ def test_generated_history_proxy_respects_context_prefix_and_target_window_bound
104
  assert torch.equal(corrupted[:3], source_latents[:3])
105
  assert not torch.equal(corrupted[3:6], source_latents[3:6])
106
  assert torch.equal(corrupted[6:], source_latents[6:])
107
- assert diagnostics["generated_history_proxy_frame_count"] == 3
108
- assert diagnostics["generated_history_proxy_frame_fraction"] == 3 / 8
109
 
110
 
111
  def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources():
 
68
  source_is_generated = torch.zeros(4, 1, dtype=torch.bool)
69
 
70
  torch.manual_seed(123)
71
+ corrupted, generated = model._apply_generated_history_proxy(
72
  source_latents,
73
  source_is_generated,
74
  )
 
77
  assert not torch.equal(corrupted, source_latents)
78
  assert generated.all()
79
  assert not source_is_generated.any()
80
+ assert generated.sum().item() == 4
 
81
 
82
 
83
 
 
87
  source_is_generated = torch.zeros(8, 1, dtype=torch.bool)
88
 
89
  torch.manual_seed(123)
90
+ corrupted, generated = model._apply_generated_history_proxy(
91
  source_latents,
92
  source_is_generated,
93
  context_frame_count=3,
 
103
  assert torch.equal(corrupted[:3], source_latents[:3])
104
  assert not torch.equal(corrupted[3:6], source_latents[3:6])
105
  assert torch.equal(corrupted[6:], source_latents[6:])
106
+ assert generated.sum().item() == 3
 
107
 
108
 
109
  def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources():
tests/test_dememwm_injection_static.py CHANGED
@@ -18,23 +18,19 @@ def _streams(dtype=torch.float32):
18
  anchor_gate=1.0,
19
  dynamic_gate=torch.ones(2, 3, 1) * 0.5,
20
  revisit_gate=0.0,
21
- diagnostics={"selected_revisit_frame_record_ids": ["c1"], "dynamic_max_source_frame": torch.tensor(2)},
22
  )
23
 
24
 
25
- def test_injection_kwarg_names_masks_dtype_and_diagnostics():
26
- kwargs, diag = InjectionAdapter()(_streams(), dtype=torch.float64)
27
  assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
28
  assert kwargs["memory_tokens"].dtype == torch.float64
29
  assert kwargs["memory_dynamic_mask"].dtype == torch.bool
30
- assert diag["anchor_valid_tokens"] == 6
31
- assert diag["dynamic_valid_fraction"] > 0.0
32
- assert diag["selected_revisit_frame_record_ids"] == ["c1"]
33
- assert diag["max_source_frame"] == 2
34
 
35
 
36
  def test_injection_omit_disabled_streams():
37
- kwargs, _ = InjectionAdapter(omit_disabled=True)(_streams())
38
  assert kwargs["memory_retrieval_tokens"] is None
39
  assert kwargs["memory_retrieval_mask"] is None
40
  assert kwargs["memory_dynamic_tokens"] is not None
 
18
  anchor_gate=1.0,
19
  dynamic_gate=torch.ones(2, 3, 1) * 0.5,
20
  revisit_gate=0.0,
 
21
  )
22
 
23
 
24
+ def test_injection_kwarg_names_masks_and_dtype():
25
+ kwargs = InjectionAdapter()(_streams(), dtype=torch.float64)
26
  assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
27
  assert kwargs["memory_tokens"].dtype == torch.float64
28
  assert kwargs["memory_dynamic_mask"].dtype == torch.bool
29
+ assert kwargs["memory_retrieval_tokens"].dtype == torch.float64
 
 
 
30
 
31
 
32
  def test_injection_omit_disabled_streams():
33
+ kwargs = InjectionAdapter(omit_disabled=True)(_streams())
34
  assert kwargs["memory_retrieval_tokens"] is None
35
  assert kwargs["memory_retrieval_mask"] is None
36
  assert kwargs["memory_dynamic_tokens"] is not None
tests/test_dememwm_noise_bucket.py CHANGED
@@ -1,102 +1,19 @@
1
- import torch
2
- from dememwm_import_helper import install_dememwm_namespace
3
-
4
- install_dememwm_namespace()
5
- from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
6
- from algorithms.worldmem.dememwm.diagnostics import summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics
7
-
8
-
9
- def test_revisit_diagnostics_report_mean_counts_per_target():
10
- diag = summarize_revisit_diagnostics(
11
- [
12
- {"valid_revisit_frame_count": 6, "revisit_candidate_frame_count": 8, "revisit_selected_frame_count": 2},
13
- {"valid_revisit_frame_count": 3, "revisit_candidate_frame_count": 5, "revisit_selected_frame_count": 1},
14
- {"valid_revisit_frame_count": 0, "revisit_candidate_frame_count": 2, "revisit_selected_frame_count": 0, "no_valid_revisit_count": 1},
15
- ],
16
- valid_revisit_mask=torch.tensor([[True, True, False]]),
17
- )
18
- assert diag["revisit_candidate_frame_count"] == 5.0
19
- assert diag["revisit_candidate_count"] == 5.0
20
- assert diag["valid_revisit_frame_count"] == 3.0
21
- assert diag["valid_revisit_count"] == 3.0
22
- assert diag["revisit_selected_frame_count"] == 3
23
- assert diag["no_valid_revisit_count"] == 1
24
-
25
-
26
- def test_noise_bucket_diagnostics_include_valid_and_no_valid_counts():
27
- diag = summarize_noise_bucket_diagnostics(
28
- noise_bucket="high",
29
- valid_revisit_mask=torch.tensor([[True, True, False]]),
30
- no_valid_revisit_mask=torch.tensor([[False, False, True]]),
31
- )
32
- assert diag["noise_bucket"] == "high"
33
- assert diag["noise_bucket_id"] == 0
34
- assert diag["noise_bucket_is_high"] == 1
35
- assert diag["noise_bucket_is_mid"] == 0
36
- assert diag["noise_bucket_high_target_count"] == 3
37
- assert diag["noise_bucket_mid_target_count"] == 0
38
- assert diag["valid_revisit_noise_bucket_high_count"] == 2
39
- assert diag["no_valid_revisit_noise_bucket_high_count"] == 1
40
-
41
-
42
- def test_noise_bucket_diagnostics_count_per_target_bucket_ids():
43
- diag = summarize_noise_bucket_diagnostics(
44
- noise_bucket="mid",
45
- noise_bucket_ids=torch.tensor([[0, 1, 2]]),
46
- valid_revisit_mask=torch.tensor([[True, True, False]]),
47
- no_valid_revisit_mask=torch.tensor([[False, False, True]]),
48
- )
49
- assert diag["noise_bucket"] == "mid"
50
- assert diag["noise_bucket_id"] == 1
51
- assert diag["noise_bucket_is_mid"] == 1
52
- assert diag["noise_bucket_high_target_count"] == 1
53
- assert diag["noise_bucket_mid_target_count"] == 1
54
- assert diag["noise_bucket_low_target_count"] == 1
55
- assert diag["valid_revisit_noise_bucket_high_count"] == 1
56
- assert diag["valid_revisit_noise_bucket_mid_count"] == 1
57
- assert diag["valid_revisit_noise_bucket_low_count"] == 0
58
- assert diag["no_valid_revisit_noise_bucket_low_count"] == 1
59
-
60
-
61
- def test_noise_bucket_log_allowlist_keeps_target_counts_only():
62
- keys = MemoryDiTMixin._TRAIN_DIAGNOSTIC_LOG_KEYS
63
- for key in (
64
- "anchor_valid_fraction",
65
- "dynamic_valid_fraction",
66
- "revisit_valid_fraction",
67
- "valid_revisit_mask_fraction",
68
- "revisit_candidate_count",
69
- "valid_revisit_count",
70
- "revisit_selected_count",
71
- "revisit_fov_overlap_mean",
72
- "revisit_incremental_fov_overlap_mean",
73
- "revisit_plucker_overlap_mean",
74
- "causal_violation_count",
75
- "noise_bucket_id",
76
- "noise_bucket_is_high",
77
- "noise_bucket_is_mid",
78
- "noise_bucket_is_low",
79
- "revisit_raw_gate_mean",
80
- "valid_revisit_noise_bucket_high_count",
81
- "valid_revisit_noise_bucket_mid_count",
82
- "valid_revisit_noise_bucket_low_count",
83
- "no_valid_revisit_noise_bucket_high_count",
84
- "no_valid_revisit_noise_bucket_mid_count",
85
- "no_valid_revisit_noise_bucket_low_count",
86
- ):
87
- assert key not in keys
88
- for key in [
89
- "noise_bucket_target_count",
90
- "noise_bucket_high_target_count",
91
- "noise_bucket_mid_target_count",
92
- "noise_bucket_low_target_count",
93
- "revisit_candidate_frame_count",
94
- "valid_revisit_frame_count",
95
- "revisit_selected_frame_count",
96
- "revisit_frame_fov_overlap_mean",
97
- "revisit_best_selected_frame_fov_overlap_mean",
98
- "revisit_best_selected_plucker_overlap_mean",
99
- "revisit_best_selected_gap_frames_mean",
100
- "revisit_learned_gate_mean",
101
  ]:
102
- assert key in keys
 
1
+ from pathlib import Path
2
+
3
+
4
+ def test_training_logging_keeps_only_core_scalars():
5
+ algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
6
+ expected_logs = [
7
+ "training/loss",
8
+ "training/denoise_loss",
9
+ "training/revisit_gate",
10
+ ]
11
+ for key in expected_logs:
12
+ assert key in algorithm
13
+ for removed_key in [
14
+ "training/dynamic_max_source_frame",
15
+ "training/revisit_valid_count",
16
+ "training/noise_bucket_id",
17
+ "training/eval_bucket_true_revisit_count",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ]:
19
+ assert removed_key not in algorithm
tests/test_dememwm_preselection.py CHANGED
@@ -107,7 +107,7 @@ def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
107
  frame_indices = torch.arange(8)[:, None]
108
  poses = torch.zeros((8, 1, 5), dtype=torch.float32)
109
  target_pose = torch.zeros((1, 1, 5), dtype=torch.float32)
110
- anchor_banks, _, _, diag = harness._build_preselected_causal_memory_banks(
111
  committed_latents=latents,
112
  source_frame_indices=frame_indices,
113
  source_is_generated=None,
@@ -133,7 +133,7 @@ def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
133
  )
134
 
135
  assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
136
- assert diag["preselected_anchor_projected_frame_count"] == 2
137
 
138
 
139
  def test_streaming_diverse_anchor_selection_uses_context_frames():
@@ -168,7 +168,7 @@ def test_preselected_memory_banks_project_only_selected_frames():
168
  target_frame_indices = torch.tensor([[10], [11]])
169
  poses = torch.zeros((20, 1, 5), dtype=torch.float32)
170
  target_pose = torch.zeros((2, 1, 5), dtype=torch.float32)
171
- anchor_banks, revisit_banks, tokens_per_frame, diag = harness._build_preselected_causal_memory_banks(
172
  committed_latents=latents,
173
  source_frame_indices=frame_indices,
174
  source_is_generated=None,
@@ -195,10 +195,10 @@ def test_preselected_memory_banks_project_only_selected_frames():
195
  assert tokens_per_frame == 1
196
  assert len(anchor_banks[0].records) == 4
197
  assert len(revisit_banks[0].records) == 3
198
- assert diag["preselected_anchor_projected_frame_count"] == 4
199
- assert diag["preselected_revisit_projected_frame_count"] == 3
200
- assert diag["preselected_revisit_projected_frame_record_count"] == 3
201
- assert harness.project_call_lengths == [4, 1, 1, 1]
202
  assert 20 not in harness.project_call_lengths
203
 
204
 
@@ -221,7 +221,7 @@ def test_preselected_revisit_projects_best_fov_frame_not_latest():
221
  )
222
  poses = pose_rows[:, None, :]
223
 
224
- _, revisit_banks, _, _ = harness._build_preselected_causal_memory_banks(
225
  committed_latents=latents,
226
  source_frame_indices=frame_indices,
227
  source_is_generated=None,
@@ -246,6 +246,7 @@ def test_preselected_revisit_projects_best_fov_frame_not_latest():
246
  token_patch_size=2,
247
  )
248
 
 
249
  assert len(revisit_banks[0].records) == 1
250
  assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1
251
  assert harness.project_call_values == [[1.0]]
 
107
  frame_indices = torch.arange(8)[:, None]
108
  poses = torch.zeros((8, 1, 5), dtype=torch.float32)
109
  target_pose = torch.zeros((1, 1, 5), dtype=torch.float32)
110
+ anchor_banks, _, _, _, _ = harness._build_preselected_causal_memory_banks(
111
  committed_latents=latents,
112
  source_frame_indices=frame_indices,
113
  source_is_generated=None,
 
133
  )
134
 
135
  assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
136
+ assert harness.project_call_lengths == [2]
137
 
138
 
139
  def test_streaming_diverse_anchor_selection_uses_context_frames():
 
168
  target_frame_indices = torch.tensor([[10], [11]])
169
  poses = torch.zeros((20, 1, 5), dtype=torch.float32)
170
  target_pose = torch.zeros((2, 1, 5), dtype=torch.float32)
171
+ anchor_banks, revisit_banks, tokens_per_frame, selected_by_target, stats = harness._build_preselected_causal_memory_banks(
172
  committed_latents=latents,
173
  source_frame_indices=frame_indices,
174
  source_is_generated=None,
 
195
  assert tokens_per_frame == 1
196
  assert len(anchor_banks[0].records) == 4
197
  assert len(revisit_banks[0].records) == 3
198
+ assert selected_by_target is not None
199
+ assert stats is not None
200
+ assert [len(records) for records in selected_by_target[0]] == [2, 2]
201
+ assert harness.project_call_lengths == [4, 3]
202
  assert 20 not in harness.project_call_lengths
203
 
204
 
 
221
  )
222
  poses = pose_rows[:, None, :]
223
 
224
+ _, revisit_banks, _, selected_by_target, _ = harness._build_preselected_causal_memory_banks(
225
  committed_latents=latents,
226
  source_frame_indices=frame_indices,
227
  source_is_generated=None,
 
246
  token_patch_size=2,
247
  )
248
 
249
+ assert selected_by_target is not None
250
  assert len(revisit_banks[0].records) == 1
251
  assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1
252
  assert harness.project_call_values == [[1.0]]
tests/test_dememwm_retrieval.py CHANGED
@@ -91,21 +91,19 @@ def test_revisit_candidates_require_causal_c_short_gap():
91
  exclude_local_context_frames=4,
92
  )
93
  assert [r.max_source_frame for r in result.records] == [1]
94
- assert result.diagnostics["revisit_candidate_frame_count"] == 2
95
- assert result.diagnostics["revisit_candidate_count"] == 2
96
- assert result.diagnostics["valid_revisit_frame_count"] == 1
97
- assert result.diagnostics["valid_revisit_count"] == 1
98
- assert result.diagnostics["valid_candidate_label_count"] == 1
99
- assert result.diagnostics["revisit_min_gap_to_target"] == 5
100
- assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
101
 
102
 
103
  def test_revisit_abstains_when_no_valid_candidate():
104
  result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4)
105
  assert result.records == []
106
- assert result.diagnostics["abstained"] is True
107
- assert result.diagnostics["valid_revisit_mask"] == 0
108
- assert result.diagnostics["no_valid_revisit_count"] == 1
 
 
109
 
110
 
111
  def test_revisit_retrieval_rejects_non_vectorized_inputs():
@@ -150,14 +148,11 @@ def test_fov_threshold_filters_candidates_without_action():
150
  exclude_local_context_frames=4,
151
  topk=4,
152
  )
153
- assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
154
- assert result.diagnostics["valid_revisit_frame_count"] == 1
155
- assert result.diagnostics["valid_revisit_count"] == 1
156
- assert result.diagnostics["best_selected_fov_overlap"] == 1.0
157
- assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0
158
- assert result.diagnostics["best_selected_gap_frames"] == 10
159
- assert result.diagnostics["revisit_fov_overlap_max"] == 1.0
160
- assert result.diagnostics["revisit_plucker_overlap_max"] > 0.0
161
 
162
 
163
  def test_pose_preselect_uses_local_position_and_view_direction_before_fov():
@@ -177,13 +172,8 @@ def test_pose_preselect_uses_local_position_and_view_direction_before_fov():
177
  pose_preselect_topk=1,
178
  )
179
 
180
- assert result.diagnostics["selected_frame_record_ids"] == ["near_same_direction"]
181
- assert result.diagnostics["revisit_pose_preselect_input_count"] == 3
182
- assert result.diagnostics["revisit_pose_preselect_scored_count"] == 3
183
- assert result.diagnostics["revisit_pose_preselect_selected_count"] == 1
184
- assert result.diagnostics["revisit_exact_fov_candidate_count"] == 1
185
- assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
186
- assert abs(result.diagnostics["revisit_pose_preselect_min_distance"] - (1.0 / 30.0)) < 1e-6
187
 
188
 
189
  def test_selected_frame_carries_frame_metadata_for_projection():
@@ -197,15 +187,14 @@ def test_selected_frame_carries_frame_metadata_for_projection():
197
  topk=1,
198
  )
199
 
200
- assert result.diagnostics["selected_frame_record_ids"] == ["frame_1"]
201
  assert result.selected_frame_ids == [1]
202
  assert result.records[0].metadata["dememwm_selected_frame_index"] == 1
 
203
  assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
204
- assert result.diagnostics["best_selected_frame_index"] == 1
205
- assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0
206
 
207
 
208
- def test_high_quality_threshold_is_selected_target_diagnostic_only():
209
  result = deterministic_revisit_retrieval(
210
  [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])],
211
  target_frame=10,
@@ -215,9 +204,10 @@ def test_high_quality_threshold_is_selected_target_diagnostic_only():
215
  exclude_local_context_frames=4,
216
  topk=1,
217
  )
218
- assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
219
- assert result.diagnostics["valid_revisit_count"] == 1
220
- assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70
 
221
 
222
 
223
  def test_video_metadata_does_not_filter_revisit_candidates():
@@ -233,8 +223,8 @@ def test_video_metadata_does_not_filter_revisit_candidates():
233
  exclude_local_context_frames=4,
234
  topk=4,
235
  )
236
- assert result.diagnostics["selected_frame_record_ids"] == ["c1", "c0"]
237
- assert result.diagnostics["valid_revisit_count"] == 2
238
 
239
 
240
  def test_tie_breaking_is_overlap_then_age_then_source_then_record_id():
@@ -244,4 +234,4 @@ def test_tie_breaking_is_overlap_then_age_then_source_then_record_id():
244
  rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"),
245
  ]
246
  result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3)
247
- assert result.diagnostics["selected_frame_record_ids"] == ["c", "a", "b"]
 
91
  exclude_local_context_frames=4,
92
  )
93
  assert [r.max_source_frame for r in result.records] == [1]
94
+ assert result.selected_frame_ids == [1]
95
+ assert result.scores.numel() == 1
96
+ assert result.best_selected_gap_frames.item() == pytest.approx(5.0)
 
 
 
 
97
 
98
 
99
  def test_revisit_abstains_when_no_valid_candidate():
100
  result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4)
101
  assert result.records == []
102
+ assert result.selected_frame_ids == []
103
+ assert result.scores.numel() == 0
104
+ assert result.best_selected_fov_overlap.item() == pytest.approx(0.0)
105
+ assert result.best_selected_plucker_overlap.item() == pytest.approx(0.0)
106
+ assert result.best_selected_gap_frames.item() == pytest.approx(-1.0)
107
 
108
 
109
  def test_revisit_retrieval_rejects_non_vectorized_inputs():
 
148
  exclude_local_context_frames=4,
149
  topk=4,
150
  )
151
+ assert [record.chunk_id for record in result.records] == ["c0"]
152
+ assert result.selected_frame_ids == [0]
153
+ assert result.best_selected_fov_overlap.item() == pytest.approx(1.0)
154
+ assert result.best_selected_plucker_overlap.item() > 0.0
155
+ assert result.best_selected_gap_frames.item() == pytest.approx(10.0)
 
 
 
156
 
157
 
158
  def test_pose_preselect_uses_local_position_and_view_direction_before_fov():
 
172
  pose_preselect_topk=1,
173
  )
174
 
175
+ assert [record.chunk_id for record in result.records] == ["near_same_direction"]
176
+ assert result.selected_frame_ids == [2]
 
 
 
 
 
177
 
178
 
179
  def test_selected_frame_carries_frame_metadata_for_projection():
 
187
  topk=1,
188
  )
189
 
190
+ assert [record.chunk_id for record in result.records] == ["frame_1"]
191
  assert result.selected_frame_ids == [1]
192
  assert result.records[0].metadata["dememwm_selected_frame_index"] == 1
193
+ assert result.records[0].metadata["dememwm_selected_frame_fov_overlap"] == pytest.approx(1.0)
194
  assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
 
 
195
 
196
 
197
+ def test_high_quality_threshold_marks_selected_frame_metadata():
198
  result = deterministic_revisit_retrieval(
199
  [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])],
200
  target_frame=10,
 
204
  exclude_local_context_frames=4,
205
  topk=1,
206
  )
207
+ assert [record.chunk_id for record in result.records] == ["c0"]
208
+ assert len(result.records) == 1
209
+ assert 0.30 <= result.best_selected_fov_overlap.item() < 0.70
210
+ assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is False
211
 
212
 
213
  def test_video_metadata_does_not_filter_revisit_candidates():
 
223
  exclude_local_context_frames=4,
224
  topk=4,
225
  )
226
+ assert [record.chunk_id for record in result.records] == ["c1", "c0"]
227
+ assert result.selected_frame_ids == [1, 0]
228
 
229
 
230
  def test_tie_breaking_is_overlap_then_age_then_source_then_record_id():
 
234
  rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"),
235
  ]
236
  result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3)
237
+ assert [record.chunk_id for record in result.records] == ["c", "a", "b"]
tests/test_dememwm_schedules.py CHANGED
@@ -4,9 +4,7 @@ from types import SimpleNamespace
4
 
5
  from algorithms.worldmem.dememwm.schedules import (
6
  compute_stream_gates,
7
- noise_bucket_from_denoising_fraction,
8
- noise_bucket_from_noise_levels,
9
- noise_bucket_ids_from_noise_levels,
10
  resolve_curriculum,
11
  )
12
 
@@ -43,37 +41,16 @@ def test_two_stage_curriculum_switches_at_full_stage_start():
43
  assert stage_1.anchor_enabled and stage_1.dynamic_enabled and stage_1.revisit_enabled
44
  assert stage_1.dit_train_state == "frozen"
45
  assert not hasattr(stage_1, "dit_late_blocks_trainable")
46
- assert all("late" not in key for key in stage_1.diagnostics())
47
  assert stage_2.stage == "stage_2"
48
  assert stage_2.anchor_enabled and stage_2.dynamic_enabled and stage_2.revisit_enabled
49
  assert stage_2.dit_train_state == "full"
50
 
51
 
52
- def test_debug_force_all_streams_overrides_stage():
53
- gates = compute_stream_gates("stage_1", debug_force_all_streams=True)
54
- assert gates.anchor_enabled and gates.dynamic_enabled and gates.revisit_enabled
55
- assert gates.reason == "debug_force_all_streams"
56
-
57
-
58
  def test_unknown_stage_fails():
59
  with pytest.raises(ValueError):
60
  compute_stream_gates("unknown")
61
 
62
 
63
- def test_noise_bucket_from_denoising_fraction():
64
- assert noise_bucket_from_denoising_fraction(0.0) == "high"
65
- assert noise_bucket_from_denoising_fraction(0.5) == "mid"
66
- assert noise_bucket_from_denoising_fraction(1.0) == "low"
67
-
68
-
69
- def test_noise_bucket_from_training_noise_levels():
70
- import torch
71
- assert noise_bucket_from_noise_levels(torch.tensor([9, 8]), 10) == "high"
72
- assert noise_bucket_from_noise_levels(torch.tensor([5, 4]), 10) == "mid"
73
- assert noise_bucket_from_noise_levels(torch.tensor([1, 0]), 10) == "low"
74
-
75
-
76
- def test_noise_bucket_ids_from_training_noise_levels():
77
  import torch
78
- bucket_ids = noise_bucket_ids_from_noise_levels(torch.tensor([[9, 4, 0]]), 10)
79
- assert bucket_ids.tolist() == [[0, 1, 2]]
 
4
 
5
  from algorithms.worldmem.dememwm.schedules import (
6
  compute_stream_gates,
7
+ denoising_fraction_from_noise_levels,
 
 
8
  resolve_curriculum,
9
  )
10
 
 
41
  assert stage_1.anchor_enabled and stage_1.dynamic_enabled and stage_1.revisit_enabled
42
  assert stage_1.dit_train_state == "frozen"
43
  assert not hasattr(stage_1, "dit_late_blocks_trainable")
 
44
  assert stage_2.stage == "stage_2"
45
  assert stage_2.anchor_enabled and stage_2.dynamic_enabled and stage_2.revisit_enabled
46
  assert stage_2.dit_train_state == "full"
47
 
48
 
 
 
 
 
 
 
49
  def test_unknown_stage_fails():
50
  with pytest.raises(ValueError):
51
  compute_stream_gates("unknown")
52
 
53
 
54
+ def test_denoising_fraction_from_training_noise_levels():
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  import torch
56
+ assert denoising_fraction_from_noise_levels(torch.tensor([9, 0]), 10) == pytest.approx(0.5)
 
train_dememwm_full_berzelius.sh CHANGED
@@ -51,7 +51,6 @@ srun python -m main \
51
  ++algorithm.context_frames=100 \
52
  ++algorithm.log_video=true \
53
  ++algorithm.diffusion.sampling_timesteps=20 \
54
- ++algorithm.dememwm.debug_force_all_streams=false \
55
  ++algorithm.dememwm.generated_history_proxy.enabled=true \
56
  ++algorithm.dememwm.generated_history_proxy.start_step=40000 \
57
  ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
@@ -76,7 +75,6 @@ srun python -m main \
76
  ++algorithm.dememwm.revisit.plucker_weight=0.10 \
77
  ++algorithm.dememwm.revisit.max_frames=2 \
78
  ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
79
- ++algorithm.dememwm.stage_policy.noise_bucket_logging=true \
80
  ++algorithm.dememwm.cache.enabled=true \
81
  ++algorithm.dememwm.cache.device=cpu \
82
  ++algorithm.dememwm.cache.keep_raw_latents=all \
 
51
  ++algorithm.context_frames=100 \
52
  ++algorithm.log_video=true \
53
  ++algorithm.diffusion.sampling_timesteps=20 \
 
54
  ++algorithm.dememwm.generated_history_proxy.enabled=true \
55
  ++algorithm.dememwm.generated_history_proxy.start_step=40000 \
56
  ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \
 
75
  ++algorithm.dememwm.revisit.plucker_weight=0.10 \
76
  ++algorithm.dememwm.revisit.max_frames=2 \
77
  ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \
 
78
  ++algorithm.dememwm.cache.enabled=true \
79
  ++algorithm.dememwm.cache.device=cpu \
80
  ++algorithm.dememwm.cache.keep_raw_latents=all \