BonanDing commited on
Commit
93d7b0a
·
1 Parent(s): 79bf398

Clean DeMemWM deterministic memory slot handling

Browse files
algorithms/worldmem/dememwm/algorithm.py CHANGED
@@ -42,7 +42,6 @@ class MemoryDiTMixin:
42
  "revisit_pose_preselect_selected_count",
43
  "revisit_exact_fov_candidate_count",
44
  "valid_revisit_frame_count",
45
- "valid_revisit_target_count",
46
  "no_valid_revisit_count",
47
  "revisit_selected_frame_count",
48
  "revisit_frame_fov_overlap_mean",
@@ -171,7 +170,7 @@ class MemoryDiTMixin:
171
  if self._cfg_has(memory_cfg, name)
172
  ]
173
  if ratio_fields:
174
- raise ValueError(f"standalone DeMemWM uses fixed manual token budgets, not ratio fields: {ratio_fields}")
175
 
176
  anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
177
  dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
@@ -247,7 +246,6 @@ class MemoryDiTMixin:
247
  "revisit_deterministic_fov_plucker_retrieval": True,
248
  "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
249
  "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
250
- "revisit_high_quality_fov_threshold": high_quality_fov_threshold,
251
  "revisit_plucker_weight": plucker_weight,
252
  "stage_policy_noise_bucket_logging": True,
253
  }
@@ -764,13 +762,28 @@ class MemoryDiTMixin:
764
  if source_positions.numel() <= num_anchors or pose is None:
765
  return source_positions[:num_anchors]
766
  poses = pose.float()
767
- selected = [0]
768
- dists = torch.cdist(poses[0:1], poses).squeeze(0)
769
- for _ in range(num_anchors - 1):
 
 
 
 
 
 
 
 
 
 
 
770
  farthest = int(dists.argmax().item())
 
 
771
  selected.append(farthest)
772
- d_new = torch.cdist(poses[farthest:farthest + 1], poses).squeeze(0)
 
773
  dists = torch.minimum(dists, d_new)
 
774
  return source_positions[torch.tensor(sorted(selected), device=source_positions.device)]
775
 
776
  def _build_streaming_cache_records(
@@ -867,10 +880,14 @@ class MemoryDiTMixin:
867
  source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
868
  if source_positions.numel() > 0:
869
  if anchor_diverse:
870
- anchor_pose = _pose_subset(source_positions, batch_idx)
871
- selected_anchor_positions = self._select_diverse_anchor_positions(
872
- source_positions, anchor_pose, len(anchor_indices)
873
- )
 
 
 
 
874
  else:
875
  selected_list = []
876
  for anchor_idx in anchor_indices:
@@ -1339,28 +1356,36 @@ class MemoryDiTMixin:
1339
  }
1340
  return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
1341
 
 
 
 
 
 
 
 
 
1342
  def _records_to_stream(
1343
  self,
1344
  records,
1345
- max_tokens: int,
1346
  hidden_size: int,
1347
  device: torch.device,
1348
  dtype: torch.dtype,
1349
  ) -> tuple[torch.Tensor, torch.Tensor, int]:
1350
- max_tokens = max(0, int(max_tokens))
1351
  record_list = list(records)
1352
- stacked_tokens, stacked_mask = stack_record_tokens(record_list, max_slots=max_tokens)
1353
  max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1)
1354
- if stacked_tokens is None or stacked_mask is None or max_tokens == 0:
1355
- tokens = torch.zeros((max_tokens, hidden_size), device=device, dtype=dtype)
1356
- mask = torch.zeros((max_tokens,), device=device, dtype=torch.bool)
1357
  return tokens, mask, max_source_frame
1358
- n = min(max_tokens, stacked_tokens.shape[0])
1359
  filled = stacked_tokens[:n].to(device=device, dtype=dtype)
1360
  filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool)
1361
- if n < max_tokens:
1362
- pad = filled.new_zeros(max_tokens - n, hidden_size)
1363
- pad_mask = torch.zeros(max_tokens - n, device=device, dtype=torch.bool)
1364
  tokens = torch.cat([filled, pad], dim=0)
1365
  mask = torch.cat([filled_mask, pad_mask], dim=0)
1366
  else:
@@ -1520,10 +1545,10 @@ class MemoryDiTMixin:
1520
  revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size(
1521
  revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8
1522
  )
1523
- revisit_max_tokens = revisit_max_frames * revisit_pool_h * revisit_pool_w
1524
  recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8))
1525
- exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4))
1526
- local_context_exclusion_frames = self._local_context_exclusion_frames()
1527
  fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
1528
  high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
1529
  plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10))
@@ -1568,7 +1593,7 @@ class MemoryDiTMixin:
1568
  dtype=stream_dtype,
1569
  max_recent_frames=recent_frames,
1570
  target_frame_indices=target_frame_indices,
1571
- exclude_latest_local_frames=exclude_latest_local_frames,
1572
  )
1573
  if raw_latents is not None:
1574
  dynamic_latents = raw_latents
@@ -1631,7 +1656,7 @@ class MemoryDiTMixin:
1631
  revisit_pool_h,
1632
  revisit_pool_w,
1633
  revisit_max_frames,
1634
- local_context_exclusion_frames,
1635
  fov_overlap_threshold,
1636
  plucker_weight,
1637
  revisit_retrieval_kwargs,
@@ -1641,7 +1666,7 @@ class MemoryDiTMixin:
1641
 
1642
  T_tgt = target_frame_indices.shape[0]
1643
  anchor_slots = max(0, anchor_num_tokens)
1644
- revisit_slots = max(0, revisit_max_tokens)
1645
  anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT
1646
  anchor_include_generated = allow_generated_anchor
1647
  anchor_token_rows = []
@@ -1659,7 +1684,6 @@ class MemoryDiTMixin:
1659
  source_type=anchor_source_type,
1660
  include_generated=anchor_include_generated,
1661
  max_records=len(anchor_indices),
1662
- max_slots=anchor_slots,
1663
  )
1664
  )
1665
  anchor_bank.assert_causal(target_frame, records)
@@ -1693,7 +1717,7 @@ class MemoryDiTMixin:
1693
  "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1694
  "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1695
  "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
1696
- "dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
1697
  }
1698
  else:
1699
  # Pre-select dynamic source frame positions using only frame index metadata
@@ -1705,7 +1729,7 @@ class MemoryDiTMixin:
1705
  for _b in range(B):
1706
  for _j in range(T_tgt):
1707
  _target = int(target_frame_indices[_j, _b].item())
1708
- _valid = (_dfi[:, _b] < _target - exclude_latest_local_frames).nonzero(as_tuple=False).flatten()
1709
  _needed.extend(_valid[-_max_src:].tolist())
1710
  if _needed:
1711
  _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long)
@@ -1727,7 +1751,7 @@ class MemoryDiTMixin:
1727
  _dynamic_pose_small,
1728
  target_frame_indices,
1729
  _dynamic_gen_small,
