| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
|
|
| install_dememwm_namespace() |
| from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin |
| from algorithms.worldmem.dememwm.cache import StreamingCache |
| from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType |
|
|
|
|
| class Harness(MemoryDiTMixin): |
| def __init__(self): |
| self.n_tokens = 8 |
| self.context_frames = 0 |
| self.frame_stack = 1 |
| self.dememwm_anchor_proj = torch.nn.Linear(12, 8) |
| self.dememwm_revisit_proj = torch.nn.Linear(12, 8) |
| self.project_call_lengths = [] |
| self.project_call_values = [] |
|
|
| def _project_latent_patch_tokens(self, latents, projection, patch_size): |
| self.project_call_lengths.append(int(latents.shape[0])) |
| self.project_call_values.append(latents[:, 0, 0, 0, 0].detach().cpu().tolist()) |
| return MemoryDiTMixin._project_latent_patch_tokens(self, latents, projection, patch_size) |
|
|
|
|
| def test_training_window_bounds_samples_inside_long_clip(): |
| harness = Harness() |
| torch.manual_seed(0) |
| starts = [] |
| for _ in range(20): |
| start, end = harness._training_window_bounds(128, torch.device("cpu")) |
| starts.append(start) |
| assert end - start == 8 |
| assert 0 <= start <= 120 |
| assert any(start != 120 for start in starts) |
|
|
|
|
| def test_training_window_bounds_respects_context_frames(): |
| harness = Harness() |
| harness.context_frames = 100 |
| torch.manual_seed(0) |
| starts = [] |
| for _ in range(20): |
| start, end = harness._training_window_bounds(128, torch.device("cpu")) |
| starts.append(start) |
| assert end - start == 8 |
| assert 100 <= start <= 120 |
| assert any(start != 120 for start in starts) |
|
|
|
|
| def test_revisit_local_context_exclusion_uses_n_tokens_times_frame_stack(): |
| harness = Harness() |
| harness.n_tokens = 4 |
| harness.frame_stack = 2 |
| harness.context_frames = 100 |
| assert harness._local_context_exclusion_frames() == 8 |
|
|
|
|
| def test_diverse_anchor_selection_does_not_repeat_tied_pose_indices(): |
| harness = Harness() |
| source_positions = torch.arange(5) |
| poses = torch.zeros((5, 5), dtype=torch.float32) |
|
|
| selected = harness._select_diverse_anchor_positions(source_positions, poses, 4) |
|
|
| assert selected.tolist() == [0, 1, 2, 3] |
|
|
|
|
| def test_diverse_anchor_selection_seeds_from_widest_pose_pair(): |
| harness = Harness() |
| source_positions = torch.arange(4) |
| poses = torch.tensor([[0.0], [-10.0], [10.0], [0.1]], dtype=torch.float32) |
|
|
| selected = harness._select_diverse_anchor_positions(source_positions, poses, 2) |
|
|
| assert selected.tolist() == [1, 2] |
|
|
|
|
| def test_cached_revisit_prefilter_keeps_only_causal_records(): |
| harness = Harness() |
|
|
| def record(frame: int) -> MemoryRecord: |
| return MemoryRecord( |
| tokens=torch.zeros((1, 8)), |
| mask=torch.ones(1, dtype=torch.bool), |
| source_start=frame, |
| source_end=frame + 1, |
| frame_indices=torch.tensor([frame]), |
| pose=None, |
| source_type=MemorySourceType.REVISIT, |
| is_generated=False, |
| chunk_id=f"revisit_{frame}", |
| ) |
|
|
| selected = harness._causal_cached_revisit_records( |
| (record(0), record(2), record(5)), |
| target_frame=3, |
| ) |
|
|
| assert [record.source_start for record in selected] == [0, 2] |
|
|
|
|
| def test_diverse_anchor_selection_uses_context_frames_not_literal_limit(): |
| harness = Harness() |
| harness.context_frames = 2 |
| latents = torch.randn(8, 1, 3, 2, 2) |
| frame_indices = torch.arange(8)[:, None] |
| poses = torch.zeros((8, 1, 5), dtype=torch.float32) |
| target_pose = torch.zeros((1, 1, 5), dtype=torch.float32) |
| anchor_banks, _, _, diag = harness._build_preselected_causal_memory_banks( |
| committed_latents=latents, |
| source_frame_indices=frame_indices, |
| source_is_generated=None, |
| pose=poses, |
| action=None, |
| target_frame_indices=torch.tensor([[6]]), |
| target_pose=target_pose, |
| target_action=None, |
| target_video_ids=None, |
| allow_generated_anchor=False, |
| anchor_indices=[0, 1, 2, 3], |
| anchor_pool_h=1, |
| anchor_pool_w=1, |
| anchor_diverse=True, |
| revisit_pool_h=1, |
| revisit_pool_w=1, |
| revisit_max_frames=0, |
| exclude_local_context_frames=4, |
| fov_overlap_threshold=0.0, |
| plucker_weight=0.1, |
| revisit_retrieval_kwargs=None, |
| token_patch_size=2, |
| ) |
|
|
| assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1] |
| assert diag["preselected_anchor_projected_frame_count"] == 2 |
|
|
|
|
| def test_streaming_diverse_anchor_selection_uses_context_frames(): |
| harness = Harness() |
| harness.context_frames = 2 |
| latents = torch.randn(8, 1, 3, 2, 2) |
| frame_indices = torch.arange(8)[:, None] |
| poses = torch.zeros((8, 1, 5), dtype=torch.float32) |
|
|
| anchor_banks, _ = harness._build_streaming_cache_records( |
| source_latents=latents, |
| source_frame_indices=frame_indices, |
| source_is_generated=None, |
| pose=poses, |
| action=None, |
| allow_generated_anchor=False, |
| anchor_indices=[0, 1, 2, 3], |
| anchor_pool_h=1, |
| anchor_pool_w=1, |
| anchor_diverse=True, |
| token_patch_size=2, |
| ) |
|
|
| assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1] |
| assert harness.project_call_lengths == [2] |
|
|
|
|
| def test_preselected_memory_banks_project_only_selected_frames(): |
| harness = Harness() |
| latents = torch.randn(20, 1, 3, 2, 2) |
| frame_indices = torch.arange(20)[:, None] |
| target_frame_indices = torch.tensor([[10], [11]]) |
| poses = torch.zeros((20, 1, 5), dtype=torch.float32) |
| target_pose = torch.zeros((2, 1, 5), dtype=torch.float32) |
| anchor_banks, revisit_banks, tokens_per_frame, diag = harness._build_preselected_causal_memory_banks( |
| committed_latents=latents, |
| source_frame_indices=frame_indices, |
| source_is_generated=None, |
| pose=poses, |
| action=None, |
| target_frame_indices=target_frame_indices, |
| target_pose=target_pose, |
| target_action=None, |
| target_video_ids=None, |
| allow_generated_anchor=False, |
| anchor_indices=[0, 1, 2, 3], |
| anchor_pool_h=1, |
| anchor_pool_w=1, |
| anchor_diverse=False, |
| revisit_pool_h=1, |
| revisit_pool_w=1, |
| revisit_max_frames=2, |
| exclude_local_context_frames=4, |
| fov_overlap_threshold=0.0, |
| plucker_weight=0.1, |
| revisit_retrieval_kwargs=None, |
| token_patch_size=2, |
| ) |
| assert tokens_per_frame == 1 |
| assert len(anchor_banks[0].records) == 4 |
| assert len(revisit_banks[0].records) == 3 |
| assert diag["preselected_anchor_projected_frame_count"] == 4 |
| assert diag["preselected_revisit_projected_frame_count"] == 3 |
| assert diag["preselected_revisit_projected_frame_record_count"] == 3 |
| assert harness.project_call_lengths == [4, 1, 1, 1] |
| assert 20 not in harness.project_call_lengths |
|
|
|
|
| def test_preselected_revisit_projects_best_fov_frame_not_latest(): |
| harness = Harness() |
| latents = torch.arange(8, dtype=torch.float32).view(8, 1, 1, 1, 1).expand(8, 1, 3, 2, 2).clone() |
| frame_indices = torch.arange(8)[:, None] |
| pose_rows = torch.tensor( |
| [ |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| [0.0, 0.0, 0.0, 0.0, 0.0], |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| [0.0, 0.0, 0.0, 0.0, 180.0], |
| ], |
| dtype=torch.float32, |
| ) |
| poses = pose_rows[:, None, :] |
|
|
| _, revisit_banks, _, _ = harness._build_preselected_causal_memory_banks( |
| committed_latents=latents, |
| source_frame_indices=frame_indices, |
| source_is_generated=None, |
| pose=poses, |
| action=None, |
| target_frame_indices=torch.tensor([[8]]), |
| target_pose=torch.tensor([[[0.0, 0.0, 0.0, 0.0, 0.0]]]), |
| target_action=None, |
| target_video_ids=None, |
| allow_generated_anchor=False, |
| anchor_indices=[], |
| anchor_pool_h=1, |
| anchor_pool_w=1, |
| anchor_diverse=False, |
| revisit_pool_h=1, |
| revisit_pool_w=1, |
| revisit_max_frames=1, |
| exclude_local_context_frames=4, |
| fov_overlap_threshold=0.30, |
| plucker_weight=0.1, |
| revisit_retrieval_kwargs={"high_quality_fov_threshold": 0.70}, |
| token_patch_size=2, |
| ) |
|
|
| assert len(revisit_banks[0].records) == 1 |
| assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1 |
| assert harness.project_call_values == [[1.0]] |
|
|
|
|
| def test_streaming_revisit_projection_uses_selected_frame_metadata(): |
| harness = Harness() |
| cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=True) |
| latents = torch.arange(4, dtype=torch.float32).view(4, 1, 1, 1, 1).expand(4, 1, 3, 2, 2).clone() |
| cache.add_raw_latents(latents, torch.arange(4)[:, None]) |
| record = MemoryRecord( |
| tokens=torch.zeros((1, 8)), |
| mask=torch.ones(1, dtype=torch.bool), |
| source_start=0, |
| source_end=4, |
| frame_indices=torch.tensor([0, 1, 2, 3]), |
| pose=None, |
| source_type=MemorySourceType.PREFIX_GT, |
| is_generated=False, |
| chunk_id="frame", |
| metadata={ |
| "dememwm_revisit_metadata_only": True, |
| "dememwm_selected_frame_index": 1, |
| }, |
| ) |
|
|
| projected = harness._project_streaming_revisit_records( |
| cache=cache, |
| batch_idx=0, |
| records=[record], |
| device=torch.device("cpu"), |
| dtype=torch.float32, |
| token_patch_size=2, |
| revisit_pool_h=1, |
| revisit_pool_w=1, |
| projection_cache={}, |
| ) |
|
|
| assert len(projected) == 1 |
| assert projected[0].metadata["dememwm_selected_frame_index"] == 1 |
| assert harness.project_call_values == [[1.0]] |
|
|