Clean DeMemWM deterministic memory slot handling
Browse files- algorithms/worldmem/dememwm/algorithm.py +70 -49
- algorithms/worldmem/dememwm/cache.py +1 -6
- algorithms/worldmem/dememwm/diagnostics.py +0 -2
- algorithms/worldmem/dememwm/memory.py +6 -14
- algorithms/worldmem/dememwm/retrieval.py +0 -1
- configurations/algorithm/dememwm_memory_dit.yaml +0 -1
- scripts/dememwm_full_eval.slurm +0 -1
- scripts/dememwm_full_train.slurm +1 -3
- tests/test_dememwm_config_static.py +11 -7
- tests/test_dememwm_memory.py +37 -3
- tests/test_dememwm_noise_bucket.py +0 -1
- tests/test_dememwm_preselection.py +69 -0
- tests/test_dememwm_retrieval.py +0 -5
- tests/test_dememwm_stream_grad.py +1 -1
- train_dememwm_full_berzelius.sh +5 -7
algorithms/worldmem/dememwm/algorithm.py
CHANGED
|
@@ -42,7 +42,6 @@ class MemoryDiTMixin:
|
|
| 42 |
"revisit_pose_preselect_selected_count",
|
| 43 |
"revisit_exact_fov_candidate_count",
|
| 44 |
"valid_revisit_frame_count",
|
| 45 |
-
"valid_revisit_target_count",
|
| 46 |
"no_valid_revisit_count",
|
| 47 |
"revisit_selected_frame_count",
|
| 48 |
"revisit_frame_fov_overlap_mean",
|
|
@@ -171,7 +170,7 @@ class MemoryDiTMixin:
|
|
| 171 |
if self._cfg_has(memory_cfg, name)
|
| 172 |
]
|
| 173 |
if ratio_fields:
|
| 174 |
-
raise ValueError(f"standalone DeMemWM
|
| 175 |
|
| 176 |
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
|
| 177 |
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
|
|
@@ -247,7 +246,6 @@ class MemoryDiTMixin:
|
|
| 247 |
"revisit_deterministic_fov_plucker_retrieval": True,
|
| 248 |
"revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
|
| 249 |
"revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
|
| 250 |
-
"revisit_high_quality_fov_threshold": high_quality_fov_threshold,
|
| 251 |
"revisit_plucker_weight": plucker_weight,
|
| 252 |
"stage_policy_noise_bucket_logging": True,
|
| 253 |
}
|
|
@@ -764,13 +762,28 @@ class MemoryDiTMixin:
|
|
| 764 |
if source_positions.numel() <= num_anchors or pose is None:
|
| 765 |
return source_positions[:num_anchors]
|
| 766 |
poses = pose.float()
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
farthest = int(dists.argmax().item())
|
|
|
|
|
|
|
| 771 |
selected.append(farthest)
|
| 772 |
-
|
|
|
|
| 773 |
dists = torch.minimum(dists, d_new)
|
|
|
|
| 774 |
return source_positions[torch.tensor(sorted(selected), device=source_positions.device)]
|
| 775 |
|
| 776 |
def _build_streaming_cache_records(
|
|
@@ -867,10 +880,14 @@ class MemoryDiTMixin:
|
|
| 867 |
source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
|
| 868 |
if source_positions.numel() > 0:
|
| 869 |
if anchor_diverse:
|
| 870 |
-
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 874 |
else:
|
| 875 |
selected_list = []
|
| 876 |
for anchor_idx in anchor_indices:
|
|
@@ -1339,28 +1356,36 @@ class MemoryDiTMixin:
|
|
| 1339 |
}
|
| 1340 |
return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
|
| 1341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1342 |
def _records_to_stream(
|
| 1343 |
self,
|
| 1344 |
records,
|
| 1345 |
-
|
| 1346 |
hidden_size: int,
|
| 1347 |
device: torch.device,
|
| 1348 |
dtype: torch.dtype,
|
| 1349 |
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
| 1350 |
-
|
| 1351 |
record_list = list(records)
|
| 1352 |
-
stacked_tokens, stacked_mask = stack_record_tokens(record_list,
|
| 1353 |
max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1)
|
| 1354 |
-
if stacked_tokens is None or stacked_mask is None or
|
| 1355 |
-
tokens = torch.zeros((
|
| 1356 |
-
mask = torch.zeros((
|
| 1357 |
return tokens, mask, max_source_frame
|
| 1358 |
-
n = min(
|
| 1359 |
filled = stacked_tokens[:n].to(device=device, dtype=dtype)
|
| 1360 |
filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool)
|
| 1361 |
-
if n <
|
| 1362 |
-
pad = filled.new_zeros(
|
| 1363 |
-
pad_mask = torch.zeros(
|
| 1364 |
tokens = torch.cat([filled, pad], dim=0)
|
| 1365 |
mask = torch.cat([filled_mask, pad_mask], dim=0)
|
| 1366 |
else:
|
|
@@ -1520,10 +1545,10 @@ class MemoryDiTMixin:
|
|
| 1520 |
revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size(
|
| 1521 |
revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8
|
| 1522 |
)
|
| 1523 |
-
|
| 1524 |
recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8))
|
| 1525 |
-
|
| 1526 |
-
|
| 1527 |
fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
|
| 1528 |
high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
|
| 1529 |
plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10))
|
|
@@ -1568,7 +1593,7 @@ class MemoryDiTMixin:
|
|
| 1568 |
dtype=stream_dtype,
|
| 1569 |
max_recent_frames=recent_frames,
|
| 1570 |
target_frame_indices=target_frame_indices,
|
| 1571 |
-
exclude_latest_local_frames=
|
| 1572 |
)
|
| 1573 |
if raw_latents is not None:
|
| 1574 |
dynamic_latents = raw_latents
|
|
@@ -1631,7 +1656,7 @@ class MemoryDiTMixin:
|
|
| 1631 |
revisit_pool_h,
|
| 1632 |
revisit_pool_w,
|
| 1633 |
revisit_max_frames,
|
| 1634 |
-
|
| 1635 |
fov_overlap_threshold,
|
| 1636 |
plucker_weight,
|
| 1637 |
revisit_retrieval_kwargs,
|
|
@@ -1641,7 +1666,7 @@ class MemoryDiTMixin:
|
|
| 1641 |
|
| 1642 |
T_tgt = target_frame_indices.shape[0]
|
| 1643 |
anchor_slots = max(0, anchor_num_tokens)
|
| 1644 |
-
revisit_slots = max(0,
|
| 1645 |
anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT
|
| 1646 |
anchor_include_generated = allow_generated_anchor
|
| 1647 |
anchor_token_rows = []
|
|
@@ -1659,7 +1684,6 @@ class MemoryDiTMixin:
|
|
| 1659 |
source_type=anchor_source_type,
|
| 1660 |
include_generated=anchor_include_generated,
|
| 1661 |
max_records=len(anchor_indices),
|
| 1662 |
-
max_slots=anchor_slots,
|
| 1663 |
)
|
| 1664 |
)
|
| 1665 |
anchor_bank.assert_causal(target_frame, records)
|
|
@@ -1693,7 +1717,7 @@ class MemoryDiTMixin:
|
|
| 1693 |
"dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1694 |
"dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1695 |
"dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
|
| 1696 |
-
"dynamic_exclude_latest_local_frames":
|
| 1697 |
}
|
| 1698 |
else:
|
| 1699 |
# Pre-select dynamic source frame positions using only frame index metadata
|
|
@@ -1705,7 +1729,7 @@ class MemoryDiTMixin:
|
|
| 1705 |
for _b in range(B):
|
| 1706 |
for _j in range(T_tgt):
|
| 1707 |
_target = int(target_frame_indices[_j, _b].item())
|
| 1708 |
-
_valid = (_dfi[:, _b] < _target -
|
| 1709 |
_needed.extend(_valid[-_max_src:].tolist())
|
| 1710 |
if _needed:
|
| 1711 |
_needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long)
|
|
@@ -1727,7 +1751,7 @@ class MemoryDiTMixin:
|
|
| 1727 |
_dynamic_pose_small,
|
| 1728 |
target_frame_indices,
|
| 1729 |
_dynamic_gen_small,
|
| 1730 |
-
exclude_latest_local_frames=
|
| 1731 |
)
|
| 1732 |
|
| 1733 |
dynamic_min_gap_tensor = torch.as_tensor(
|
|
@@ -1785,7 +1809,6 @@ class MemoryDiTMixin:
|
|
| 1785 |
revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1786 |
revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1787 |
revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
|
| 1788 |
-
valid_revisit_target_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
| 1789 |
eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
| 1790 |
revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
|
| 1791 |
eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
|
|
@@ -1801,7 +1824,10 @@ class MemoryDiTMixin:
|
|
| 1801 |
for target_idx in range(T_tgt):
|
| 1802 |
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
|
| 1803 |
if use_cache_revisit_records:
|
| 1804 |
-
candidate_records =
|
|
|
|
|
|
|
|
|
|
| 1805 |
else:
|
| 1806 |
candidate_records = revisit_bank.query(
|
| 1807 |
MemoryBankQuery(
|
|
@@ -1815,7 +1841,7 @@ class MemoryDiTMixin:
|
|
| 1815 |
target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
|
| 1816 |
target_summary=None,
|
| 1817 |
topk=revisit_max_frames,
|
| 1818 |
-
exclude_local_context_frames=
|
| 1819 |
fov_overlap_threshold=fov_overlap_threshold,
|
| 1820 |
plucker_weight=plucker_weight,
|
| 1821 |
target_video_id=_target_video_id_or_none(batch_idx, target_idx),
|
|
@@ -1840,7 +1866,6 @@ class MemoryDiTMixin:
|
|
| 1840 |
revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
|
| 1841 |
revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
|
| 1842 |
revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
|
| 1843 |
-
valid_revisit_target_mask[batch_idx, target_idx] = bool(result.diagnostics.get("valid_revisit_target_count", 0))
|
| 1844 |
revisit_bank.assert_causal(target_frame, selected_records)
|
| 1845 |
if selected_records:
|
| 1846 |
valid_revisit_mask[batch_idx, target_idx] = True
|
|
@@ -1860,10 +1885,9 @@ class MemoryDiTMixin:
|
|
| 1860 |
target_frame=target_frame,
|
| 1861 |
)
|
| 1862 |
eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted)
|
| 1863 |
-
actual_max_source_frame = max((int(record.max_source_frame) for record in selected_records), default=max_source_frame)
|
| 1864 |
batch_token_rows.append(stream_tokens)
|
| 1865 |
batch_mask_rows.append(stream_mask)
|
| 1866 |
-
batch_max_rows.append(torch.as_tensor(
|
| 1867 |
revisit_token_rows.append(torch.stack(batch_token_rows, dim=0))
|
| 1868 |
revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0))
|
| 1869 |
revisit_max_rows.append(torch.stack(batch_max_rows, dim=0))
|
|
@@ -1872,15 +1896,15 @@ class MemoryDiTMixin:
|
|
| 1872 |
revisit_max = torch.stack(revisit_max_rows, dim=0)
|
| 1873 |
|
| 1874 |
if anchor_tokens.shape[-2] != anchor_num_tokens:
|
| 1875 |
-
raise AssertionError(f"anchor
|
| 1876 |
if dynamic_latents is not None and dynamic_latents.shape[0] > 0:
|
| 1877 |
_expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target(
|
| 1878 |
int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1])
|
| 1879 |
)
|
| 1880 |
if dynamic_tokens.shape[-2] != _expected_dyn:
|
| 1881 |
-
raise AssertionError(f"dynamic
|
| 1882 |
-
if revisit_tokens.shape[-2]
|
| 1883 |
-
raise AssertionError(f"revisit
|
| 1884 |
anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0
|
| 1885 |
dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0
|
| 1886 |
gate_module = getattr(self, "dememwm_revisit_gate", None)
|
|
@@ -1913,7 +1937,6 @@ class MemoryDiTMixin:
|
|
| 1913 |
revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
|
| 1914 |
revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
|
| 1915 |
revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
|
| 1916 |
-
valid_revisit_target_mask = torch.zeros_like(valid_revisit_target_mask)
|
| 1917 |
eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask)
|
| 1918 |
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
|
| 1919 |
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
|
|
@@ -1948,7 +1971,6 @@ class MemoryDiTMixin:
|
|
| 1948 |
"revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
|
| 1949 |
"no_valid_revisit_mask": no_valid_revisit_mask,
|
| 1950 |
"valid_revisit_eff_mask": valid_revisit_eff_mask,
|
| 1951 |
-
"valid_revisit_target_mask": valid_revisit_target_mask,
|
| 1952 |
"revisit_candidate_frame_count_per_target": revisit_candidate_count,
|
| 1953 |
"revisit_selected_frame_count_per_target": revisit_selected_count,
|
| 1954 |
"revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
|
|
@@ -1972,18 +1994,17 @@ class MemoryDiTMixin:
|
|
| 1972 |
"token_patch_size": token_patch_size,
|
| 1973 |
"tokens_per_frame": tokens_per_frame,
|
| 1974 |
"anchor_token_slots": int(anchor_tokens.shape[-2]),
|
| 1975 |
-
"
|
| 1976 |
"anchor_pool_h": anchor_pool_h,
|
| 1977 |
"anchor_pool_w": anchor_pool_w,
|
| 1978 |
"dynamic_token_slots": int(dynamic_tokens.shape[-2]),
|
| 1979 |
-
"
|
| 1980 |
"dynamic_min_gap_to_target": dynamic_min_gap_to_target,
|
| 1981 |
"dynamic_max_gap_to_target": dynamic_max_gap_to_target,
|
| 1982 |
-
"dynamic_exclude_latest_local_frames":
|
| 1983 |
"revisit_token_slots": int(revisit_tokens.shape[-2]),
|
| 1984 |
-
"
|
| 1985 |
-
"revisit_local_context_exclusion_frames":
|
| 1986 |
-
"revisit_high_quality_fov_threshold": high_quality_fov_threshold,
|
| 1987 |
"revisit_pool_h": revisit_pool_h,
|
| 1988 |
"revisit_pool_w": revisit_pool_w,
|
| 1989 |
"revisit_max_frames": revisit_max_frames,
|
|
|
|
| 42 |
"revisit_pose_preselect_selected_count",
|
| 43 |
"revisit_exact_fov_candidate_count",
|
| 44 |
"valid_revisit_frame_count",
|
|
|
|
| 45 |
"no_valid_revisit_count",
|
| 46 |
"revisit_selected_frame_count",
|
| 47 |
"revisit_frame_fov_overlap_mean",
|
|
|
|
| 170 |
if self._cfg_has(memory_cfg, name)
|
| 171 |
]
|
| 172 |
if ratio_fields:
|
| 173 |
+
raise ValueError(f"standalone DeMemWM derives stream slots from latent shape and compression settings, not ratio fields: {ratio_fields}")
|
| 174 |
|
| 175 |
anchor_cfg = self._cfg_get(memory_cfg, "anchor", None)
|
| 176 |
dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None)
|
|
|
|
| 246 |
"revisit_deterministic_fov_plucker_retrieval": True,
|
| 247 |
"revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(),
|
| 248 |
"revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold,
|
|
|
|
| 249 |
"revisit_plucker_weight": plucker_weight,
|
| 250 |
"stage_policy_noise_bucket_logging": True,
|
| 251 |
}
|
|
|
|
| 762 |
if source_positions.numel() <= num_anchors or pose is None:
|
| 763 |
return source_positions[:num_anchors]
|
| 764 |
poses = pose.float()
|
| 765 |
+
pairwise = torch.cdist(poses, poses)
|
| 766 |
+
if not bool((pairwise > 0).any().item()):
|
| 767 |
+
return source_positions[:num_anchors]
|
| 768 |
+
available = torch.ones((int(source_positions.numel()),), device=poses.device, dtype=torch.bool)
|
| 769 |
+
if num_anchors == 1:
|
| 770 |
+
selected = [int(pairwise.mean(dim=1).argmax().item())]
|
| 771 |
+
else:
|
| 772 |
+
first, second = divmod(int(pairwise.argmax().item()), int(pairwise.shape[1]))
|
| 773 |
+
selected = [int(first), int(second)]
|
| 774 |
+
for idx in selected:
|
| 775 |
+
available[idx] = False
|
| 776 |
+
dists = pairwise[selected].min(dim=0).values
|
| 777 |
+
dists = dists.masked_fill(~available, float("-inf"))
|
| 778 |
+
for _ in range(num_anchors - len(selected)):
|
| 779 |
farthest = int(dists.argmax().item())
|
| 780 |
+
if not bool(available[farthest].item()):
|
| 781 |
+
break
|
| 782 |
selected.append(farthest)
|
| 783 |
+
available[farthest] = False
|
| 784 |
+
d_new = pairwise[farthest]
|
| 785 |
dists = torch.minimum(dists, d_new)
|
| 786 |
+
dists = dists.masked_fill(~available, float("-inf"))
|
| 787 |
return source_positions[torch.tensor(sorted(selected), device=source_positions.device)]
|
| 788 |
|
| 789 |
def _build_streaming_cache_records(
|
|
|
|
| 880 |
source_positions = torch.nonzero(non_generated, as_tuple=False).flatten()
|
| 881 |
if source_positions.numel() > 0:
|
| 882 |
if anchor_diverse:
|
| 883 |
+
anchor_source_positions = source_positions[source_positions < self._context_frame_count()]
|
| 884 |
+
if anchor_source_positions.numel() > 0:
|
| 885 |
+
anchor_pose = _pose_subset(anchor_source_positions, batch_idx)
|
| 886 |
+
selected_anchor_positions = self._select_diverse_anchor_positions(
|
| 887 |
+
anchor_source_positions, anchor_pose, len(anchor_indices)
|
| 888 |
+
)
|
| 889 |
+
else:
|
| 890 |
+
selected_anchor_positions = source_positions[:0]
|
| 891 |
else:
|
| 892 |
selected_list = []
|
| 893 |
for anchor_idx in anchor_indices:
|
|
|
|
| 1356 |
}
|
| 1357 |
return anchor_banks, revisit_banks, tokens_per_frame, diagnostics
|
| 1358 |
|
| 1359 |
+
def _causal_cached_revisit_records(
|
| 1360 |
+
self,
|
| 1361 |
+
records: Iterable[MemoryRecord],
|
| 1362 |
+
target_frame: int,
|
| 1363 |
+
) -> list[MemoryRecord]:
|
| 1364 |
+
target_frame = int(target_frame)
|
| 1365 |
+
return [record for record in records if int(record.source_end) <= target_frame]
|
| 1366 |
+
|
| 1367 |
def _records_to_stream(
|
| 1368 |
self,
|
| 1369 |
records,
|
| 1370 |
+
target_slots: int,
|
| 1371 |
hidden_size: int,
|
| 1372 |
device: torch.device,
|
| 1373 |
dtype: torch.dtype,
|
| 1374 |
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
| 1375 |
+
target_slots = max(0, int(target_slots))
|
| 1376 |
record_list = list(records)
|
| 1377 |
+
stacked_tokens, stacked_mask = stack_record_tokens(record_list, target_slots=target_slots)
|
| 1378 |
max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1)
|
| 1379 |
+
if stacked_tokens is None or stacked_mask is None or target_slots == 0:
|
| 1380 |
+
tokens = torch.zeros((target_slots, hidden_size), device=device, dtype=dtype)
|
| 1381 |
+
mask = torch.zeros((target_slots,), device=device, dtype=torch.bool)
|
| 1382 |
return tokens, mask, max_source_frame
|
| 1383 |
+
n = min(target_slots, stacked_tokens.shape[0])
|
| 1384 |
filled = stacked_tokens[:n].to(device=device, dtype=dtype)
|
| 1385 |
filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool)
|
| 1386 |
+
if n < target_slots:
|
| 1387 |
+
pad = filled.new_zeros(target_slots - n, hidden_size)
|
| 1388 |
+
pad_mask = torch.zeros(target_slots - n, device=device, dtype=torch.bool)
|
| 1389 |
tokens = torch.cat([filled, pad], dim=0)
|
| 1390 |
mask = torch.cat([filled_mask, pad_mask], dim=0)
|
| 1391 |
else:
|
|
|
|
| 1545 |
revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size(
|
| 1546 |
revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8
|
| 1547 |
)
|
| 1548 |
+
revisit_target_slots = revisit_max_frames * revisit_pool_h * revisit_pool_w
|
| 1549 |
recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8))
|
| 1550 |
+
dynamic_recent_exclusion_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4))
|
| 1551 |
+
revisit_context_window_exclusion_frames = self._local_context_exclusion_frames()
|
| 1552 |
fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30)
|
| 1553 |
high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70))
|
| 1554 |
plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10))
|
|
|
|
| 1593 |
dtype=stream_dtype,
|
| 1594 |
max_recent_frames=recent_frames,
|
| 1595 |
target_frame_indices=target_frame_indices,
|
| 1596 |
+
exclude_latest_local_frames=dynamic_recent_exclusion_frames,
|
| 1597 |
)
|
| 1598 |
if raw_latents is not None:
|
| 1599 |
dynamic_latents = raw_latents
|
|
|
|
| 1656 |
revisit_pool_h,
|
| 1657 |
revisit_pool_w,
|
| 1658 |
revisit_max_frames,
|
| 1659 |
+
revisit_context_window_exclusion_frames,
|
| 1660 |
fov_overlap_threshold,
|
| 1661 |
plucker_weight,
|
| 1662 |
revisit_retrieval_kwargs,
|
|
|
|
| 1666 |
|
| 1667 |
T_tgt = target_frame_indices.shape[0]
|
| 1668 |
anchor_slots = max(0, anchor_num_tokens)
|
| 1669 |
+
revisit_slots = max(0, revisit_target_slots)
|
| 1670 |
anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT
|
| 1671 |
anchor_include_generated = allow_generated_anchor
|
| 1672 |
anchor_token_rows = []
|
|
|
|
| 1684 |
source_type=anchor_source_type,
|
| 1685 |
include_generated=anchor_include_generated,
|
| 1686 |
max_records=len(anchor_indices),
|
|
|
|
| 1687 |
)
|
| 1688 |
)
|
| 1689 |
anchor_bank.assert_causal(target_frame, records)
|
|
|
|
| 1717 |
"dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1718 |
"dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device),
|
| 1719 |
"dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device),
|
| 1720 |
+
"dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
|
| 1721 |
}
|
| 1722 |
else:
|
| 1723 |
# Pre-select dynamic source frame positions using only frame index metadata
|
|
|
|
| 1729 |
for _b in range(B):
|
| 1730 |
for _j in range(T_tgt):
|
| 1731 |
_target = int(target_frame_indices[_j, _b].item())
|
| 1732 |
+
_valid = (_dfi[:, _b] < _target - dynamic_recent_exclusion_frames).nonzero(as_tuple=False).flatten()
|
| 1733 |
_needed.extend(_valid[-_max_src:].tolist())
|
| 1734 |
if _needed:
|
| 1735 |
_needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long)
|
|
|
|
| 1751 |
_dynamic_pose_small,
|
| 1752 |
target_frame_indices,
|
| 1753 |
_dynamic_gen_small,
|
| 1754 |
+
exclude_latest_local_frames=dynamic_recent_exclusion_frames,
|
| 1755 |
)
|
| 1756 |
|
| 1757 |
dynamic_min_gap_tensor = torch.as_tensor(
|
|
|
|
| 1809 |
revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1810 |
revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32)
|
| 1811 |
revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32)
|
|
|
|
| 1812 |
eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool)
|
| 1813 |
revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long)
|
| 1814 |
eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES)
|
|
|
|
| 1824 |
for target_idx in range(T_tgt):
|
| 1825 |
target_frame = int(target_frame_indices[target_idx, batch_idx].item())
|
| 1826 |
if use_cache_revisit_records:
|
| 1827 |
+
candidate_records = self._causal_cached_revisit_records(
|
| 1828 |
+
revisit_record_batches[batch_idx],
|
| 1829 |
+
target_frame,
|
| 1830 |
+
)
|
| 1831 |
else:
|
| 1832 |
candidate_records = revisit_bank.query(
|
| 1833 |
MemoryBankQuery(
|
|
|
|
| 1841 |
target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx),
|
| 1842 |
target_summary=None,
|
| 1843 |
topk=revisit_max_frames,
|
| 1844 |
+
exclude_local_context_frames=revisit_context_window_exclusion_frames,
|
| 1845 |
fov_overlap_threshold=fov_overlap_threshold,
|
| 1846 |
plucker_weight=plucker_weight,
|
| 1847 |
target_video_id=_target_video_id_or_none(batch_idx, target_idx),
|
|
|
|
| 1866 |
revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0))
|
| 1867 |
revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0))
|
| 1868 |
revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1))
|
|
|
|
| 1869 |
revisit_bank.assert_causal(target_frame, selected_records)
|
| 1870 |
if selected_records:
|
| 1871 |
valid_revisit_mask[batch_idx, target_idx] = True
|
|
|
|
| 1885 |
target_frame=target_frame,
|
| 1886 |
)
|
| 1887 |
eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted)
|
|
|
|
| 1888 |
batch_token_rows.append(stream_tokens)
|
| 1889 |
batch_mask_rows.append(stream_mask)
|
| 1890 |
+
batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long))
|
| 1891 |
revisit_token_rows.append(torch.stack(batch_token_rows, dim=0))
|
| 1892 |
revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0))
|
| 1893 |
revisit_max_rows.append(torch.stack(batch_max_rows, dim=0))
|
|
|
|
| 1896 |
revisit_max = torch.stack(revisit_max_rows, dim=0)
|
| 1897 |
|
| 1898 |
if anchor_tokens.shape[-2] != anchor_num_tokens:
|
| 1899 |
+
raise AssertionError(f"anchor slot count mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}")
|
| 1900 |
if dynamic_latents is not None and dynamic_latents.shape[0] > 0:
|
| 1901 |
_expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target(
|
| 1902 |
int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1])
|
| 1903 |
)
|
| 1904 |
if dynamic_tokens.shape[-2] != _expected_dyn:
|
| 1905 |
+
raise AssertionError(f"dynamic slot count mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}")
|
| 1906 |
+
if revisit_tokens.shape[-2] != revisit_target_slots:
|
| 1907 |
+
raise AssertionError(f"revisit slot count mismatch: got {revisit_tokens.shape[-2]}, expected {revisit_target_slots}")
|
| 1908 |
anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0
|
| 1909 |
dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0
|
| 1910 |
gate_module = getattr(self, "dememwm_revisit_gate", None)
|
|
|
|
| 1937 |
revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap)
|
| 1938 |
revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap)
|
| 1939 |
revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0)
|
|
|
|
| 1940 |
eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask)
|
| 1941 |
valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask)
|
| 1942 |
revisit_gate_raw = torch.zeros_like(revisit_gate_raw)
|
|
|
|
| 1971 |
"revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)),
|
| 1972 |
"no_valid_revisit_mask": no_valid_revisit_mask,
|
| 1973 |
"valid_revisit_eff_mask": valid_revisit_eff_mask,
|
|
|
|
| 1974 |
"revisit_candidate_frame_count_per_target": revisit_candidate_count,
|
| 1975 |
"revisit_selected_frame_count_per_target": revisit_selected_count,
|
| 1976 |
"revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap,
|
|
|
|
| 1994 |
"token_patch_size": token_patch_size,
|
| 1995 |
"tokens_per_frame": tokens_per_frame,
|
| 1996 |
"anchor_token_slots": int(anchor_tokens.shape[-2]),
|
| 1997 |
+
"anchor_target_slots": anchor_num_tokens,
|
| 1998 |
"anchor_pool_h": anchor_pool_h,
|
| 1999 |
"anchor_pool_w": anchor_pool_w,
|
| 2000 |
"dynamic_token_slots": int(dynamic_tokens.shape[-2]),
|
| 2001 |
+
"dynamic_target_slots": int(dynamic_tokens.shape[-2]),
|
| 2002 |
"dynamic_min_gap_to_target": dynamic_min_gap_to_target,
|
| 2003 |
"dynamic_max_gap_to_target": dynamic_max_gap_to_target,
|
| 2004 |
+
"dynamic_exclude_latest_local_frames": dynamic_recent_exclusion_frames,
|
| 2005 |
"revisit_token_slots": int(revisit_tokens.shape[-2]),
|
| 2006 |
+
"revisit_target_slots": revisit_target_slots,
|
| 2007 |
+
"revisit_local_context_exclusion_frames": revisit_context_window_exclusion_frames,
|
|
|
|
| 2008 |
"revisit_pool_h": revisit_pool_h,
|
| 2009 |
"revisit_pool_w": revisit_pool_w,
|
| 2010 |
"revisit_max_frames": revisit_max_frames,
|
algorithms/worldmem/dememwm/cache.py
CHANGED
|
@@ -39,7 +39,6 @@ class StreamingCache:
|
|
| 39 |
no_evict: bool = True,
|
| 40 |
clear_between_videos: bool = True,
|
| 41 |
max_records: Optional[int] = None,
|
| 42 |
-
max_slots: Optional[int] = None,
|
| 43 |
on_capacity_exceeded: str = "warn",
|
| 44 |
) -> None:
|
| 45 |
self.enabled = bool(enabled)
|
|
@@ -51,7 +50,6 @@ class StreamingCache:
|
|
| 51 |
self.no_evict = bool(no_evict)
|
| 52 |
self.clear_between_videos = bool(clear_between_videos)
|
| 53 |
self.max_records = max_records
|
| 54 |
-
self.max_slots = max_slots
|
| 55 |
self.on_capacity_exceeded = str(on_capacity_exceeded or "warn")
|
| 56 |
if self.eviction_policy != "none" or not self.no_evict:
|
| 57 |
raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true")
|
|
@@ -92,7 +90,6 @@ class StreamingCache:
|
|
| 92 |
no_evict=bool(get("no_evict", True)),
|
| 93 |
clear_between_videos=bool(get("clear_between_videos", True)),
|
| 94 |
max_records=get("max_records", None),
|
| 95 |
-
max_slots=get("max_slots", None),
|
| 96 |
on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")),
|
| 97 |
)
|
| 98 |
|
|
@@ -213,14 +210,12 @@ class StreamingCache:
|
|
| 213 |
exceeded = False
|
| 214 |
if self.max_records is not None and self.record_count > int(self.max_records):
|
| 215 |
exceeded = True
|
| 216 |
-
if self.max_slots is not None and self.slot_count > int(self.max_slots):
|
| 217 |
-
exceeded = True
|
| 218 |
if not exceeded:
|
| 219 |
return
|
| 220 |
self.capacity_exceeded_count += 1
|
| 221 |
msg = (
|
| 222 |
"DeMemWMStreamingCache capacity exceeded "
|
| 223 |
-
f"records={self.record_count}/{self.max_records}
|
| 224 |
"no eviction performed because no_evict=true"
|
| 225 |
)
|
| 226 |
if self.on_capacity_exceeded == "error":
|
|
|
|
| 39 |
no_evict: bool = True,
|
| 40 |
clear_between_videos: bool = True,
|
| 41 |
max_records: Optional[int] = None,
|
|
|
|
| 42 |
on_capacity_exceeded: str = "warn",
|
| 43 |
) -> None:
|
| 44 |
self.enabled = bool(enabled)
|
|
|
|
| 50 |
self.no_evict = bool(no_evict)
|
| 51 |
self.clear_between_videos = bool(clear_between_videos)
|
| 52 |
self.max_records = max_records
|
|
|
|
| 53 |
self.on_capacity_exceeded = str(on_capacity_exceeded or "warn")
|
| 54 |
if self.eviction_policy != "none" or not self.no_evict:
|
| 55 |
raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true")
|
|
|
|
| 90 |
no_evict=bool(get("no_evict", True)),
|
| 91 |
clear_between_videos=bool(get("clear_between_videos", True)),
|
| 92 |
max_records=get("max_records", None),
|
|
|
|
| 93 |
on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")),
|
| 94 |
)
|
| 95 |
|
|
|
|
| 210 |
exceeded = False
|
| 211 |
if self.max_records is not None and self.record_count > int(self.max_records):
|
| 212 |
exceeded = True
|
|
|
|
|
|
|
| 213 |
if not exceeded:
|
| 214 |
return
|
| 215 |
self.capacity_exceeded_count += 1
|
| 216 |
msg = (
|
| 217 |
"DeMemWMStreamingCache capacity exceeded "
|
| 218 |
+
f"records={self.record_count}/{self.max_records}; "
|
| 219 |
"no eviction performed because no_evict=true"
|
| 220 |
)
|
| 221 |
if self.on_capacity_exceeded == "error":
|
algorithms/worldmem/dememwm/diagnostics.py
CHANGED
|
@@ -64,7 +64,6 @@ def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], vali
|
|
| 64 |
exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
|
| 65 |
valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
|
| 66 |
valid_count_mean = float(valid_count / target_count) if target_count else 0.0
|
| 67 |
-
valid_target_count = sum(int(diag.get("valid_revisit_target_count", diag.get("high_quality_selected_revisit", 0))) for diag in result_diagnostics)
|
| 68 |
selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
|
| 69 |
no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
|
| 70 |
abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
|
|
@@ -78,7 +77,6 @@ def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], vali
|
|
| 78 |
"revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
|
| 79 |
"valid_revisit_frame_count": valid_count_mean,
|
| 80 |
"valid_revisit_count": valid_count_mean,
|
| 81 |
-
"valid_revisit_target_count": int(valid_target_count),
|
| 82 |
"no_valid_revisit_count": int(no_valid_count),
|
| 83 |
"valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
|
| 84 |
"revisit_selected_frame_count": int(selected_count),
|
|
|
|
| 64 |
exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
|
| 65 |
valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
|
| 66 |
valid_count_mean = float(valid_count / target_count) if target_count else 0.0
|
|
|
|
| 67 |
selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
|
| 68 |
no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
|
| 69 |
abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
|
|
|
|
| 77 |
"revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
|
| 78 |
"valid_revisit_frame_count": valid_count_mean,
|
| 79 |
"valid_revisit_count": valid_count_mean,
|
|
|
|
| 80 |
"no_valid_revisit_count": int(no_valid_count),
|
| 81 |
"valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
|
| 82 |
"revisit_selected_frame_count": int(selected_count),
|
algorithms/worldmem/dememwm/memory.py
CHANGED
|
@@ -15,15 +15,13 @@ class MemoryBankQuery:
|
|
| 15 |
source_type: Optional[MemorySourceType] = None
|
| 16 |
include_generated: bool = True
|
| 17 |
max_records: Optional[int] = None
|
| 18 |
-
max_slots: Optional[int] = None
|
| 19 |
|
| 20 |
|
| 21 |
class CausalMemoryBank:
|
| 22 |
"""Small causal memory bank for DeMemWM records."""
|
| 23 |
|
| 24 |
-
def __init__(self, max_records: Optional[int] = None
|
| 25 |
self.max_records = max_records
|
| 26 |
-
self.max_slots = max_slots
|
| 27 |
self._records: list[MemoryRecord] = []
|
| 28 |
|
| 29 |
def __len__(self) -> int:
|
|
@@ -172,7 +170,6 @@ class CausalMemoryBank:
|
|
| 172 |
if isinstance(query, int):
|
| 173 |
query = MemoryBankQuery(target_frame=query, **kwargs)
|
| 174 |
out: list[MemoryRecord] = []
|
| 175 |
-
used_slots = 0
|
| 176 |
for record in self._records:
|
| 177 |
if int(record.source_end) > int(query.target_frame):
|
| 178 |
continue
|
|
@@ -180,15 +177,9 @@ class CausalMemoryBank:
|
|
| 180 |
continue
|
| 181 |
if not query.include_generated and record.is_generated:
|
| 182 |
continue
|
| 183 |
-
if query.max_slots is not None and used_slots >= query.max_slots:
|
| 184 |
-
break
|
| 185 |
out.append(record)
|
| 186 |
-
if query.max_slots is not None:
|
| 187 |
-
used_slots += record.valid_slots
|
| 188 |
if query.max_records is not None and len(out) >= query.max_records:
|
| 189 |
break
|
| 190 |
-
if query.max_slots is not None and used_slots >= query.max_slots:
|
| 191 |
-
break
|
| 192 |
return out
|
| 193 |
|
| 194 |
def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
|
|
@@ -197,12 +188,13 @@ class CausalMemoryBank:
|
|
| 197 |
raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
|
| 198 |
|
| 199 |
|
| 200 |
-
def stack_record_tokens(records: list[MemoryRecord],
|
| 201 |
if not records:
|
| 202 |
return None, None
|
| 203 |
tokens = torch.cat([r.tokens for r in records], dim=0)
|
| 204 |
mask = torch.cat([r.mask.bool() for r in records], dim=0)
|
| 205 |
-
if
|
| 206 |
-
|
| 207 |
-
|
|
|
|
| 208 |
return tokens, mask
|
|
|
|
| 15 |
source_type: Optional[MemorySourceType] = None
|
| 16 |
include_generated: bool = True
|
| 17 |
max_records: Optional[int] = None
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
class CausalMemoryBank:
|
| 21 |
"""Small causal memory bank for DeMemWM records."""
|
| 22 |
|
| 23 |
+
def __init__(self, max_records: Optional[int] = None):
|
| 24 |
self.max_records = max_records
|
|
|
|
| 25 |
self._records: list[MemoryRecord] = []
|
| 26 |
|
| 27 |
def __len__(self) -> int:
|
|
|
|
| 170 |
if isinstance(query, int):
|
| 171 |
query = MemoryBankQuery(target_frame=query, **kwargs)
|
| 172 |
out: list[MemoryRecord] = []
|
|
|
|
| 173 |
for record in self._records:
|
| 174 |
if int(record.source_end) > int(query.target_frame):
|
| 175 |
continue
|
|
|
|
| 177 |
continue
|
| 178 |
if not query.include_generated and record.is_generated:
|
| 179 |
continue
|
|
|
|
|
|
|
| 180 |
out.append(record)
|
|
|
|
|
|
|
| 181 |
if query.max_records is not None and len(out) >= query.max_records:
|
| 182 |
break
|
|
|
|
|
|
|
| 183 |
return out
|
| 184 |
|
| 185 |
def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
|
|
|
|
| 188 |
raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
|
| 189 |
|
| 190 |
|
| 191 |
+
def stack_record_tokens(records: list[MemoryRecord], target_slots: int | None = None):
|
| 192 |
if not records:
|
| 193 |
return None, None
|
| 194 |
tokens = torch.cat([r.tokens for r in records], dim=0)
|
| 195 |
mask = torch.cat([r.mask.bool() for r in records], dim=0)
|
| 196 |
+
if target_slots is not None:
|
| 197 |
+
valid_idx = mask.nonzero(as_tuple=False).flatten()
|
| 198 |
+
tokens = tokens.index_select(0, valid_idx)[:target_slots]
|
| 199 |
+
mask = mask.index_select(0, valid_idx)[:target_slots]
|
| 200 |
return tokens, mask
|
algorithms/worldmem/dememwm/retrieval.py
CHANGED
|
@@ -427,7 +427,6 @@ def deterministic_revisit_retrieval(
|
|
| 427 |
"revisit_candidate_count": len(causal_records),
|
| 428 |
"valid_revisit_frame_count": len(valid_labels),
|
| 429 |
"valid_revisit_count": len(valid_labels),
|
| 430 |
-
"valid_revisit_target_count": high_quality_selected,
|
| 431 |
"no_valid_revisit_count": int(len(valid_labels) == 0),
|
| 432 |
"valid_revisit_mask": int(len(valid_labels) > 0),
|
| 433 |
"revisit_abstained_count": int(len(selected_records) == 0),
|
|
|
|
| 427 |
"revisit_candidate_count": len(causal_records),
|
| 428 |
"valid_revisit_frame_count": len(valid_labels),
|
| 429 |
"valid_revisit_count": len(valid_labels),
|
|
|
|
| 430 |
"no_valid_revisit_count": int(len(valid_labels) == 0),
|
| 431 |
"valid_revisit_mask": int(len(valid_labels) > 0),
|
| 432 |
"revisit_abstained_count": int(len(selected_records) == 0),
|
configurations/algorithm/dememwm_memory_dit.yaml
CHANGED
|
@@ -93,7 +93,6 @@ dememwm:
|
|
| 93 |
no_evict: true
|
| 94 |
clear_between_videos: true
|
| 95 |
max_records: null
|
| 96 |
-
max_slots: null
|
| 97 |
on_capacity_exceeded: warn
|
| 98 |
checkpoint:
|
| 99 |
strict_dememwm_eval_load: true
|
|
|
|
| 93 |
no_evict: true
|
| 94 |
clear_between_videos: true
|
| 95 |
max_records: null
|
|
|
|
| 96 |
on_capacity_exceeded: warn
|
| 97 |
checkpoint:
|
| 98 |
strict_dememwm_eval_load: true
|
scripts/dememwm_full_eval.slurm
CHANGED
|
@@ -150,7 +150,6 @@ EVAL_ARGS=(
|
|
| 150 |
"++algorithm.dememwm.cache.no_evict=true"
|
| 151 |
"++algorithm.dememwm.cache.clear_between_videos=true"
|
| 152 |
"++algorithm.dememwm.cache.max_records=null"
|
| 153 |
-
"++algorithm.dememwm.cache.max_slots=null"
|
| 154 |
"++algorithm.dememwm.cache.on_capacity_exceeded=warn"
|
| 155 |
"experiment.validation.batch_size=${VAL_BATCH_SIZE}"
|
| 156 |
"experiment.validation.limit_batch=${VAL_LIMIT}"
|
|
|
|
| 150 |
"++algorithm.dememwm.cache.no_evict=true"
|
| 151 |
"++algorithm.dememwm.cache.clear_between_videos=true"
|
| 152 |
"++algorithm.dememwm.cache.max_records=null"
|
|
|
|
| 153 |
"++algorithm.dememwm.cache.on_capacity_exceeded=warn"
|
| 154 |
"experiment.validation.batch_size=${VAL_BATCH_SIZE}"
|
| 155 |
"experiment.validation.limit_batch=${VAL_LIMIT}"
|
scripts/dememwm_full_train.slurm
CHANGED
|
@@ -69,8 +69,7 @@ srun python -m main \
|
|
| 69 |
++algorithm.dememwm.dynamic.recent_frames=4 \
|
| 70 |
++algorithm.dememwm.revisit.enabled=true \
|
| 71 |
++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
|
| 72 |
-
++algorithm.dememwm.revisit.fov_overlap_threshold=0.
|
| 73 |
-
++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70 \
|
| 74 |
++algorithm.dememwm.revisit.pose_preselect_topk=64 \
|
| 75 |
++algorithm.dememwm.revisit.fov_yaw_samples=25 \
|
| 76 |
++algorithm.dememwm.revisit.fov_pitch_samples=20 \
|
|
@@ -87,7 +86,6 @@ srun python -m main \
|
|
| 87 |
++algorithm.dememwm.cache.no_evict=true \
|
| 88 |
++algorithm.dememwm.cache.clear_between_videos=true \
|
| 89 |
++algorithm.dememwm.cache.max_records=null \
|
| 90 |
-
++algorithm.dememwm.cache.max_slots=null \
|
| 91 |
++algorithm.dememwm.cache.on_capacity_exceeded=warn \
|
| 92 |
++algorithm.dememwm.curriculum.enabled=true \
|
| 93 |
++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
|
|
|
|
| 69 |
++algorithm.dememwm.dynamic.recent_frames=4 \
|
| 70 |
++algorithm.dememwm.revisit.enabled=true \
|
| 71 |
++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
|
| 72 |
+
++algorithm.dememwm.revisit.fov_overlap_threshold=0.60 \
|
|
|
|
| 73 |
++algorithm.dememwm.revisit.pose_preselect_topk=64 \
|
| 74 |
++algorithm.dememwm.revisit.fov_yaw_samples=25 \
|
| 75 |
++algorithm.dememwm.revisit.fov_pitch_samples=20 \
|
|
|
|
| 86 |
++algorithm.dememwm.cache.no_evict=true \
|
| 87 |
++algorithm.dememwm.cache.clear_between_videos=true \
|
| 88 |
++algorithm.dememwm.cache.max_records=null \
|
|
|
|
| 89 |
++algorithm.dememwm.cache.on_capacity_exceeded=warn \
|
| 90 |
++algorithm.dememwm.curriculum.enabled=true \
|
| 91 |
++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
|
tests/test_dememwm_config_static.py
CHANGED
|
@@ -72,8 +72,6 @@ def test_full_scripts_use_consumed_contract_overrides():
|
|
| 72 |
required = [
|
| 73 |
"algorithm.dememwm.dynamic.exclude_latest_local_frames=4",
|
| 74 |
"algorithm.dememwm.revisit.deterministic_pose_retrieval=true",
|
| 75 |
-
"algorithm.dememwm.revisit.fov_overlap_threshold=0.30",
|
| 76 |
-
"algorithm.dememwm.revisit.high_quality_fov_threshold=0.70",
|
| 77 |
"algorithm.dememwm.revisit.pose_preselect_topk=64",
|
| 78 |
"algorithm.dememwm.revisit.fov_yaw_samples=25",
|
| 79 |
"algorithm.dememwm.revisit.fov_pitch_samples=20",
|
|
@@ -98,9 +96,18 @@ def test_full_scripts_use_consumed_contract_overrides():
|
|
| 98 |
"algorithm.dememwm.revisit.generated_penalty",
|
| 99 |
"algorithm.dememwm.rollout.",
|
| 100 |
]
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
text = Path(rel).read_text()
|
| 103 |
-
for token in required:
|
| 104 |
assert token in text, f"{token} missing from {rel}"
|
| 105 |
for token in stale:
|
| 106 |
assert token not in text, f"stale {token} override remains in {rel}"
|
|
@@ -145,7 +152,6 @@ def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
|
|
| 145 |
"valid_revisit_mask",
|
| 146 |
"revisit_candidate_frame_count",
|
| 147 |
"valid_candidate_label_count",
|
| 148 |
-
"valid_revisit_target_count",
|
| 149 |
"valid_revisit_frame_count",
|
| 150 |
"no_valid_revisit_count",
|
| 151 |
"revisit_selected_frame_count",
|
|
@@ -180,7 +186,6 @@ def test_eval_ablation_and_noise_bucket_logging_contracts():
|
|
| 180 |
schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
|
| 181 |
diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
|
| 182 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
| 183 |
-
matrix = Path("scripts/dememwm_eval_ablation_matrix.sh").read_text()
|
| 184 |
for branch in [
|
| 185 |
"memory_off",
|
| 186 |
"A_only",
|
|
@@ -197,7 +202,6 @@ def test_eval_ablation_and_noise_bucket_logging_contracts():
|
|
| 197 |
"local_context_overlap_fake_revisit",
|
| 198 |
]:
|
| 199 |
assert branch in schedules
|
| 200 |
-
assert branch in matrix
|
| 201 |
for token in [
|
| 202 |
"noise_bucket_from_denoising_fraction",
|
| 203 |
"noise_bucket_from_noise_levels",
|
|
|
|
| 72 |
required = [
|
| 73 |
"algorithm.dememwm.dynamic.exclude_latest_local_frames=4",
|
| 74 |
"algorithm.dememwm.revisit.deterministic_pose_retrieval=true",
|
|
|
|
|
|
|
| 75 |
"algorithm.dememwm.revisit.pose_preselect_topk=64",
|
| 76 |
"algorithm.dememwm.revisit.fov_yaw_samples=25",
|
| 77 |
"algorithm.dememwm.revisit.fov_pitch_samples=20",
|
|
|
|
| 96 |
"algorithm.dememwm.revisit.generated_penalty",
|
| 97 |
"algorithm.dememwm.rollout.",
|
| 98 |
]
|
| 99 |
+
expected_by_script = {
|
| 100 |
+
"scripts/dememwm_full_train.slurm": [
|
| 101 |
+
"algorithm.dememwm.revisit.fov_overlap_threshold=0.60",
|
| 102 |
+
],
|
| 103 |
+
"scripts/dememwm_full_eval.slurm": [
|
| 104 |
+
"algorithm.dememwm.revisit.fov_overlap_threshold=0.30",
|
| 105 |
+
"algorithm.dememwm.revisit.high_quality_fov_threshold=0.70",
|
| 106 |
+
],
|
| 107 |
+
}
|
| 108 |
+
for rel, script_specific_required in expected_by_script.items():
|
| 109 |
text = Path(rel).read_text()
|
| 110 |
+
for token in required + script_specific_required:
|
| 111 |
assert token in text, f"{token} missing from {rel}"
|
| 112 |
for token in stale:
|
| 113 |
assert token not in text, f"stale {token} override remains in {rel}"
|
|
|
|
| 152 |
"valid_revisit_mask",
|
| 153 |
"revisit_candidate_frame_count",
|
| 154 |
"valid_candidate_label_count",
|
|
|
|
| 155 |
"valid_revisit_frame_count",
|
| 156 |
"no_valid_revisit_count",
|
| 157 |
"revisit_selected_frame_count",
|
|
|
|
| 186 |
schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
|
| 187 |
diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
|
| 188 |
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
|
|
|
|
| 189 |
for branch in [
|
| 190 |
"memory_off",
|
| 191 |
"A_only",
|
|
|
|
| 202 |
"local_context_overlap_fake_revisit",
|
| 203 |
]:
|
| 204 |
assert branch in schedules
|
|
|
|
| 205 |
for token in [
|
| 206 |
"noise_bucket_from_denoising_fraction",
|
| 207 |
"noise_bucket_from_noise_levels",
|
tests/test_dememwm_memory.py
CHANGED
|
@@ -44,12 +44,46 @@ def test_all_false_masks_are_valid_abstention_outputs():
|
|
| 44 |
assert mask.sum().item() == 0
|
| 45 |
|
| 46 |
|
| 47 |
-
def
|
| 48 |
bank = CausalMemoryBank(max_records=10)
|
| 49 |
for f in range(6):
|
| 50 |
bank.add_record(_record(f, slots=2))
|
| 51 |
-
records = bank.query(MemoryBankQuery(target_frame=10, max_records=2
|
| 52 |
assert len(records) == 2
|
| 53 |
-
tokens, mask = stack_record_tokens(records,
|
| 54 |
assert tokens.shape[0] == 3
|
| 55 |
assert mask.shape[0] == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
assert mask.sum().item() == 0
|
| 45 |
|
| 46 |
|
| 47 |
+
def test_query_caps_records_and_stack_uses_target_slots():
|
| 48 |
bank = CausalMemoryBank(max_records=10)
|
| 49 |
for f in range(6):
|
| 50 |
bank.add_record(_record(f, slots=2))
|
| 51 |
+
records = bank.query(MemoryBankQuery(target_frame=10, max_records=2))
|
| 52 |
assert len(records) == 2
|
| 53 |
+
tokens, mask = stack_record_tokens(records, target_slots=3)
|
| 54 |
assert tokens.shape[0] == 3
|
| 55 |
assert mask.shape[0] == 3
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_target_slots_ignore_masked_slots_when_stacking_records():
|
| 59 |
+
invalid = MemoryRecord(
|
| 60 |
+
tokens=torch.ones(4, 4),
|
| 61 |
+
mask=torch.zeros(4, dtype=torch.bool),
|
| 62 |
+
source_start=0,
|
| 63 |
+
source_end=1,
|
| 64 |
+
frame_indices=torch.tensor([0]),
|
| 65 |
+
pose=None,
|
| 66 |
+
source_type=MemorySourceType.REVISIT,
|
| 67 |
+
is_generated=False,
|
| 68 |
+
chunk_id="invalid",
|
| 69 |
+
)
|
| 70 |
+
valid = MemoryRecord(
|
| 71 |
+
tokens=torch.ones(2, 4) * 2,
|
| 72 |
+
mask=torch.ones(2, dtype=torch.bool),
|
| 73 |
+
source_start=1,
|
| 74 |
+
source_end=2,
|
| 75 |
+
frame_indices=torch.tensor([1]),
|
| 76 |
+
pose=None,
|
| 77 |
+
source_type=MemorySourceType.REVISIT,
|
| 78 |
+
is_generated=False,
|
| 79 |
+
chunk_id="valid",
|
| 80 |
+
)
|
| 81 |
+
bank = CausalMemoryBank()
|
| 82 |
+
bank.add_record(invalid)
|
| 83 |
+
bank.add_record(valid)
|
| 84 |
+
|
| 85 |
+
records = bank.query(MemoryBankQuery(target_frame=3))
|
| 86 |
+
tokens, mask = stack_record_tokens(records, target_slots=2)
|
| 87 |
+
|
| 88 |
+
assert mask.tolist() == [True, True]
|
| 89 |
+
assert torch.equal(tokens, torch.ones(2, 4) * 2)
|
tests/test_dememwm_noise_bucket.py
CHANGED
|
@@ -92,7 +92,6 @@ def test_noise_bucket_log_allowlist_keeps_target_counts_only():
|
|
| 92 |
"noise_bucket_low_target_count",
|
| 93 |
"revisit_candidate_frame_count",
|
| 94 |
"valid_revisit_frame_count",
|
| 95 |
-
"valid_revisit_target_count",
|
| 96 |
"revisit_selected_frame_count",
|
| 97 |
"revisit_frame_fov_overlap_mean",
|
| 98 |
"revisit_best_selected_frame_fov_overlap_mean",
|
|
|
|
| 92 |
"noise_bucket_low_target_count",
|
| 93 |
"revisit_candidate_frame_count",
|
| 94 |
"valid_revisit_frame_count",
|
|
|
|
| 95 |
"revisit_selected_frame_count",
|
| 96 |
"revisit_frame_fov_overlap_mean",
|
| 97 |
"revisit_best_selected_frame_fov_overlap_mean",
|
tests/test_dememwm_preselection.py
CHANGED
|
@@ -56,6 +56,50 @@ def test_revisit_local_context_exclusion_uses_n_tokens_times_frame_stack():
|
|
| 56 |
assert harness._local_context_exclusion_frames() == 8
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
|
| 60 |
harness = Harness()
|
| 61 |
harness.context_frames = 2
|
|
@@ -92,6 +136,31 @@ def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
|
|
| 92 |
assert diag["preselected_anchor_projected_frame_count"] == 2
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
def test_preselected_memory_banks_project_only_selected_frames():
|
| 96 |
harness = Harness()
|
| 97 |
latents = torch.randn(20, 1, 3, 2, 2)
|
|
|
|
| 56 |
assert harness._local_context_exclusion_frames() == 8
|
| 57 |
|
| 58 |
|
| 59 |
+
def test_diverse_anchor_selection_does_not_repeat_tied_pose_indices():
|
| 60 |
+
harness = Harness()
|
| 61 |
+
source_positions = torch.arange(5)
|
| 62 |
+
poses = torch.zeros((5, 5), dtype=torch.float32)
|
| 63 |
+
|
| 64 |
+
selected = harness._select_diverse_anchor_positions(source_positions, poses, 4)
|
| 65 |
+
|
| 66 |
+
assert selected.tolist() == [0, 1, 2, 3]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_diverse_anchor_selection_seeds_from_widest_pose_pair():
|
| 70 |
+
harness = Harness()
|
| 71 |
+
source_positions = torch.arange(4)
|
| 72 |
+
poses = torch.tensor([[0.0], [-10.0], [10.0], [0.1]], dtype=torch.float32)
|
| 73 |
+
|
| 74 |
+
selected = harness._select_diverse_anchor_positions(source_positions, poses, 2)
|
| 75 |
+
|
| 76 |
+
assert selected.tolist() == [1, 2]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def test_cached_revisit_prefilter_keeps_only_causal_records():
|
| 80 |
+
harness = Harness()
|
| 81 |
+
|
| 82 |
+
def record(frame: int) -> MemoryRecord:
|
| 83 |
+
return MemoryRecord(
|
| 84 |
+
tokens=torch.zeros((1, 8)),
|
| 85 |
+
mask=torch.ones(1, dtype=torch.bool),
|
| 86 |
+
source_start=frame,
|
| 87 |
+
source_end=frame + 1,
|
| 88 |
+
frame_indices=torch.tensor([frame]),
|
| 89 |
+
pose=None,
|
| 90 |
+
source_type=MemorySourceType.REVISIT,
|
| 91 |
+
is_generated=False,
|
| 92 |
+
chunk_id=f"revisit_{frame}",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
selected = harness._causal_cached_revisit_records(
|
| 96 |
+
(record(0), record(2), record(5)),
|
| 97 |
+
target_frame=3,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
assert [record.source_start for record in selected] == [0, 2]
|
| 101 |
+
|
| 102 |
+
|
| 103 |
def test_diverse_anchor_selection_uses_context_frames_not_literal_limit():
|
| 104 |
harness = Harness()
|
| 105 |
harness.context_frames = 2
|
|
|
|
| 136 |
assert diag["preselected_anchor_projected_frame_count"] == 2
|
| 137 |
|
| 138 |
|
| 139 |
+
def test_streaming_diverse_anchor_selection_uses_context_frames():
|
| 140 |
+
harness = Harness()
|
| 141 |
+
harness.context_frames = 2
|
| 142 |
+
latents = torch.randn(8, 1, 3, 2, 2)
|
| 143 |
+
frame_indices = torch.arange(8)[:, None]
|
| 144 |
+
poses = torch.zeros((8, 1, 5), dtype=torch.float32)
|
| 145 |
+
|
| 146 |
+
anchor_banks, _ = harness._build_streaming_cache_records(
|
| 147 |
+
source_latents=latents,
|
| 148 |
+
source_frame_indices=frame_indices,
|
| 149 |
+
source_is_generated=None,
|
| 150 |
+
pose=poses,
|
| 151 |
+
action=None,
|
| 152 |
+
allow_generated_anchor=False,
|
| 153 |
+
anchor_indices=[0, 1, 2, 3],
|
| 154 |
+
anchor_pool_h=1,
|
| 155 |
+
anchor_pool_w=1,
|
| 156 |
+
anchor_diverse=True,
|
| 157 |
+
token_patch_size=2,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1]
|
| 161 |
+
assert harness.project_call_lengths == [2]
|
| 162 |
+
|
| 163 |
+
|
| 164 |
def test_preselected_memory_banks_project_only_selected_frames():
|
| 165 |
harness = Harness()
|
| 166 |
latents = torch.randn(20, 1, 3, 2, 2)
|
tests/test_dememwm_retrieval.py
CHANGED
|
@@ -96,7 +96,6 @@ def test_revisit_candidates_require_causal_c_short_gap():
|
|
| 96 |
assert result.diagnostics["valid_revisit_frame_count"] == 1
|
| 97 |
assert result.diagnostics["valid_revisit_count"] == 1
|
| 98 |
assert result.diagnostics["valid_candidate_label_count"] == 1
|
| 99 |
-
assert result.diagnostics["valid_revisit_target_count"] == 1
|
| 100 |
assert result.diagnostics["revisit_min_gap_to_target"] == 5
|
| 101 |
assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
|
| 102 |
|
|
@@ -106,7 +105,6 @@ def test_revisit_abstains_when_no_valid_candidate():
|
|
| 106 |
assert result.records == []
|
| 107 |
assert result.diagnostics["abstained"] is True
|
| 108 |
assert result.diagnostics["valid_revisit_mask"] == 0
|
| 109 |
-
assert result.diagnostics["valid_revisit_target_count"] == 0
|
| 110 |
assert result.diagnostics["no_valid_revisit_count"] == 1
|
| 111 |
|
| 112 |
|
|
@@ -155,7 +153,6 @@ def test_fov_threshold_filters_candidates_without_action():
|
|
| 155 |
assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
|
| 156 |
assert result.diagnostics["valid_revisit_frame_count"] == 1
|
| 157 |
assert result.diagnostics["valid_revisit_count"] == 1
|
| 158 |
-
assert result.diagnostics["valid_revisit_target_count"] == 1
|
| 159 |
assert result.diagnostics["best_selected_fov_overlap"] == 1.0
|
| 160 |
assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0
|
| 161 |
assert result.diagnostics["best_selected_gap_frames"] == 10
|
|
@@ -206,7 +203,6 @@ def test_selected_frame_carries_frame_metadata_for_projection():
|
|
| 206 |
assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
|
| 207 |
assert result.diagnostics["best_selected_frame_index"] == 1
|
| 208 |
assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0
|
| 209 |
-
assert result.diagnostics["valid_revisit_target_count"] == 1
|
| 210 |
|
| 211 |
|
| 212 |
def test_high_quality_threshold_is_selected_target_diagnostic_only():
|
|
@@ -221,7 +217,6 @@ def test_high_quality_threshold_is_selected_target_diagnostic_only():
|
|
| 221 |
)
|
| 222 |
assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
|
| 223 |
assert result.diagnostics["valid_revisit_count"] == 1
|
| 224 |
-
assert result.diagnostics["valid_revisit_target_count"] == 0
|
| 225 |
assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70
|
| 226 |
|
| 227 |
|
|
|
|
| 96 |
assert result.diagnostics["valid_revisit_frame_count"] == 1
|
| 97 |
assert result.diagnostics["valid_revisit_count"] == 1
|
| 98 |
assert result.diagnostics["valid_candidate_label_count"] == 1
|
|
|
|
| 99 |
assert result.diagnostics["revisit_min_gap_to_target"] == 5
|
| 100 |
assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1
|
| 101 |
|
|
|
|
| 105 |
assert result.records == []
|
| 106 |
assert result.diagnostics["abstained"] is True
|
| 107 |
assert result.diagnostics["valid_revisit_mask"] == 0
|
|
|
|
| 108 |
assert result.diagnostics["no_valid_revisit_count"] == 1
|
| 109 |
|
| 110 |
|
|
|
|
| 153 |
assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
|
| 154 |
assert result.diagnostics["valid_revisit_frame_count"] == 1
|
| 155 |
assert result.diagnostics["valid_revisit_count"] == 1
|
|
|
|
| 156 |
assert result.diagnostics["best_selected_fov_overlap"] == 1.0
|
| 157 |
assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0
|
| 158 |
assert result.diagnostics["best_selected_gap_frames"] == 10
|
|
|
|
| 203 |
assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True
|
| 204 |
assert result.diagnostics["best_selected_frame_index"] == 1
|
| 205 |
assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
def test_high_quality_threshold_is_selected_target_diagnostic_only():
|
|
|
|
| 217 |
)
|
| 218 |
assert result.diagnostics["selected_frame_record_ids"] == ["c0"]
|
| 219 |
assert result.diagnostics["valid_revisit_count"] == 1
|
|
|
|
| 220 |
assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70
|
| 221 |
|
| 222 |
|
tests/test_dememwm_stream_grad.py
CHANGED
|
@@ -23,7 +23,7 @@ def test_records_to_stream_preserves_grad_to_record_tokens():
|
|
| 23 |
tokens, mask, max_source = MemoryDiTMixin._records_to_stream(
|
| 24 |
object(),
|
| 25 |
[record],
|
| 26 |
-
|
| 27 |
hidden_size=4,
|
| 28 |
device=torch.device("cpu"),
|
| 29 |
dtype=torch.float32,
|
|
|
|
| 23 |
tokens, mask, max_source = MemoryDiTMixin._records_to_stream(
|
| 24 |
object(),
|
| 25 |
[record],
|
| 26 |
+
target_slots=4,
|
| 27 |
hidden_size=4,
|
| 28 |
device=torch.device("cpu"),
|
| 29 |
dtype=torch.float32,
|
train_dememwm_full_berzelius.sh
CHANGED
|
@@ -23,10 +23,10 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
|
| 23 |
export WANDB_DISABLED=true
|
| 24 |
export HYDRA_FULL_ERROR=1
|
| 25 |
|
| 26 |
-
OUTPUT_DIR=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/
|
| 27 |
|
| 28 |
srun python -m main \
|
| 29 |
-
+name=
|
| 30 |
+output_dir="${OUTPUT_DIR}/" \
|
| 31 |
auto_resume=true \
|
| 32 |
experiment.tasks=[training] \
|
|
@@ -40,7 +40,7 @@ srun python -m main \
|
|
| 40 |
dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
|
| 41 |
dataset.precomputed_feature_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft/vae_features \
|
| 42 |
dataset.n_frames=1000 \
|
| 43 |
-
+dataset.n_frames_valid=
|
| 44 |
+dataset.customized_validation=true \
|
| 45 |
+dataset.memory_condition_length=0 \
|
| 46 |
+dataset.wo_updown=false \
|
|
@@ -68,8 +68,7 @@ srun python -m main \
|
|
| 68 |
++algorithm.dememwm.dynamic.recent_frames=4 \
|
| 69 |
++algorithm.dememwm.revisit.enabled=true \
|
| 70 |
++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
|
| 71 |
-
++algorithm.dememwm.revisit.fov_overlap_threshold=0.
|
| 72 |
-
++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70 \
|
| 73 |
++algorithm.dememwm.revisit.pose_preselect_topk=64 \
|
| 74 |
++algorithm.dememwm.revisit.fov_yaw_samples=25 \
|
| 75 |
++algorithm.dememwm.revisit.fov_pitch_samples=20 \
|
|
@@ -86,7 +85,6 @@ srun python -m main \
|
|
| 86 |
++algorithm.dememwm.cache.no_evict=true \
|
| 87 |
++algorithm.dememwm.cache.clear_between_videos=true \
|
| 88 |
++algorithm.dememwm.cache.max_records=null \
|
| 89 |
-
++algorithm.dememwm.cache.max_slots=null \
|
| 90 |
++algorithm.dememwm.cache.on_capacity_exceeded=warn \
|
| 91 |
++algorithm.dememwm.curriculum.enabled=true \
|
| 92 |
++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
|
|
@@ -95,7 +93,7 @@ srun python -m main \
|
|
| 95 |
++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \
|
| 96 |
++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \
|
| 97 |
++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \
|
| 98 |
-
experiment.training.batch_size=
|
| 99 |
experiment.training.optim.accumulate_grad_batches=1 \
|
| 100 |
experiment.validation.batch_size=1 \
|
| 101 |
experiment.validation.limit_batch=8 \
|
|
|
|
| 23 |
export WANDB_DISABLED=true
|
| 24 |
export HYDRA_FULL_ERROR=1
|
| 25 |
|
| 26 |
+
OUTPUT_DIR=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/dememwm_full_berzelius_8a100_bs16_global128_350k
|
| 27 |
|
| 28 |
srun python -m main \
|
| 29 |
+
+name=train_dememwm_full_berzelius_8a100_bs16_global128_350k \
|
| 30 |
+output_dir="${OUTPUT_DIR}/" \
|
| 31 |
auto_resume=true \
|
| 32 |
experiment.tasks=[training] \
|
|
|
|
| 40 |
dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \
|
| 41 |
dataset.precomputed_feature_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft/vae_features \
|
| 42 |
dataset.n_frames=1000 \
|
| 43 |
+
+dataset.n_frames_valid=700 \
|
| 44 |
+dataset.customized_validation=true \
|
| 45 |
+dataset.memory_condition_length=0 \
|
| 46 |
+dataset.wo_updown=false \
|
|
|
|
| 68 |
++algorithm.dememwm.dynamic.recent_frames=4 \
|
| 69 |
++algorithm.dememwm.revisit.enabled=true \
|
| 70 |
++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \
|
| 71 |
+
++algorithm.dememwm.revisit.fov_overlap_threshold=0.60 \
|
|
|
|
| 72 |
++algorithm.dememwm.revisit.pose_preselect_topk=64 \
|
| 73 |
++algorithm.dememwm.revisit.fov_yaw_samples=25 \
|
| 74 |
++algorithm.dememwm.revisit.fov_pitch_samples=20 \
|
|
|
|
| 85 |
++algorithm.dememwm.cache.no_evict=true \
|
| 86 |
++algorithm.dememwm.cache.clear_between_videos=true \
|
| 87 |
++algorithm.dememwm.cache.max_records=null \
|
|
|
|
| 88 |
++algorithm.dememwm.cache.on_capacity_exceeded=warn \
|
| 89 |
++algorithm.dememwm.curriculum.enabled=true \
|
| 90 |
++algorithm.dememwm.curriculum.full_stage_start_step=20000 \
|
|
|
|
| 93 |
++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \
|
| 94 |
++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \
|
| 95 |
++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \
|
| 96 |
+
experiment.training.batch_size=16 \
|
| 97 |
experiment.training.optim.accumulate_grad_batches=1 \
|
| 98 |
experiment.validation.batch_size=1 \
|
| 99 |
experiment.validation.limit_batch=8 \
|