1730
- exclude_latest_local_frames=exclude_latest_local_frames,
1731
  )
1732
 
1733
  dynamic_min_gap_tensor = torch.as_tensor(
@@ -1785,7 +1809,6 @@ class MemoryDiTMixin:
1785
  revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1786
  revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1787
  revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
1788
- valid_revisit_target_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
1789
  eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
1790
  revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
1791
  eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
@@ -1801,7 +1824,10 @@ class MemoryDiTMixin:
1801
  for target_idx in range(T_tgt):
1802
  target_frame = int(target_frame_indices[target_idx, batch_idx].item())
1803
  if use_cache_revisit_records:
1804
- candidate_records = list(revisit_record_batches[batch_idx])
 
 
 
1805
  else:
1806
  candidate_records = revisit_bank.query(
1807
  MemoryBankQuery(
@@ -1815,7 +1841,7 @@ class MemoryDiTMixin:
1815
  target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
1816
  target_summary=None,
1817
  topk=revisit_max_frames,
1818
- exclude_local_context_frames=local_context_exclusion_frames,
1819
  fov_overlap_threshold=fov_overlap_threshold,
1820
  plucker_weight=plucker_weight,
1821
  target_video_id=_target_video_id_or_none(batch_idx, target_idx),
@@ -1840,7 +1866,6 @@ class MemoryDiTMixin:
1840
  revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
1841
  revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
1842
  revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
1843
- valid_revisit_target_mask[batch_idx, target_idx] = bool(result.diagnostics.get("valid_revisit_target_count", 0))
1844
  revisit_bank.assert_causal(target_frame, selected_records)
1845
  if selected_records:
1846
  valid_revisit_mask[batch_idx, target_idx] = True
@@ -1860,10 +1885,9 @@ class MemoryDiTMixin:
1860
  target_frame=target_frame,
1861
  )
1862
  eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted)
1863
- actual_max_source_frame = max((int(record.max_source_frame) for record in selected_records), default=max_source_frame)
1864
  batch_token_rows.append(stream_tokens)
1865
  batch_mask_rows.append(stream_mask)
1866
- batch_max_rows.append(torch.as_tensor(actual_max_source_frame, device=stream_device, dtype=torch.long))
1867
  revisit_token_rows.append(torch.stack(batch_token_rows, dim=0))
1868
  revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0))
1869
  revisit_max_rows.append(torch.stack(batch_max_rows, dim=0))
@@ -1872,15 +1896,15 @@ class MemoryDiTMixin:
1872
  revisit_max = torch.stack(revisit_max_rows, dim=0)
1873
 
1874
  if anchor_tokens.shape[-2] != anchor_num_tokens:
1875
- raise AssertionError(f"anchor token budget mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}")
1876
  if dynamic_latents is not None and dynamic_latents.shape[0] > 0:
1877
  _expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target(
1878
  int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1])
1879
  )
1880
  if dynamic_tokens.shape[-2] != _expected_dyn:
1881
- raise AssertionError(f"dynamic token budget mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}")
1882
- if revisit_tokens.shape[-2] > revisit_max_tokens:
1883
- raise AssertionError(f"revisit token cap exceeded: got {revisit_tokens.shape[-2]}, cap {revisit_max_tokens}")
1884
  anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0
1885
  dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0
1886
  gate_module = getattr(self, "dememwm_revisit_gate", None)
@@ -1913,7 +1937,6 @@ class MemoryDiTMixin:
1913
  revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
1914
  revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
1915
  revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
1916
- valid_revisit_target_mask = torch.zeros_like(valid_revisit_target_mask)
1917
  eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask)
1918
  valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
1919
  revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
@@ -1948,7 +1971,6 @@ class MemoryDiTMixin:
1948
  "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
1949
  "no_valid_revisit_mask": no_valid_revisit_mask,
1950
  "valid_revisit_eff_mask": valid_revisit_eff_mask,
1951
- "valid_revisit_target_mask": valid_revisit_target_mask,
1952
  "revisit_candidate_frame_count_per_target": revisit_candidate_count,
1953
  "revisit_selected_frame_count_per_target": revisit_selected_count,
1954
  "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
@@ -1972,18 +1994,17 @@ class MemoryDiTMixin:
1972
  "token_patch_size": token_patch_size,
1973
  "tokens_per_frame": tokens_per_frame,
1974
  "anchor_token_slots": int(anchor_tokens.shape[-2]),
1975
- "anchor_budget_tokens": anchor_num_tokens,
1976
  "anchor_pool_h": anchor_pool_h,
1977
  "anchor_pool_w": anchor_pool_w,
1978
  "dynamic_token_slots": int(dynamic_tokens.shape[-2]),
1979
- "dynamic_budget_tokens": int(dynamic_tokens.shape[-2]),
1980
  "dynamic_min_gap_to_target": dynamic_min_gap_to_target,
1981
  "dynamic_max_gap_to_target": dynamic_max_gap_to_target,
1982
- "dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
1983
  "revisit_token_slots": int(revisit_tokens.shape[-2]),
1984
- "revisit_max_tokens": revisit_max_tokens,
1985
- "revisit_local_context_exclusion_frames": local_context_exclusion_frames,
1986
- "revisit_high_quality_fov_threshold": high_quality_fov_threshold,
1987
  "revisit_pool_h": revisit_pool_h,
1988
  "revisit_pool_w": revisit_pool_w,
1989
  "revisit_max_frames": revisit_max_frames,
 
42
  "revisit_pose_preselect_selected_count",
43
  "revisit_exact_fov_candidate_count",
44
  "valid_revisit_frame_count",
 
45
  "no_valid_revisit_count",
46
  "revisit_selected_frame_count",
47
  "revisit_frame_fov_overlap_mean",
 
170
  if self._cfg_has(memory_cfg, name)
171
  ]
172
  if ratio_fields:
173
+ raise ValueError(f"standalone DeMemWM derives stream slots from latent shape and compression settings, not ratio fields: {ratio_fields}")
174
 
175
  anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
176
  dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
 
246
  "revisit_deterministic_fov_plucker_retrieval": True,
247
  "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
248
  "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
 
249
  "revisit_plucker_weight": plucker_weight,
250
  "stage_policy_noise_bucket_logging": True,
251
  }
 
762
  if source_positions.numel() <= num_anchors or pose is None:
763
  return source_positions[:num_anchors]
764
  poses = pose.float()
765
+ pairwise = torch.cdist(poses, poses)
766
+ if not bool((pairwise > 0).any().item()):
767
+ return source_positions[:num_anchors]
768
+ available = torch.ones((int(source_positions.numel()),), device=poses.device, dtype=torch.bool)
769
+ if num_anchors == 1:
770
+ selected = [int(pairwise.mean(dim=1).argmax().item())]
771
+ else:
772
+ first, second = divmod(int(pairwise.argmax().item()), int(pairwise.shape[1]))
773
+ selected = [int(first), int(second)]
774
+ for idx in selected:
775
+ available[idx] = False
776
+ dists = pairwise[selected].min(dim=0).values
777
+ dists = dists.masked_fill(~available, float("-inf"))
778
+ for _ in range(num_anchors - len(selected)):
779
  farthest = int(dists.argmax().item())
