DeMemWM / tests /test_dememwm_eval_ablation.py
BonanDing's picture
Initial commit
b47a1ce
import types
import pytest
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor
from algorithms.worldmem.dememwm.diagnostics import summarize_eval_ablation_diagnostics
from algorithms.worldmem.dememwm.schedules import (
EVAL_ABLATION_BRANCHES,
EVAL_ABLATION_BRANCH_TO_ID,
EVAL_CORRUPTION_BRANCHES,
normalize_eval_ablation_branch,
)
WAVE9_BRANCHES = (
"memory_off",
"A_only",
"D_only",
"A_plus_D",
"A_plus_D_plus_R_normal",
"R_forced_off",
"R_forced_on",
"wrong_pose",
"time_shuffle",
"source_matched_random",
"pose_shuffle",
"wrong_video",
"local_context_overlap_fake_revisit",
)
def _device():
return torch.device("cpu")
def test_wave9_branch_registry_is_exact_and_validated():
assert EVAL_ABLATION_BRANCHES == WAVE9_BRANCHES
assert EVAL_ABLATION_BRANCH_TO_ID["memory_off"] == 0
assert EVAL_ABLATION_BRANCH_TO_ID["local_context_overlap_fake_revisit"] == len(WAVE9_BRANCHES) - 1
assert EVAL_CORRUPTION_BRANCHES == WAVE9_BRANCHES[7:]
assert normalize_eval_ablation_branch(None) == "A_plus_D_plus_R_normal"
assert normalize_eval_ablation_branch("wrong_pose") == "wrong_pose"
with pytest.raises(ValueError):
normalize_eval_ablation_branch("ratio_sweep")
def test_eval_ablation_diagnostics_bucket_counts():
diag = summarize_eval_ablation_diagnostics(
enabled=True,
branch="wrong_pose",
valid_revisit_mask=torch.tensor([[True, True, True, False]]),
no_valid_revisit_mask=torch.tensor([[False, False, False, True]]),
eval_corrupted_revisit_mask=torch.tensor([[False, True, True, False]]),
)
assert diag["eval_ablation_enabled"] is True
assert diag["eval_ablation_branch"] == "wrong_pose"
assert diag["eval_ablation_branch_id"] == EVAL_ABLATION_BRANCH_TO_ID["wrong_pose"]
assert diag["eval_bucket_true_revisit_count"] == 1
assert diag["eval_bucket_no_valid_revisit_count"] == 1
assert diag["eval_bucket_corrupted_memory_count"] == 2
assert diag["eval_bucket_true_revisit_fraction"] == pytest.approx(0.25)
assert diag["eval_bucket_no_valid_revisit_fraction"] == pytest.approx(0.25)
assert diag["eval_bucket_corrupted_memory_fraction"] == pytest.approx(0.5)
class ConstantGate(torch.nn.Module):
def __init__(self, value: float):
super().__init__()
self.value = float(value)
def forward(self, *, valid_revisit_mask, best_selected_fov_overlap, best_selected_plucker_overlap, selected_gap_frames):
del valid_revisit_mask, best_selected_plucker_overlap, selected_gap_frames
return torch.full_like(best_selected_fov_overlap, self.value, dtype=torch.float32)
class DummyDeMemWM(MemoryDiTMixin):
def __init__(self, branch: str, device: torch.device):
self.cfg = types.SimpleNamespace(
dememwm=types.SimpleNamespace(
enabled=True,
training_stage="stage_2",
debug_force_all_streams=False,
token_patch_size=2,
curriculum=types.SimpleNamespace(enabled=False),
anchor=types.SimpleNamespace(
enabled=True,
anchor_indices=[0, 1],
allow_generated_as_anchor=False,
diverse_selection=False,
compress=types.SimpleNamespace(pool_h=1, pool_w=1),
),
dynamic=types.SimpleNamespace(
enabled=True,
exclude_latest_local_frames=2,
recent_frames=4,
conv_kernel_t=3,
conv_stride_t=2,
),
revisit=types.SimpleNamespace(
enabled=True,
deterministic_pose_retrieval=True,
fov_overlap_threshold=0.0,
plucker_weight=0.1,
max_frames=2,
compress=types.SimpleNamespace(pool_h=1, pool_w=1),
),
stage_policy=types.SimpleNamespace(noise_bucket_logging=True),
eval_ablation=types.SimpleNamespace(enabled=True, branch=branch),
generated_history_proxy=types.SimpleNamespace(enabled=False),
injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0),
cache=types.SimpleNamespace(enabled=False),
checkpoint=types.SimpleNamespace(strict_dememwm_eval_load=True),
),
weight_decay=0.0,
optimizer_beta=(0.9, 0.999),
)
self.global_step = 0
self.x_stacked_shape = (1, 4, 4)
self.dememwm_anchor_proj = torch.nn.Linear(4, 8, bias=False).to(device)
self.dememwm_revisit_proj = torch.nn.Linear(4, 8, bias=False).to(device)
self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor(
latent_channels=1,
dit_hidden_size=8,
patch_size=2,
conv_kernel_t=3,
conv_stride_t=2,
max_source_frames=4,
exclude_latest_local_frames=2,
).to(device)
self.dememwm_revisit_gate = ConstantGate(0.25).to(device)
def _streams(branch: str):
device = _device()
model = DummyDeMemWM(branch, device)
latents = torch.arange(12 * 1 * 1 * 4 * 4, device=device, dtype=torch.float32).reshape(12, 1, 1, 4, 4) / 100.0
source_frames = torch.arange(12, device=device).reshape(12, 1)
target_frames = torch.tensor([[8], [12]], device=device)
pose = torch.zeros((12, 1, 5), device=device, dtype=torch.float32)
target_pose = torch.zeros((2, 1, 5), device=device, dtype=torch.float32)
return model.build_memory_streams(
latents,
source_frames,
target_frame_indices=target_frames,
pose=pose,
target_pose=target_pose,
action=None,
target_action=None,
)
def test_eval_ablation_stream_enable_branches_control_masks_and_gates():
memory_off = _streams("memory_off")
assert memory_off.anchor_gate == 0.0
assert memory_off.dynamic_gate == 0.0
assert torch.count_nonzero(memory_off.revisit_gate).item() == 0
assert not memory_off.anchor_mask.any()
assert not memory_off.dynamic_mask.any()
assert not memory_off.revisit_mask.any()
a_only = _streams("A_only")
assert a_only.anchor_gate == 1.0
assert a_only.anchor_mask.any()
assert not a_only.dynamic_mask.any()
assert not a_only.revisit_mask.any()
d_only = _streams("D_only")
assert d_only.dynamic_gate == 1.0
assert d_only.dynamic_mask.any()
assert d_only.anchor_gate == 0.0
assert not d_only.anchor_mask.any()
assert not d_only.revisit_mask.any()
a_plus_d = _streams("A_plus_D")
assert a_plus_d.anchor_mask.any()
assert a_plus_d.dynamic_mask.any()
assert not a_plus_d.revisit_mask.any()
assert torch.count_nonzero(a_plus_d.revisit_gate).item() == 0
def test_eval_ablation_forced_revisit_controls_are_isolated_to_eval_branch():
normal = _streams("A_plus_D_plus_R_normal")
forced_off = _streams("R_forced_off")
forced_on = _streams("R_forced_on")
assert normal.valid_revisit_mask.all()
assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25))
assert torch.count_nonzero(forced_off.revisit_gate).item() == 0
assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype))
assert forced_on.diagnostics["eval_ablation_branch"] == "R_forced_on"
def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate():
wrong_pose = _streams("wrong_pose")
assert wrong_pose.valid_revisit_mask.all()
assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25))
assert wrong_pose.diagnostics["eval_bucket_corrupted_memory_count"] == int(wrong_pose.valid_revisit_mask.numel())
assert wrong_pose.diagnostics["eval_bucket_true_revisit_count"] == 0