| import types |
|
|
| import pytest |
| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
|
|
| install_dememwm_namespace() |
| from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin |
| from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor |
| from algorithms.worldmem.dememwm.diagnostics import summarize_eval_ablation_diagnostics |
| from algorithms.worldmem.dememwm.schedules import ( |
| EVAL_ABLATION_BRANCHES, |
| EVAL_ABLATION_BRANCH_TO_ID, |
| EVAL_CORRUPTION_BRANCHES, |
| normalize_eval_ablation_branch, |
| ) |
|
|
|
|
| WAVE9_BRANCHES = ( |
| "memory_off", |
| "A_only", |
| "D_only", |
| "A_plus_D", |
| "A_plus_D_plus_R_normal", |
| "R_forced_off", |
| "R_forced_on", |
| "wrong_pose", |
| "time_shuffle", |
| "source_matched_random", |
| "pose_shuffle", |
| "wrong_video", |
| "local_context_overlap_fake_revisit", |
| ) |
|
|
|
|
| def _device(): |
| return torch.device("cpu") |
|
|
|
|
| def test_wave9_branch_registry_is_exact_and_validated(): |
| assert EVAL_ABLATION_BRANCHES == WAVE9_BRANCHES |
| assert EVAL_ABLATION_BRANCH_TO_ID["memory_off"] == 0 |
| assert EVAL_ABLATION_BRANCH_TO_ID["local_context_overlap_fake_revisit"] == len(WAVE9_BRANCHES) - 1 |
| assert EVAL_CORRUPTION_BRANCHES == WAVE9_BRANCHES[7:] |
| assert normalize_eval_ablation_branch(None) == "A_plus_D_plus_R_normal" |
| assert normalize_eval_ablation_branch("wrong_pose") == "wrong_pose" |
| with pytest.raises(ValueError): |
| normalize_eval_ablation_branch("ratio_sweep") |
|
|
|
|
| def test_eval_ablation_diagnostics_bucket_counts(): |
| diag = summarize_eval_ablation_diagnostics( |
| enabled=True, |
| branch="wrong_pose", |
| valid_revisit_mask=torch.tensor([[True, True, True, False]]), |
| no_valid_revisit_mask=torch.tensor([[False, False, False, True]]), |
| eval_corrupted_revisit_mask=torch.tensor([[False, True, True, False]]), |
| ) |
| assert diag["eval_ablation_enabled"] is True |
| assert diag["eval_ablation_branch"] == "wrong_pose" |
| assert diag["eval_ablation_branch_id"] == EVAL_ABLATION_BRANCH_TO_ID["wrong_pose"] |
| assert diag["eval_bucket_true_revisit_count"] == 1 |
| assert diag["eval_bucket_no_valid_revisit_count"] == 1 |
| assert diag["eval_bucket_corrupted_memory_count"] == 2 |
| assert diag["eval_bucket_true_revisit_fraction"] == pytest.approx(0.25) |
| assert diag["eval_bucket_no_valid_revisit_fraction"] == pytest.approx(0.25) |
| assert diag["eval_bucket_corrupted_memory_fraction"] == pytest.approx(0.5) |
|
|
|
|
| class ConstantGate(torch.nn.Module): |
| def __init__(self, value: float): |
| super().__init__() |
| self.value = float(value) |
|
|
| def forward(self, *, valid_revisit_mask, best_selected_fov_overlap, best_selected_plucker_overlap, selected_gap_frames): |
| del valid_revisit_mask, best_selected_plucker_overlap, selected_gap_frames |
| return torch.full_like(best_selected_fov_overlap, self.value, dtype=torch.float32) |
|
|
|
|
| class DummyDeMemWM(MemoryDiTMixin): |
| def __init__(self, branch: str, device: torch.device): |
| self.cfg = types.SimpleNamespace( |
| dememwm=types.SimpleNamespace( |
| enabled=True, |
| training_stage="stage_2", |
| debug_force_all_streams=False, |
| token_patch_size=2, |
| curriculum=types.SimpleNamespace(enabled=False), |
| anchor=types.SimpleNamespace( |
| enabled=True, |
| anchor_indices=[0, 1], |
| allow_generated_as_anchor=False, |
| diverse_selection=False, |
| compress=types.SimpleNamespace(pool_h=1, pool_w=1), |
| ), |
| dynamic=types.SimpleNamespace( |
| enabled=True, |
| exclude_latest_local_frames=2, |
| recent_frames=4, |
| conv_kernel_t=3, |
| conv_stride_t=2, |
| ), |
| revisit=types.SimpleNamespace( |
| enabled=True, |
| deterministic_pose_retrieval=True, |
| fov_overlap_threshold=0.0, |
| plucker_weight=0.1, |
| max_frames=2, |
| compress=types.SimpleNamespace(pool_h=1, pool_w=1), |
| ), |
| stage_policy=types.SimpleNamespace(noise_bucket_logging=True), |
| eval_ablation=types.SimpleNamespace(enabled=True, branch=branch), |
| generated_history_proxy=types.SimpleNamespace(enabled=False), |
| injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0), |
| cache=types.SimpleNamespace(enabled=False), |
| checkpoint=types.SimpleNamespace(strict_dememwm_eval_load=True), |
| ), |
| weight_decay=0.0, |
| optimizer_beta=(0.9, 0.999), |
| ) |
| self.global_step = 0 |
| self.x_stacked_shape = (1, 4, 4) |
| self.dememwm_anchor_proj = torch.nn.Linear(4, 8, bias=False).to(device) |
| self.dememwm_revisit_proj = torch.nn.Linear(4, 8, bias=False).to(device) |
| self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( |
| latent_channels=1, |
| dit_hidden_size=8, |
| patch_size=2, |
| conv_kernel_t=3, |
| conv_stride_t=2, |
| max_source_frames=4, |
| exclude_latest_local_frames=2, |
| ).to(device) |
| self.dememwm_revisit_gate = ConstantGate(0.25).to(device) |
|
|
|
|
| def _streams(branch: str): |
| device = _device() |
| model = DummyDeMemWM(branch, device) |
| latents = torch.arange(12 * 1 * 1 * 4 * 4, device=device, dtype=torch.float32).reshape(12, 1, 1, 4, 4) / 100.0 |
| source_frames = torch.arange(12, device=device).reshape(12, 1) |
| target_frames = torch.tensor([[8], [12]], device=device) |
| pose = torch.zeros((12, 1, 5), device=device, dtype=torch.float32) |
| target_pose = torch.zeros((2, 1, 5), device=device, dtype=torch.float32) |
| return model.build_memory_streams( |
| latents, |
| source_frames, |
| target_frame_indices=target_frames, |
| pose=pose, |
| target_pose=target_pose, |
| action=None, |
| target_action=None, |
| ) |
|
|
|
|
| def test_eval_ablation_stream_enable_branches_control_masks_and_gates(): |
| memory_off = _streams("memory_off") |
| assert memory_off.anchor_gate == 0.0 |
| assert memory_off.dynamic_gate == 0.0 |
| assert torch.count_nonzero(memory_off.revisit_gate).item() == 0 |
| assert not memory_off.anchor_mask.any() |
| assert not memory_off.dynamic_mask.any() |
| assert not memory_off.revisit_mask.any() |
|
|
| a_only = _streams("A_only") |
| assert a_only.anchor_gate == 1.0 |
| assert a_only.anchor_mask.any() |
| assert not a_only.dynamic_mask.any() |
| assert not a_only.revisit_mask.any() |
|
|
| d_only = _streams("D_only") |
| assert d_only.dynamic_gate == 1.0 |
| assert d_only.dynamic_mask.any() |
| assert d_only.anchor_gate == 0.0 |
| assert not d_only.anchor_mask.any() |
| assert not d_only.revisit_mask.any() |
|
|
| a_plus_d = _streams("A_plus_D") |
| assert a_plus_d.anchor_mask.any() |
| assert a_plus_d.dynamic_mask.any() |
| assert not a_plus_d.revisit_mask.any() |
| assert torch.count_nonzero(a_plus_d.revisit_gate).item() == 0 |
|
|
|
|
| def test_eval_ablation_forced_revisit_controls_are_isolated_to_eval_branch(): |
| normal = _streams("A_plus_D_plus_R_normal") |
| forced_off = _streams("R_forced_off") |
| forced_on = _streams("R_forced_on") |
| assert normal.valid_revisit_mask.all() |
| assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25)) |
| assert torch.count_nonzero(forced_off.revisit_gate).item() == 0 |
| assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype)) |
| assert forced_on.diagnostics["eval_ablation_branch"] == "R_forced_on" |
|
|
|
|
| def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate(): |
| wrong_pose = _streams("wrong_pose") |
| assert wrong_pose.valid_revisit_mask.all() |
| assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25)) |
| assert wrong_pose.diagnostics["eval_bucket_corrupted_memory_count"] == int(wrong_pose.valid_revisit_mask.numel()) |
| assert wrong_pose.diagnostics["eval_bucket_true_revisit_count"] == 0 |
|
|