780
+ if not bool(available[farthest].item()):
781
+ break
782
  selected.append(farthest)
783
+ available[farthest] = False
784
+ d_new = pairwise[farthest]
785
  dists = torch.minimum(dists, d_new)
786
+ dists = dists.masked_fill(~available, float("-inf"))
787
  return source_positions[torch.tensor(sorted(selected), device=source_positions.device)]
788
 
789
  def _build_streaming_cache_records(
 
880
  source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
881
  if source_positions.numel() > 0:
882
  if anchor_diverse:
883
+ anchor_source_positions = source_positions[source_positions < self._context_frame_count()]
884
+ if anchor_source_positions.numel() > 0:
885
+ anchor_pose = _pose_subset(anchor_source_positions, batch_idx)
886
+ selected_anchor_positions = self._select_diverse_anchor_positions(
887
+ anchor_source_positions, anchor_pose, len(anchor_indices)
888
+ )
889
+ else:
890
+ selected_anchor_positions = source_positions[:0]
891
  else:
892
  selected_list = []
893
  for anchor_idx in anchor_indices:
 
1356
  }
1357
  return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
1358
 
1359
+ def _causal_cached_revisit_records(
1360
+ self,
1361
+ records: Iterable[MemoryRecord],
1362
+ target_frame: int,
1363
+ ) -> list[MemoryRecord]:
1364
+ target_frame = int(target_frame)
1365
+ return [record for record in records if int(record.source_end) <= target_frame]
1366
+
1367
  def _records_to_stream(
1368
  self,
1369
  records,
1370
+ target_slots: int,
1371
  hidden_size: int,
1372
  device: torch.device,
1373
  dtype: torch.dtype,
1374
  ) -> tuple[torch.Tensor, torch.Tensor, int]:
1375
+ target_slots = max(0, int(target_slots))
1376
  record_list = list(records)
1377
+ stacked_tokens, stacked_mask = stack_record_tokens(record_list, target_slots=target_slots)
1378
  max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1)
1379
+ if stacked_tokens is None or stacked_mask is None or target_slots == 0:
1380
+ tokens = torch.zeros((target_slots, hidden_size), device=device, dtype=dtype)
1381
+ mask = torch.zeros((target_slots,), device=device, dtype=torch.bool)
1382
  return tokens, mask, max_source_frame
1383
+ n = min(target_slots, stacked_tokens.shape[0])
1384
  filled = stacked_tokens[:n].to(device=device, dtype=dtype)
1385
  filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool)
1386
+ if n < target_slots:
1387
+ pad = filled.new_zeros(target_slots - n, hidden_size)
1388
+ pad_mask = torch.zeros(target_slots - n, device=device, dtype=torch.bool)
1389
  tokens = torch.cat([filled, pad], dim=0)
1390
  mask = torch.cat([filled_mask, pad_mask], dim=0)
1391
  else:
 
1545
  revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size(
1546
  revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8
1547
  )
1548
+ revisit_target_slots = revisit_max_frames * revisit_pool_h * revisit_pool_w
1549
  recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8))
1550
+ dynamic_recent_exclusion_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4))
1551
+ revisit_context_window_exclusion_frames = self._local_context_exclusion_frames()
1552
  fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
1553
  high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
1554
  plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10))
 
1593
  dtype=stream_dtype,
1594
  max_recent_frames=recent_frames,
1595
  target_frame_indices=target_frame_indices,
1596
+ exclude_latest_local_frames=dynamic_recent_exclusion_frames,
1597
  )
1598
  if raw_latents is not None:
1599
  dynamic_latents = raw_latents
 
1656
  revisit_pool_h,
1657
  revisit_pool_w,
1658
  revisit_max_frames,
1659
+ revisit_context_window_exclusion_frames,
1660
  fov_overlap_threshold,
1661
  plucker_weight,
1662
  revisit_retrieval_kwargs,
 
1666
 
1667
  T_tgt = target_frame_indices.shape[0]
1668
  anchor_slots = max(0, anchor_num_tokens)
1669
+ revisit_slots = max(0, revisit_target_slots)
1670
  anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT
1671
  anchor_include_generated = allow_generated_anchor
1672
  anchor_token_rows = []
 
1684
  source_type=anchor_source_type,
1685
  include_generated=anchor_include_generated,
1686
  max_records=len(anchor_indices),
 
1687
  )
1688
  )
1689
  anchor_bank.assert_causal(target_frame, records)
 
1717
  "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1718
  "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
1719
  "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
1720
+ "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
1721
  }
1722
  else:
1723
  # Pre-select dynamic source frame positions using only frame index metadata
 
1729
  for _b in range(B):
1730
  for _j in range(T_tgt):
1731
  _target = int(target_frame_indices[_j, _b].item())
1732
+ _valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
1733
  _needed.extend(_valid[-_max_src:].tolist())
1734
  if _needed:
1735
  _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long)
 
1751
  _dynamic_pose_small,
1752
  target_frame_indices,
1753
  _dynamic_gen_small,
1754
+ exclude_latest_local_frames=dynamic_recent_exclusion_frames,
1755
  )
1756
 
1757
  dynamic_min_gap_tensor = torch.as_tensor(
 
1809
  revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1810
  revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
1811
  revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
 
1812
  eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
1813
  revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
1814
  eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
 
1824
  for target_idx in range(T_tgt):
1825
  target_frame = int(target_frame_indices[target_idx, batch_idx].item())
1826
  if use_cache_revisit_records:
1827
+ candidate_records = self._causal_cached_revisit_records(
1828
+ revisit_record_batches[batch_idx],
1829
+ target_frame,
1830
+ )
1831
  else:
1832
  candidate_records = revisit_bank.query(
1833
  MemoryBankQuery(
 
1841
  target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
1842
  target_summary=None,
1843
  topk=revisit_max_frames,
1844
+ exclude_local_context_frames=revisit_context_window_exclusion_frames,
1845
  fov_overlap_threshold=fov_overlap_threshold,
1846
  plucker_weight=plucker_weight,
1847
  target_video_id=_target_video_id_or_none(batch_idx, target_idx),
 
1866
  revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
1867
  revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
1868
  revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
 
1869
  revisit_bank.assert_causal(target_frame, selected_records)
1870
  if selected_records:
1871
  valid_revisit_mask[batch_idx, target_idx] = True
 
1885
  target_frame=target_frame,
1886
  )
1887
  eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted)
 
1888
  batch_token_rows.append(stream_tokens)
1889
  batch_mask_rows.append(stream_mask)
1890
+ batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long))
1891
  revisit_token_rows.append(torch.stack(batch_token_rows, dim=0))
1892
  revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0))
1893
  revisit_max_rows.append(torch.stack(batch_max_rows, dim=0))
 
