from pathlib import Path from types import SimpleNamespace import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin from algorithms.worldmem.dememwm.types import MemorySourceType class DummyDeMemWM(MemoryDiTMixin): def __init__(self, proxy_cfg=None, step=0): self.cfg = SimpleNamespace(dememwm=SimpleNamespace(generated_history_proxy=proxy_cfg)) self.global_step = step def _proxy_cfg(**overrides): values = { "enabled": True, "start_step": 0, "ramp_steps": 0, "max_prob": 1.0, "noise_std": 0.5, "dropout_prob": 0.0, } values.update(overrides) return SimpleNamespace(**values) def test_generated_history_proxy_config_defaults_disabled_and_train_script_enables_explicit_values(): config = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() train_script = Path("scripts/dememwm_full_train.slurm").read_text() for token in [ "generated_history_proxy:", "enabled: false", "start_step: 0", "ramp_steps: 1", "max_prob: 0.0", "noise_std: 0.25", "dropout_prob: 0.0", ]: assert token in config for token in [ "algorithm.dememwm.generated_history_proxy.enabled=true", "algorithm.dememwm.generated_history_proxy.start_step=40000", "algorithm.dememwm.generated_history_proxy.ramp_steps=40000", "algorithm.dememwm.generated_history_proxy.max_prob=0.25", "algorithm.dememwm.generated_history_proxy.noise_std=0.25", "algorithm.dememwm.generated_history_proxy.dropout_prob=0.0", ]: assert token in train_script def test_generated_history_proxy_probability_ramps_after_start_step(): model = DummyDeMemWM(_proxy_cfg(start_step=10, ramp_steps=10, max_prob=0.5)) assert model._generated_history_proxy_prob(step=9) == 0.0 assert model._generated_history_proxy_prob(step=10) == 0.0 assert model._generated_history_proxy_prob(step=15) == 0.25 assert model._generated_history_proxy_prob(step=20) == 0.5 assert model._generated_history_proxy_prob(step=30) == 0.5 def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_frames(): model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0) source_latents = torch.zeros(4, 1, 1, 2, 2) source_is_generated = torch.zeros(4, 1, dtype=torch.bool) torch.manual_seed(123) corrupted, generated, diagnostics = model._apply_generated_history_proxy( source_latents, source_is_generated, ) assert torch.equal(source_latents, torch.zeros_like(source_latents)) assert not torch.equal(corrupted, source_latents) assert generated.all() assert not source_is_generated.any() assert diagnostics["generated_history_proxy_frame_count"] == 4 assert diagnostics["generated_history_proxy_frame_fraction"] == 1.0 def test_generated_history_proxy_respects_context_prefix_and_target_window_bounds(): model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0) source_latents = torch.zeros(8, 1, 1, 2, 2) source_is_generated = torch.zeros(8, 1, dtype=torch.bool) torch.manual_seed(123) corrupted, generated, diagnostics = model._apply_generated_history_proxy( source_latents, source_is_generated, context_frame_count=3, target_start_frame=6, ) expected_generated = torch.tensor( [[False], [False], [False], [True], [True], [True], [False], [False]], dtype=torch.bool, ) assert torch.equal(source_latents, torch.zeros_like(source_latents)) assert torch.equal(generated, expected_generated) assert torch.equal(corrupted[:3], source_latents[:3]) assert not torch.equal(corrupted[3:6], source_latents[3:6]) assert torch.equal(corrupted[6:], source_latents[6:]) assert diagnostics["generated_history_proxy_frame_count"] == 3 assert diagnostics["generated_history_proxy_frame_fraction"] == 3 / 8 def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources(): model = DummyDeMemWM(_proxy_cfg(enabled=False)) model.dememwm_anchor_proj = torch.nn.Linear(1, 2, bias=False) model.dememwm_revisit_proj = torch.nn.Linear(1, 2, bias=False) with torch.no_grad(): model.dememwm_anchor_proj.weight.fill_(1.0) model.dememwm_revisit_proj.weight.fill_(1.0) latents = torch.arange(4, dtype=torch.float32).reshape(4, 1, 1, 1, 1) frame_indices = torch.arange(4).reshape(4, 1) source_is_generated = torch.tensor([[True], [False], [False], [False]]) anchor_projected = model._project_latent_patch_tokens(latents, model.dememwm_anchor_proj, patch_size=1) revisit_projected = model._project_latent_patch_tokens(latents, model.dememwm_revisit_proj, patch_size=1) anchor_banks, revisit_banks = model._build_causal_memory_banks( anchor_projected, revisit_projected, frame_indices, source_is_generated, pose=None, action=None, allow_generated_anchor=False, anchor_indices=[0, 1], anchor_pool_h=1, anchor_pool_w=1, revisit_pool_h=1, revisit_pool_w=1, src_h=1, src_w=1, ) anchor_frames = [int(record.frame_indices.item()) for record in anchor_banks[0].records] assert anchor_frames == [1, 2] generated_revisit = [record for record in revisit_banks[0].records if record.is_generated] assert len(generated_revisit) == 1 assert generated_revisit[0].source_type == MemorySourceType.GENERATED assert generated_revisit[0].frame_indices.tolist() == [0]