import types import pytest import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor from algorithms.worldmem.dememwm.diagnostics import summarize_eval_ablation_diagnostics from algorithms.worldmem.dememwm.schedules import ( EVAL_ABLATION_BRANCHES, EVAL_ABLATION_BRANCH_TO_ID, EVAL_CORRUPTION_BRANCHES, normalize_eval_ablation_branch, ) WAVE9_BRANCHES = ( "memory_off", "A_only", "D_only", "A_plus_D", "A_plus_D_plus_R_normal", "R_forced_off", "R_forced_on", "wrong_pose", "time_shuffle", "source_matched_random", "pose_shuffle", "wrong_video", "local_context_overlap_fake_revisit", ) def _device(): return torch.device("cpu") def test_wave9_branch_registry_is_exact_and_validated(): assert EVAL_ABLATION_BRANCHES == WAVE9_BRANCHES assert EVAL_ABLATION_BRANCH_TO_ID["memory_off"] == 0 assert EVAL_ABLATION_BRANCH_TO_ID["local_context_overlap_fake_revisit"] == len(WAVE9_BRANCHES) - 1 assert EVAL_CORRUPTION_BRANCHES == WAVE9_BRANCHES[7:] assert normalize_eval_ablation_branch(None) == "A_plus_D_plus_R_normal" assert normalize_eval_ablation_branch("wrong_pose") == "wrong_pose" with pytest.raises(ValueError): normalize_eval_ablation_branch("ratio_sweep") def test_eval_ablation_diagnostics_bucket_counts(): diag = summarize_eval_ablation_diagnostics( enabled=True, branch="wrong_pose", valid_revisit_mask=torch.tensor([[True, True, True, False]]), no_valid_revisit_mask=torch.tensor([[False, False, False, True]]), eval_corrupted_revisit_mask=torch.tensor([[False, True, True, False]]), ) assert diag["eval_ablation_enabled"] is True assert diag["eval_ablation_branch"] == "wrong_pose" assert diag["eval_ablation_branch_id"] == EVAL_ABLATION_BRANCH_TO_ID["wrong_pose"] assert diag["eval_bucket_true_revisit_count"] == 1 assert diag["eval_bucket_no_valid_revisit_count"] == 1 assert diag["eval_bucket_corrupted_memory_count"] == 2 assert diag["eval_bucket_true_revisit_fraction"] == pytest.approx(0.25) assert diag["eval_bucket_no_valid_revisit_fraction"] == pytest.approx(0.25) assert diag["eval_bucket_corrupted_memory_fraction"] == pytest.approx(0.5) class ConstantGate(torch.nn.Module): def __init__(self, value: float): super().__init__() self.value = float(value) def forward(self, *, valid_revisit_mask, best_selected_fov_overlap, best_selected_plucker_overlap, selected_gap_frames): del valid_revisit_mask, best_selected_plucker_overlap, selected_gap_frames return torch.full_like(best_selected_fov_overlap, self.value, dtype=torch.float32) class DummyDeMemWM(MemoryDiTMixin): def __init__(self, branch: str, device: torch.device): self.cfg = types.SimpleNamespace( dememwm=types.SimpleNamespace( enabled=True, training_stage="stage_2", debug_force_all_streams=False, token_patch_size=2, curriculum=types.SimpleNamespace(enabled=False), anchor=types.SimpleNamespace( enabled=True, anchor_indices=[0, 1], allow_generated_as_anchor=False, diverse_selection=False, compress=types.SimpleNamespace(pool_h=1, pool_w=1), ), dynamic=types.SimpleNamespace( enabled=True, exclude_latest_local_frames=2, recent_frames=4, conv_kernel_t=3, conv_stride_t=2, ), revisit=types.SimpleNamespace( enabled=True, deterministic_pose_retrieval=True, fov_overlap_threshold=0.0, plucker_weight=0.1, max_frames=2, compress=types.SimpleNamespace(pool_h=1, pool_w=1), ), stage_policy=types.SimpleNamespace(noise_bucket_logging=True), eval_ablation=types.SimpleNamespace(enabled=True, branch=branch), generated_history_proxy=types.SimpleNamespace(enabled=False), injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0), cache=types.SimpleNamespace(enabled=False), checkpoint=types.SimpleNamespace(strict_dememwm_eval_load=True), ), weight_decay=0.0, optimizer_beta=(0.9, 0.999), ) self.global_step = 0 self.x_stacked_shape = (1, 4, 4) self.dememwm_anchor_proj = torch.nn.Linear(4, 8, bias=False).to(device) self.dememwm_revisit_proj = torch.nn.Linear(4, 8, bias=False).to(device) self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( latent_channels=1, dit_hidden_size=8, patch_size=2, conv_kernel_t=3, conv_stride_t=2, max_source_frames=4, exclude_latest_local_frames=2, ).to(device) self.dememwm_revisit_gate = ConstantGate(0.25).to(device) def _streams(branch: str): device = _device() model = DummyDeMemWM(branch, device) latents = torch.arange(12 * 1 * 1 * 4 * 4, device=device, dtype=torch.float32).reshape(12, 1, 1, 4, 4) / 100.0 source_frames = torch.arange(12, device=device).reshape(12, 1) target_frames = torch.tensor([[8], [12]], device=device) pose = torch.zeros((12, 1, 5), device=device, dtype=torch.float32) target_pose = torch.zeros((2, 1, 5), device=device, dtype=torch.float32) return model.build_memory_streams( latents, source_frames, target_frame_indices=target_frames, pose=pose, target_pose=target_pose, action=None, target_action=None, ) def test_eval_ablation_stream_enable_branches_control_masks_and_gates(): memory_off = _streams("memory_off") assert memory_off.anchor_gate == 0.0 assert memory_off.dynamic_gate == 0.0 assert torch.count_nonzero(memory_off.revisit_gate).item() == 0 assert not memory_off.anchor_mask.any() assert not memory_off.dynamic_mask.any() assert not memory_off.revisit_mask.any() a_only = _streams("A_only") assert a_only.anchor_gate == 1.0 assert a_only.anchor_mask.any() assert not a_only.dynamic_mask.any() assert not a_only.revisit_mask.any() d_only = _streams("D_only") assert d_only.dynamic_gate == 1.0 assert d_only.dynamic_mask.any() assert d_only.anchor_gate == 0.0 assert not d_only.anchor_mask.any() assert not d_only.revisit_mask.any() a_plus_d = _streams("A_plus_D") assert a_plus_d.anchor_mask.any() assert a_plus_d.dynamic_mask.any() assert not a_plus_d.revisit_mask.any() assert torch.count_nonzero(a_plus_d.revisit_gate).item() == 0 def test_eval_ablation_forced_revisit_controls_are_isolated_to_eval_branch(): normal = _streams("A_plus_D_plus_R_normal") forced_off = _streams("R_forced_off") forced_on = _streams("R_forced_on") assert normal.valid_revisit_mask.all() assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25)) assert torch.count_nonzero(forced_off.revisit_gate).item() == 0 assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype)) assert forced_on.diagnostics["eval_ablation_branch"] == "R_forced_on" def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate(): wrong_pose = _streams("wrong_pose") assert wrong_pose.valid_revisit_mask.all() assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25)) assert wrong_pose.diagnostics["eval_bucket_corrupted_memory_count"] == int(wrong_pose.valid_revisit_mask.numel()) assert wrong_pose.diagnostics["eval_bucket_true_revisit_count"] == 0