1896
  revisit_max = torch.stack(revisit_max_rows, dim=0)
1897
 
1898
  if anchor_tokens.shape[-2] != anchor_num_tokens:
1899
+ raise AssertionError(f"anchor slot count mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}")
1900
  if dynamic_latents is not None and dynamic_latents.shape[0] > 0:
1901
  _expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target(
1902
  int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1])
1903
  )
1904
  if dynamic_tokens.shape[-2] != _expected_dyn:
1905
+ raise AssertionError(f"dynamic slot count mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}")
1906
+ if revisit_tokens.shape[-2] != revisit_target_slots:
1907
+ raise AssertionError(f"revisit slot count mismatch: got {revisit_tokens.shape[-2]}, expected {revisit_target_slots}")
1908
  anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0
1909
  dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0
1910
  gate_module = getattr(self, "dememwm_revisit_gate", None)
 
1937
  revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
1938
  revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
1939
  revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
 
1940
  eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask)
1941
  valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
1942
  revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
 
1971
  "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
1972
  "no_valid_revisit_mask": no_valid_revisit_mask,
1973
  "valid_revisit_eff_mask": valid_revisit_eff_mask,
 
1974
  "revisit_candidate_frame_count_per_target": revisit_candidate_count,
1975
  "revisit_selected_frame_count_per_target": revisit_selected_count,
1976
  "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
 
1994
  "token_patch_size": token_patch_size,
1995
  "tokens_per_frame": tokens_per_frame,
1996
  "anchor_token_slots": int(anchor_tokens.shape[-2]),
1997
+ "anchor_target_slots": anchor_num_tokens,
1998
  "anchor_pool_h": anchor_pool_h,
1999
  "anchor_pool_w": anchor_pool_w,
2000
  "dynamic_token_slots": int(dynamic_tokens.shape[-2]),
2001
+ "dynamic_target_slots": int(dynamic_tokens.shape[-2]),
2002
  "dynamic_min_gap_to_target": dynamic_min_gap_to_target,
2003
  "dynamic_max_gap_to_target": dynamic_max_gap_to_target,
2004
+ "dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
2005
  "revisit_token_slots": int(revisit_tokens.shape[-2]),
2006
+ "revisit_target_slots": revisit_target_slots,
2007
+ "revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames,
 
2008
  "revisit_pool_h": revisit_pool_h,
2009
  "revisit_pool_w": revisit_pool_w,
2010
  "revisit_max_frames": revisit_max_frames,
algorithms/worldmem/dememwm/cache.py CHANGED
@@ -39,7 +39,6 @@ class StreamingCache:
39
  no_evict: bool = True,
40
  clear_between_videos: bool = True,
41
  max_records: Optional[int] = None,
42
- max_slots: Optional[int] = None,
43
  on_capacity_exceeded: str = "warn",
44
  ) -> None:
45
  self.enabled = bool(enabled)
@@ -51,7 +50,6 @@ class StreamingCache:
51
  self.no_evict = bool(no_evict)
52
  self.clear_between_videos = bool(clear_between_videos)
53
  self.max_records = max_records
54
- self.max_slots = max_slots
55
  self.on_capacity_exceeded = str(on_capacity_exceeded or "warn")
56
  if self.eviction_policy != "none" or not self.no_evict:
57
  raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true")
@@ -92,7 +90,6 @@ class StreamingCache:
92
  no_evict=bool(get("no_evict", True)),
93
  clear_between_videos=bool(get("clear_between_videos", True)),
94
  max_records=get("max_records", None),
95
- max_slots=get("max_slots", None),
96
  on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")),
97
  )
98
 
@@ -213,14 +210,12 @@ class StreamingCache:
213
  exceeded = False
214
  if self.max_records is not None and self.record_count > int(self.max_records):
215
  exceeded = True
216
- if self.max_slots is not None and self.slot_count > int(self.max_slots):
217
- exceeded = True
218
  if not exceeded:
219
  return
220
  self.capacity_exceeded_count += 1
221
  msg = (
222
  "DeMemWMStreamingCache capacity exceeded "
223
- f"records={self.record_count}/{self.max_records}, slots={self.slot_count}/{self.max_slots}; "
224
  "no eviction performed because no_evict=true"
225
  )
226
  if self.on_capacity_exceeded == "error":
 
39
  no_evict: bool = True,
40
  clear_between_videos: bool = True,
41
  max_records: Optional[int] = None,
 
42
  on_capacity_exceeded: str = "warn",
43
  ) -> None:
44
  self.enabled = bool(enabled)
 
50
  self.no_evict = bool(no_evict)
51
  self.clear_between_videos = bool(clear_between_videos)
52
  self.max_records = max_records
 
53
  self.on_capacity_exceeded = str(on_capacity_exceeded or "warn")
54
  if self.eviction_policy != "none" or not self.no_evict:
55
  raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true")
 
90
  no_evict=bool(get("no_evict", True)),
91
  clear_between_videos=bool(get("clear_between_videos", True)),
92
  max_records=get("max_records", None),
 
93
  on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")),
94
  )
95
 
 
210
  exceeded = False
211
  if self.max_records is not None and self.record_count > int(self.max_records):
212
  exceeded = True
 
 
213
  if not exceeded:
214
  return
215
  self.capacity_exceeded_count += 1
216
  msg = (
217
  "DeMemWMStreamingCache capacity exceeded "
218
+ f"records={self.record_count}/{self.max_records}; "
219
  "no eviction performed because no_evict=true"
220
  )
221
  if self.on_capacity_exceeded == "error":
algorithms/worldmem/dememwm/diagnostics.py CHANGED
@@ -64,7 +64,6 @@ def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], vali
64
  exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
65
  valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
66
  valid_count_mean = float(valid_count / target_count) if target_count else 0.0
67
- valid_target_count = sum(int(diag.get("valid_revisit_target_count", diag.get("high_quality_selected_revisit", 0))) for diag in result_diagnostics)
68
  selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
69
  no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
70
  abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
@@ -78,7 +77,6 @@ def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], vali
78
  "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
79
  "valid_revisit_frame_count": valid_count_mean,
80
  "valid_revisit_count": valid_count_mean,
81
- "valid_revisit_target_count": int(valid_target_count),
82
  "no_valid_revisit_count": int(no_valid_count),
83
  "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
84
  "revisit_selected_frame_count": int(selected_count),
 
64
  exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
65
  valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
66
  valid_count_mean = float(valid_count / target_count) if target_count else 0.0
 
67
  selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
68
  no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
69
  abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
 
77
  "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
78
  "valid_revisit_frame_count": valid_count_mean,
79
  "valid_revisit_count": valid_count_mean,
 
80
  "no_valid_revisit_count": int(no_valid_count),
81
  "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
82
  "revisit_selected_frame_count": int(selected_count),
algorithms/worldmem/dememwm/memory.py CHANGED
@@ -15,15 +15,13 @@ class MemoryBankQuery:
15
  source_type: Optional[MemorySourceType] = None
