import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin from algorithms.worldmem.dememwm.cache import StreamingCache from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType class Harness(MemoryDiTMixin): def __init__(self): self.n_tokens = 8 self.context_frames = 0 self.frame_stack = 1 self.dememwm_anchor_proj = torch.nn.Linear(12, 8) self.dememwm_revisit_proj = torch.nn.Linear(12, 8) self.project_call_lengths = [] self.project_call_values = [] def _project_latent_patch_tokens(self, latents, projection, patch_size): self.project_call_lengths.append(int(latents.shape[0])) self.project_call_values.append(latents[:, 0, 0, 0, 0].detach().cpu().tolist()) return MemoryDiTMixin._project_latent_patch_tokens(self, latents, projection, patch_size) def test_training_window_bounds_samples_inside_long_clip(): harness = Harness() torch.manual_seed(0) starts = [] for _ in range(20): start, end = harness._training_window_bounds(128, torch.device("cpu")) starts.append(start) assert end - start == 8 assert 0 <= start <= 120 assert any(start != 120 for start in starts) def test_training_window_bounds_respects_context_frames(): harness = Harness() harness.context_frames = 100 torch.manual_seed(0) starts = [] for _ in range(20): start, end = harness._training_window_bounds(128, torch.device("cpu")) starts.append(start) assert end - start == 8 assert 100 <= start <= 120 assert any(start != 120 for start in starts) def test_revisit_local_context_exclusion_uses_n_tokens_times_frame_stack(): harness = Harness() harness.n_tokens = 4 harness.frame_stack = 2 harness.context_frames = 100 assert harness._local_context_exclusion_frames() == 8 def test_diverse_anchor_selection_does_not_repeat_tied_pose_indices(): harness = Harness() source_positions = torch.arange(5) poses = torch.zeros((5, 5), dtype=torch.float32) selected = harness._select_diverse_anchor_positions(source_positions, poses, 4) assert selected.tolist() == [0, 1, 2, 3] def test_diverse_anchor_selection_seeds_from_widest_pose_pair(): harness = Harness() source_positions = torch.arange(4) poses = torch.tensor([[0.0], [-10.0], [10.0], [0.1]], dtype=torch.float32) selected = harness._select_diverse_anchor_positions(source_positions, poses, 2) assert selected.tolist() == [1, 2] def test_cached_revisit_prefilter_keeps_only_causal_records(): harness = Harness() def record(frame: int) -> MemoryRecord: return MemoryRecord( tokens=torch.zeros((1, 8)), mask=torch.ones(1, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False, chunk_id=f"revisit_{frame}", ) selected = harness._causal_cached_revisit_records( (record(0), record(2), record(5)), target_frame=3, ) assert [record.source_start for record in selected] == [0, 2] def test_diverse_anchor_selection_uses_context_frames_not_literal_limit(): harness = Harness() harness.context_frames = 2 latents = torch.randn(8, 1, 3, 2, 2) frame_indices = torch.arange(8)[:, None] poses = torch.zeros((8, 1, 5), dtype=torch.float32) target_pose = torch.zeros((1, 1, 5), dtype=torch.float32) anchor_banks, _, _, diag = harness._build_preselected_causal_memory_banks( committed_latents=latents, source_frame_indices=frame_indices, source_is_generated=None, pose=poses, action=None, target_frame_indices=torch.tensor([[6]]), target_pose=target_pose, target_action=None, target_video_ids=None, allow_generated_anchor=False, anchor_indices=[0, 1, 2, 3], anchor_pool_h=1, anchor_pool_w=1, anchor_diverse=True, revisit_pool_h=1, revisit_pool_w=1, revisit_max_frames=0, exclude_local_context_frames=4, fov_overlap_threshold=0.0, plucker_weight=0.1, revisit_retrieval_kwargs=None, token_patch_size=2, ) assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1] assert diag["preselected_anchor_projected_frame_count"] == 2 def test_streaming_diverse_anchor_selection_uses_context_frames(): harness = Harness() harness.context_frames = 2 latents = torch.randn(8, 1, 3, 2, 2) frame_indices = torch.arange(8)[:, None] poses = torch.zeros((8, 1, 5), dtype=torch.float32) anchor_banks, _ = harness._build_streaming_cache_records( source_latents=latents, source_frame_indices=frame_indices, source_is_generated=None, pose=poses, action=None, allow_generated_anchor=False, anchor_indices=[0, 1, 2, 3], anchor_pool_h=1, anchor_pool_w=1, anchor_diverse=True, token_patch_size=2, ) assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1] assert harness.project_call_lengths == [2] def test_preselected_memory_banks_project_only_selected_frames(): harness = Harness() latents = torch.randn(20, 1, 3, 2, 2) frame_indices = torch.arange(20)[:, None] target_frame_indices = torch.tensor([[10], [11]]) poses = torch.zeros((20, 1, 5), dtype=torch.float32) target_pose = torch.zeros((2, 1, 5), dtype=torch.float32) anchor_banks, revisit_banks, tokens_per_frame, diag = harness._build_preselected_causal_memory_banks( committed_latents=latents, source_frame_indices=frame_indices, source_is_generated=None, pose=poses, action=None, target_frame_indices=target_frame_indices, target_pose=target_pose, target_action=None, target_video_ids=None, allow_generated_anchor=False, anchor_indices=[0, 1, 2, 3], anchor_pool_h=1, anchor_pool_w=1, anchor_diverse=False, revisit_pool_h=1, revisit_pool_w=1, revisit_max_frames=2, exclude_local_context_frames=4, fov_overlap_threshold=0.0, plucker_weight=0.1, revisit_retrieval_kwargs=None, token_patch_size=2, ) assert tokens_per_frame == 1 assert len(anchor_banks[0].records) == 4 assert len(revisit_banks[0].records) == 3 assert diag["preselected_anchor_projected_frame_count"] == 4 assert diag["preselected_revisit_projected_frame_count"] == 3 assert diag["preselected_revisit_projected_frame_record_count"] == 3 assert harness.project_call_lengths == [4, 1, 1, 1] assert 20 not in harness.project_call_lengths def test_preselected_revisit_projects_best_fov_frame_not_latest(): harness = Harness() latents = torch.arange(8, dtype=torch.float32).view(8, 1, 1, 1, 1).expand(8, 1, 3, 2, 2).clone() frame_indices = torch.arange(8)[:, None] pose_rows = torch.tensor( [ [0.0, 0.0, 0.0, 0.0, 180.0], [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 180.0], [0.0, 0.0, 0.0, 0.0, 180.0], [0.0, 0.0, 0.0, 0.0, 180.0], [0.0, 0.0, 0.0, 0.0, 180.0], [0.0, 0.0, 0.0, 0.0, 180.0], [0.0, 0.0, 0.0, 0.0, 180.0], ], dtype=torch.float32, ) poses = pose_rows[:, None, :] _, revisit_banks, _, _ = harness._build_preselected_causal_memory_banks( committed_latents=latents, source_frame_indices=frame_indices, source_is_generated=None, pose=poses, action=None, target_frame_indices=torch.tensor([[8]]), target_pose=torch.tensor([[[0.0, 0.0, 0.0, 0.0, 0.0]]]), target_action=None, target_video_ids=None, allow_generated_anchor=False, anchor_indices=[], anchor_pool_h=1, anchor_pool_w=1, anchor_diverse=False, revisit_pool_h=1, revisit_pool_w=1, revisit_max_frames=1, exclude_local_context_frames=4, fov_overlap_threshold=0.30, plucker_weight=0.1, revisit_retrieval_kwargs={"high_quality_fov_threshold": 0.70}, token_patch_size=2, ) assert len(revisit_banks[0].records) == 1 assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1 assert harness.project_call_values == [[1.0]] def test_streaming_revisit_projection_uses_selected_frame_metadata(): harness = Harness() cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=True) latents = torch.arange(4, dtype=torch.float32).view(4, 1, 1, 1, 1).expand(4, 1, 3, 2, 2).clone() cache.add_raw_latents(latents, torch.arange(4)[:, None]) record = MemoryRecord( tokens=torch.zeros((1, 8)), mask=torch.ones(1, dtype=torch.bool), source_start=0, source_end=4, frame_indices=torch.tensor([0, 1, 2, 3]), pose=None, source_type=MemorySourceType.PREFIX_GT, is_generated=False, chunk_id="frame", metadata={ "dememwm_revisit_metadata_only": True, "dememwm_selected_frame_index": 1, }, ) projected = harness._project_streaming_revisit_records( cache=cache, batch_idx=0, records=[record], device=torch.device("cpu"), dtype=torch.float32, token_patch_size=2, revisit_pool_h=1, revisit_pool_w=1, projection_cache={}, ) assert len(projected) == 1 assert projected[0].metadata["dememwm_selected_frame_index"] == 1 assert harness.project_call_values == [[1.0]]