16
  include_generated: bool = True
17
  max_records: Optional[int] = None
18
- max_slots: Optional[int] = None
19
 
20
 
21
  class CausalMemoryBank:
22
  """Small causal memory bank for DeMemWM records."""
23
 
24
- def __init__(self, max_records: Optional[int] = None, max_slots: Optional[int] = None):
25
  self.max_records = max_records
26
- self.max_slots = max_slots
27
  self._records: list[MemoryRecord] = []
28
 
29
  def __len__(self) -> int:
@@ -172,7 +170,6 @@ class CausalMemoryBank:
172
  if isinstance(query, int):
173
  query = MemoryBankQuery(target_frame=query, **kwargs)
174
  out: list[MemoryRecord] = []
175
- used_slots = 0
176
  for record in self._records:
177
  if int(record.source_end) > int(query.target_frame):
178
  continue
@@ -180,15 +177,9 @@ class CausalMemoryBank:
180
  continue
181
  if not query.include_generated and record.is_generated:
182
  continue
183
- if query.max_slots is not None and used_slots >= query.max_slots:
184
- break
185
  out.append(record)
186
- if query.max_slots is not None:
187
- used_slots += record.valid_slots
188
  if query.max_records is not None and len(out) >= query.max_records:
189
  break
190
- if query.max_slots is not None and used_slots >= query.max_slots:
191
- break
192
  return out
193
 
194
  def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
@@ -197,12 +188,13 @@ class CausalMemoryBank:
197
  raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
198
 
199
 
200
- def stack_record_tokens(records: list[MemoryRecord], max_slots: int | None = None):
201
  if not records:
202
  return None, None
203
  tokens = torch.cat([r.tokens for r in records], dim=0)
204
  mask = torch.cat([r.mask.bool() for r in records], dim=0)
205
- if max_slots is not None:
206
- tokens = tokens[:max_slots]
207
- mask = mask[:max_slots]
 
208
  return tokens, mask
 
15
  source_type: Optional[MemorySourceType] = None
16
  include_generated: bool = True
17
  max_records: Optional[int] = None
 
18
 
19
 
20
  class CausalMemoryBank:
21
  """Small causal memory bank for DeMemWM records."""
22
 
23
+ def __init__(self, max_records: Optional[int] = None):
24
  self.max_records = max_records
 
25
  self._records: list[MemoryRecord] = []
26
 
27
  def __len__(self) -> int:
 
170
  if isinstance(query, int):
171
  query = MemoryBankQuery(target_frame=query, **kwargs)
172
  out: list[MemoryRecord] = []
 
173
  for record in self._records:
174
  if int(record.source_end) > int(query.target_frame):
175
  continue
 
177
  continue
178
  if not query.include_generated and record.is_generated:
179
  continue
 
 
180
  out.append(record)
 
 
181
  if query.max_records is not None and len(out) >= query.max_records:
182
  break
 
 
183
  return out
184
 
185
  def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
 
188
  raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
189
 
190
 
191
+ def stack_record_tokens(records: list[MemoryRecord], target_slots: int | None = None):
192
  if not records:
193
  return None, None
194
  tokens = torch.cat([r.tokens for r in records], dim=0)
195
  mask = torch.cat([r.mask.bool() for r in records], dim=0)
196
+ if target_slots is not None:
197
+ valid_idx = mask.nonzero(as_tuple=False).flatten()
198
+ tokens = tokens.index_select(0, valid_idx)[:target_slots]
199
+ mask = mask.index_select(0, valid_idx)[:target_slots]
200
  return tokens, mask
algorithms/worldmem/dememwm/retrieval.py CHANGED
@@ -427,7 +427,6 @@ def deterministic_revisit_retrieval(
427
  "revisit_candidate_count": len(causal_records),
428
  "valid_revisit_frame_count": len(valid_labels),
429
  "valid_revisit_count": len(valid_labels),
430
- "valid_revisit_target_count": high_quality_selected,
431
  "no_valid_revisit_count": int(len(valid_labels) == 0),
432
  "valid_revisit_mask": int(len(valid_labels) > 0),
433
  "revisit_abstained_count": int(len(selected_records) == 0),
 
427
  "revisit_candidate_count": len(causal_records),
428
  "valid_revisit_frame_count": len(valid_labels),
429
  "valid_revisit_count": len(valid_labels),
 
430
  "no_valid_revisit_count": int(len(valid_labels) == 0),
431
  "valid_revisit_mask": int(len(valid_labels) > 0),
432
  "revisit_abstained_count": int(len(selected_records) == 0),
configurations/algorithm/dememwm_memory_dit.yaml CHANGED
@@ -93,7 +93,6 @@ dememwm:
93
  no_evict: true
94
  clear_between_videos: true
95
  max_records: null
96
- max_slots: null
97
  on_capacity_exceeded: warn
98
  checkpoint:
99
  strict_dememwm_eval_load: true
 
93
  no_evict: true
94
  clear_between_videos: true
95
  max_records: null
 
96
  on_capacity_exceeded: warn
97
  checkpoint:
98
  strict_dememwm_eval_load: true
scripts/dememwm_full_eval.slurm CHANGED
@@ -150,7 +150,6 @@ EVAL_ARGS=(
150
  "++algorithm.dememwm.cache.no_evict=true"
151
  "++algorithm.dememwm.cache.clear_between_videos=true"
152
  "++algorithm.dememwm.cache.max_records=null"
153
- "++algorithm.dememwm.cache.max_slots=null"
154
  "++algorithm.dememwm.cache.on_capacity_exceeded=warn"
155
  "experiment.validation.batch_size=${VAL_BATCH_SIZE}"
156
  "experiment.validation.limit_batch=${VAL_LIMIT}"
 
150
  "++algorithm.dememwm.cache.no_evict=true"
151
  "++algorithm.dememwm.cache.clear_between_videos=true"
152
  "++algorithm.dememwm.cache.max_records=null"
 
153
  "++algorithm.dememwm.cache.on_capacity_exceeded=warn"
154
  "experiment.validation.batch_size=${VAL_BATCH_SIZE}"
155
  "experiment.validation.limit_batch=${VAL_LIMIT}"
scripts/dememwm_full_train.slurm CHANGED
@@ -69,8 +69,7 @@ srun python -m main \
69
  ++algorithm.dememwm.dynamic.recent_frames=4 \
70
  ++algorithm.dememwm.revisit.enabled=true \
71
  ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
72
- ++algorithm.dememwm.revisit.fov_overlap_threshold=0.30 \
73
- ++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70 \
74
  ++algorithm.dememwm.revisit.pose_preselect_topk=64 \
75
  ++algorithm.dememwm.revisit.fov_yaw_samples=25 \
76
  ++algorithm.dememwm.revisit.fov_pitch_samples=20 \
@@ -87,7 +86,6 @@ srun python -m main \
87
  ++algorithm.dememwm.cache.no_evict=true \
88
  ++algorithm.dememwm.cache.clear_between_videos=true \
89
  ++algorithm.dememwm.cache.max_records=null \
90
- ++algorithm.dememwm.cache.max_slots=null \
91
  ++algorithm.dememwm.cache.on_capacity_exceeded=warn \
92
  ++algorithm.dememwm.curriculum.enabled=true \
93
  ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
 
69
  ++algorithm.dememwm.dynamic.recent_frames=4 \
70
  ++algorithm.dememwm.revisit.enabled=true \
71
  ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
72
+ ++algorithm.dememwm.revisit.fov_overlap_threshold=0.60 \
 
73
  ++algorithm.dememwm.revisit.pose_preselect_topk=64 \
74
  ++algorithm.dememwm.revisit.fov_yaw_samples=25 \
75
  ++algorithm.dememwm.revisit.fov_pitch_samples=20 \
 
86
  ++algorithm.dememwm.cache.no_evict=true \
87
  ++algorithm.dememwm.cache.clear_between_videos=true \
88
  ++algorithm.dememwm.cache.max_records=null \
 
89
  ++algorithm.dememwm.cache.on_capacity_exceeded=warn \
90
  ++algorithm.dememwm.curriculum.enabled=true \
91
  ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
tests/test_dememwm_config_static.py CHANGED
@@ -72,8 +72,6 @@ def test_full_scripts_use_consumed_contract_overrides():
72
  required = [
73
  "algorithm.dememwm.dynamic.exclude_latest_local_frames=4",
74
  "algorithm.dememwm.revisit.deterministic_pose_retrieval=true",
75
- "algorithm.dememwm.revisit.fov_overlap_threshold=0.30",
76
- "algorithm.dememwm.revisit.high_quality_fov_threshold=0.70",
77
  "algorithm.dememwm.revisit.pose_preselect_topk=64",
78
  "algorithm.dememwm.revisit.fov_yaw_samples=25",
79
  "algorithm.dememwm.revisit.fov_pitch_samples=20",
@@ -98,9 +96,18 @@ def test_full_scripts_use_consumed_contract_overrides():
98
  "algorithm.dememwm.revisit.generated_penalty",
99
  "algorithm.dememwm.rollout.",
100
  ]
101
- for rel in ("scripts/dememwm_full_train.slurm", "scripts/dememwm_full_eval.slurm"):
 
 
 
 
 
 
 
 
 
102
  text = Path(rel).read_text()
103
- for token in required:
104
  assert token in text, f"{token} missing from {rel}"
105
  for token in stale:
106
  assert token not in text, f"stale {token} override remains in {rel}"
@@ -145,7 +152,6 @@ def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
145
  "valid_revisit_mask",
146
  "revisit_candidate_frame_count",
147
  "valid_candidate_label_count",
148
- "valid_revisit_target_count",
149
  "valid_revisit_frame_count",
150
  "no_valid_revisit_count",
151
  "revisit_selected_frame_count",
@@ -180,7 +186,6 @@ def test_eval_ablation_and_noise_bucket_logging_contracts():
180
  schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
181
  diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
182
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
183
- matrix = Path("scripts/dememwm_eval_ablation_matrix.sh").read_text()
184
  for branch in [
185
  "memory_off",
186
  "A_only",
@@ -197,7 +202,6 @@ def test_eval_ablation_and_noise_bucket_logging_contracts():
197
  "local_context_overlap_fake_revisit",
198
  ]:
199
  assert branch in schedules
200
- assert branch in matrix
201
  for token in [
202
  "noise_bucket_from_denoising_fraction",
203
  "noise_bucket_from_noise_levels",
 
72
  required = [
73
  "algorithm.dememwm.dynamic.exclude_latest_local_frames=4",
74
  "algorithm.dememwm.revisit.deterministic_pose_retrieval=true",
 
 
75
  "algorithm.dememwm.revisit.pose_preselect_topk=64",
76
  "algorithm.dememwm.revisit.fov_yaw_samples=25",
77
  "algorithm.dememwm.revisit.fov_pitch_samples=20",
 
96
  "algorithm.dememwm.revisit.generated_penalty",
97
  "algorithm.dememwm.rollout.",
98
  ]
99
+ expected_by_script = {
100
+ "scripts/dememwm_full_train.slurm": [
101
+ "algorithm.dememwm.revisit.fov_overlap_threshold=0.60",
102
+ ],
103
+ "scripts/dememwm_full_eval.slurm": [
104
+ "algorithm.dememwm.revisit.fov_overlap_threshold=0.30",
105
+ "algorithm.dememwm.revisit.high_quality_fov_threshold=0.70",
106
+ ],
107
+ }
108
+ for rel, script_specific_required in expected_by_script.items():
109
  text = Path(rel).read_text()
110
+ for token in required + script_specific_required:
111
  assert token in text, f"{token} missing from {rel}"
112
  for token in stale:
113
  assert token not in text, f"stale {token} override remains in {rel}"
 
152
  "valid_revisit_mask",
153
  "revisit_candidate_frame_count",
154
  "valid_candidate_label_count",
 
155
  "valid_revisit_frame_count",
156
  "no_valid_revisit_count",
157
  "revisit_selected_frame_count",
 
186
  schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
187
  diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
188
  algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
 
189
  for branch in [
190
  "memory_off",
191
  "A_only",
 
202
  "local_context_overlap_fake_revisit",
203
  ]:
204
  assert branch in schedules
 
205
  for token in [
206
  "noise_bucket_from_denoising_fraction",
207
  "noise_bucket_from_noise_levels",
tests/test_dememwm_memory.py CHANGED
@@ -44,12 +44,46 @@ def test_all_false_masks_are_valid_abstention_outputs():
44
  assert mask.sum().item() == 0
45
 
46
 
47
- def test_budgets_cap_records_and_slots():
48
  bank = CausalMemoryBank(max_records=10)
49
  for f in range(6):
50
  bank.add_record(_record(f, slots=2))
51
- records = bank.query(MemoryBankQuery(target_frame=10, max_records=2, max_slots=3))
52
  assert len(records) == 2
53
- tokens, mask = stack_record_tokens(records, max_slots=3)
54
  assert tokens.shape[0] == 3
55
  assert mask.shape[0] == 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  assert mask.sum().item() == 0
45
 
46
 
47
+ def test_query_caps_records_and_stack_uses_target_slots():
48
  bank = CausalMemoryBank(max_records=10)
49
  for f in range(6):
50
  bank.add_record(_record(f, slots=2))
51
+ records = bank.query(MemoryBankQuery(target_frame=10, max_records=2))
52
  assert len(records) == 2
53
+ tokens, mask = stack_record_tokens(records, target_slots=3)
54
  assert tokens.shape[0] == 3
55
  assert mask.shape[0] == 3
56
+
57
+
58
+ def test_target_slots_ignore_masked_slots_when_stacking_records():
59
+ invalid = MemoryRecord(
60
+ tokens=torch.ones(4, 4),
61
+ mask=torch.zeros(4, dtype=torch.bool),
62
+ source_start=0,
63
+ source_end=1,
64
+ frame_indices=torch.tensor([0]),
65
+ pose=None,
66
+ source_type=MemorySourceType.REVISIT,
67
+ is_generated=False,
68
+ chunk_id="invalid",
69
+ )
70
+ valid = MemoryRecord(
71
+ tokens=torch.ones(2, 4) * 2,
72
+ mask=torch.ones(2, dtype=torch.bool),
73
+ source_start=1,
74
+ source_end=2,
75
+ frame_indices=torch.tensor([1]),
76
+ pose=None,
77
+ source_type=MemorySourceType.REVISIT,
78
+ is_generated=False,
79
+ chunk_id="valid",
80
+ )
81
+ bank = CausalMemoryBank()
82
+ bank.add_record(invalid)
83
+ bank.add_record(valid)
84
+
85
+ records = bank.query(MemoryBankQuery(target_frame=3))
86
+ tokens, mask = stack_record_tokens(records, target_slots=2)
87
+
88
+ assert mask.tolist() == [True, True]
89
+ assert torch.equal(tokens, torch.ones(2, 4) * 2)
tests/test_dememwm_noise_bucket.py CHANGED
@@ -92,7 +92,6 @@ def test_noise_bucket_log_allowlist_keeps_target_counts_only():
92
  "noise_bucket_low_target_count",
93
  "revisit_candidate_frame_count",
94
  "valid_revisit_frame_count",
95
- "valid_revisit_target_count",
96
  "revisit_selected_frame_count",
97
  "revisit_frame_fov_overlap_mean",
98
  "revisit_best_selected_frame_fov_overlap_mean",
 
92
  "noise_bucket_low_target_count",
93
  "revisit_candidate_frame_count",
94
  "valid_revisit_frame_count",
 
95
  "revisit_selected_frame_count",
96
  "revisit_frame_fov_overlap_mean",
97
  "revisit_best_selected_frame_fov_overlap_mean",
tests/test_dememwm_preselection.py CHANGED
@@ -56,6 +56,50 @@ def test_revisit_local_context_exclusion_uses_n_tokens_times_frame_stack():
56
  assert harness._local_context_exclusion_frames() == 8
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
60
  harness = Harness()
61
  harness.context_frames = 2
@@ -92,6 +136,31 @@ def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
92
  assert diag["preselected_anchor_projected_frame_count"] == 2
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def test_preselected_memory_banks_project_only_selected_frames():
96
  harness = Harness()
97
  latents = torch.randn(20, 1, 3, 2, 2)
 
56
  assert harness._local_context_exclusion_frames() == 8
57
 
58
 
59
+ def test_diverse_anchor_selection_does_not_repeat_tied_pose_indices():
60
+ harness = Harness()
61
+ source_positions = torch.arange(5)
62
+ poses = torch.zeros((5, 5), dtype=torch.float32)
63
+
64
+ selected = harness._select_diverse_anchor_positions(source_positions, poses, 4)
65
+
66
+ assert selected.tolist() == [0, 1, 2, 3]
67
+
68
+
69
+ def test_diverse_anchor_selection_seeds_from_widest_pose_pair():
70
+ harness = Harness()
71
+ source_positions = torch.arange(4)
72
+ poses = torch.tensor([[0.0], [-10.0], [10.0], [0.1]], dtype=torch.float32)
73
+
74
+ selected = harness._select_diverse_anchor_positions(source_positions, poses, 2)
75
+
76
+ assert selected.tolist() == [1, 2]
77
+
78
+
79
+ def test_cached_revisit_prefilter_keeps_only_causal_records():
80
+ harness = Harness()
81
+
82
+ def record(frame: int) -> MemoryRecord:
83
+ return MemoryRecord(
84
+ tokens=torch.zeros((1, 8)),
85
+ mask=torch.ones(1, dtype=torch.bool),
86
+ source_start=frame,
87
+ source_end=frame + 1,
88
+ frame_indices=torch.tensor([frame]),
89
+ pose=None,
90
+ source_type=MemorySourceType.REVISIT,
91
+ is_generated=False,
92
+ chunk_id=f"revisit_{frame}",
93
+ )
94
+
95
+ selected = harness._causal_cached_revisit_records(
96
+ (record(0), record(2), record(5)),
97
+ target_frame=3,
98
+ )
99
+
100
+ assert [record.source_start for record in selected] == [0, 2]
101
+
102
+
103
  def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
104
  harness = Harness()
105
  harness.context_frames = 2
 
136
  assert diag["preselected_anchor_projected_frame_count"] == 2
137
 
138
 
139
+ def test_streaming_diverse_anchor_selection_uses_context_frames():
140
+ harness = Harness()
141
+ harness.context_frames = 2
142
+ latents = torch.randn(8, 1, 3, 2, 2)
143
+ frame_indices = torch.arange(8)[:, None]
144
+ poses = torch.zeros((8, 1, 5), dtype=torch.float32)
145
+
146
+ anchor_banks, _ = harness._build_streaming_cache_records(
147
+ source_latents=latents,
148
+ source_frame_indices=frame_indices,
149
+ source_is_generated=None,
150
+ pose=poses,
151
+ action=None,
152
+ allow_generated_anchor=False,
153
+ anchor_indices=[0, 1, 2, 3],
154
+ anchor_pool_h=1,
155
+ anchor_pool_w=1,
156
+ anchor_diverse=True,
157
+ token_patch_size=2,
158
+ )
159
+
160
+ assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
161
+ assert harness.project_call_lengths == [2]
162
+
163
+
164
  def test_preselected_memory_banks_project_only_selected_frames():
165
  harness = Harness()
166
  latents = torch.randn(20, 1, 3, 2, 2)
tests/test_dememwm_retrieval.py CHANGED
@@ -96,7 +96,6 @@ def test_revisit_candidates_require_causal_c_short_gap():
96
  assert result.diagnostics["valid_revisit_frame_count"] == 1
97
  assert result.diagnostics["valid_revisit_count"] == 1
98
  assert result.diagnostics["valid_candidate_label_count"] == 1
99
- assert result.diagnostics["valid_revisit_target_count"] == 1
100
  assert result.diagnostics["revisit_min_gap_to_target"] == 5
101
  assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
102
 
@@ -106,7 +105,6 @@ def test_revisit_abstains_when_no_valid_candidate():
106
  assert result.records == []
107
  assert result.diagnostics["abstained"] is True
108
  assert result.diagnostics["valid_revisit_mask"] == 0
109
- assert result.diagnostics["valid_revisit_target_count"] == 0
110
  assert result.diagnostics["no_valid_revisit_count"] == 1
111
 
112
 
@@ -155,7 +153,6 @@ def test_fov_threshold_filters_candidates_without_action():
155
  assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
156
  assert result.diagnostics["valid_revisit_frame_count"] == 1
157
  assert result.diagnostics["valid_revisit_count"] == 1
158
- assert result.diagnostics["valid_revisit_target_count"] == 1
159
  assert result.diagnostics["best_selected_fov_overlap"] == 1.0
160
  assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0
161
  assert result.diagnostics["best_selected_gap_frames"] == 10
@@ -206,7 +203,6 @@ def test_selected_frame_carries_frame_metadata_for_projection():
206
  assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
207
  assert result.diagnostics["best_selected_frame_index"] == 1
208
  assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0
209
- assert result.diagnostics["valid_revisit_target_count"] == 1
210
 
211
 
212
  def test_high_quality_threshold_is_selected_target_diagnostic_only():
@@ -221,7 +217,6 @@ def test_high_quality_threshold_is_selected_target_diagnostic_only():
221
  )
222
  assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
223
  assert result.diagnostics["valid_revisit_count"] == 1
224
- assert result.diagnostics["valid_revisit_target_count"] == 0
225
  assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70
226
 
227
 
 
96
  assert result.diagnostics["valid_revisit_frame_count"] == 1
97
  assert result.diagnostics["valid_revisit_count"] == 1
98
  assert result.diagnostics["valid_candidate_label_count"] == 1
 
99
  assert result.diagnostics["revisit_min_gap_to_target"] == 5
100
  assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
101
 
 
105
  assert result.records == []
106
  assert result.diagnostics["abstained"] is True
107
  assert result.diagnostics["valid_revisit_mask"] == 0
 
108
  assert result.diagnostics["no_valid_revisit_count"] == 1
109
 
110
 
 
153
  assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
154
  assert result.diagnostics["valid_revisit_frame_count"] == 1
155
  assert result.diagnostics["valid_revisit_count"] == 1
 
156
  assert result.diagnostics["best_selected_fov_overlap"] == 1.0
157
  assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0
158
  assert result.diagnostics["best_selected_gap_frames"] == 10
 
203
  assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
204
  assert result.diagnostics["best_selected_frame_index"] == 1
205
  assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0
 
206
 
207
 
208
  def test_high_quality_threshold_is_selected_target_diagnostic_only():
 
217
  )
218
  assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
219
  assert result.diagnostics["valid_revisit_count"] == 1
 
220
  assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70
221
 
222
 
tests/test_dememwm_stream_grad.py CHANGED
@@ -23,7 +23,7 @@ def test_records_to_stream_preserves_grad_to_record_tokens():
23
  tokens, mask, max_source = MemoryDiTMixin._records_to_stream(
24
  object(),
25
  [record],
26
- max_tokens=4,
27
  hidden_size=4,
28
  device=torch.device("cpu"),
29
  dtype=torch.float32,
 
23
  tokens, mask, max_source = MemoryDiTMixin._records_to_stream(
24
  object(),
25
  [record],
26
+ target_slots=4,
27
  hidden_size=4,
28
  device=torch.device("cpu"),
29
  dtype=torch.float32,
train_dememwm_full_berzelius.sh CHANGED
@@ -23,10 +23,10 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
23
  export WANDB_DISABLED=true
24
  export HYDRA_FULL_ERROR=1
25
 
26
- OUTPUT_DIR=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/dememwm_full_berzelius_8a100_bs8_global64_350k
27
 
28
  srun python -m main \
29
- +name=train_dememwm_full_berzelius_8a100_bs8_global64_350k \
30
  +output_dir="${OUTPUT_DIR}/" \
31
  auto_resume=true \
32
  experiment.tasks=[training] \
@@ -40,7 +40,7 @@ srun python -m main \
40
  dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
41
  dataset.precomputed_feature_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft/vae_features \
42
  dataset.n_frames=1000 \
43
- +dataset.n_frames_valid=1100 \
44
  +dataset.customized_validation=true \
45
  +dataset.memory_condition_length=0 \
46
  +dataset.wo_updown=false \
@@ -68,8 +68,7 @@ srun python -m main \
68
  ++algorithm.dememwm.dynamic.recent_frames=4 \
69
  ++algorithm.dememwm.revisit.enabled=true \
70
  ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
71
- ++algorithm.dememwm.revisit.fov_overlap_threshold=0.30 \
72
- ++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70 \
73
  ++algorithm.dememwm.revisit.pose_preselect_topk=64 \
74
  ++algorithm.dememwm.revisit.fov_yaw_samples=25 \
75
  ++algorithm.dememwm.revisit.fov_pitch_samples=20 \
@@ -86,7 +85,6 @@ srun python -m main \
86
  ++algorithm.dememwm.cache.no_evict=true \
87
  ++algorithm.dememwm.cache.clear_between_videos=true \
88
  ++algorithm.dememwm.cache.max_records=null \
89
- ++algorithm.dememwm.cache.max_slots=null \
90
  ++algorithm.dememwm.cache.on_capacity_exceeded=warn \
91
  ++algorithm.dememwm.curriculum.enabled=true \
92
  ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
@@ -95,7 +93,7 @@ srun python -m main \
95
  ++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \
96
  ++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \
97
  ++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \
98
- experiment.training.batch_size=8 \
99
  experiment.training.optim.accumulate_grad_batches=1 \
100
  experiment.validation.batch_size=1 \
101
  experiment.validation.limit_batch=8 \
 
23
  export WANDB_DISABLED=true
24
  export HYDRA_FULL_ERROR=1
25
 
26
+ OUTPUT_DIR=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/dememwm_full_berzelius_8a100_bs16_global128_350k
27
 
28
  srun python -m main \
29
+ +name=train_dememwm_full_berzelius_8a100_bs16_global128_350k \
30
  +output_dir="${OUTPUT_DIR}/" \
31
  auto_resume=true \
32
  experiment.tasks=[training] \
 
40
  dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
41
  dataset.precomputed_feature_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft/vae_features \
42
  dataset.n_frames=1000 \
43
+ +dataset.n_frames_valid=700 \
44
  +dataset.customized_validation=true \
45
  +dataset.memory_condition_length=0 \
46
  +dataset.wo_updown=false \
 
68
  ++algorithm.dememwm.dynamic.recent_frames=4 \
69
  ++algorithm.dememwm.revisit.enabled=true \
70
  ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
71
+ ++algorithm.dememwm.revisit.fov_overlap_threshold=0.60 \
 
72
  ++algorithm.dememwm.revisit.pose_preselect_topk=64 \
73
  ++algorithm.dememwm.revisit.fov_yaw_samples=25 \
74
  ++algorithm.dememwm.revisit.fov_pitch_samples=20 \
 
85
  ++algorithm.dememwm.cache.no_evict=true \
86
  ++algorithm.dememwm.cache.clear_between_videos=true \
87
  ++algorithm.dememwm.cache.max_records=null \
 
88
  ++algorithm.dememwm.cache.on_capacity_exceeded=warn \
89
  ++algorithm.dememwm.curriculum.enabled=true \
90
  ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
 
93
  ++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \
94
  ++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \
95
  ++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \
96
+ experiment.training.batch_size=16 \
97
  experiment.training.optim.accumulate_grad_batches=1 \
98
  experiment.validation.batch_size=1 \
99
  experiment.validation.limit_batch=8 \