| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
|
|
| install_dememwm_namespace() |
|
|
| from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin |
| from algorithms.worldmem.dememwm.types import MemorySourceType |
|
|
|
|
| class DummyDeMemWM(MemoryDiTMixin): |
| def __init__(self, proxy_cfg=None, step=0): |
| self.cfg = SimpleNamespace(dememwm=SimpleNamespace(generated_history_proxy=proxy_cfg)) |
| self.global_step = step |
|
|
|
|
| def _proxy_cfg(**overrides): |
| values = { |
| "enabled": True, |
| "start_step": 0, |
| "ramp_steps": 0, |
| "max_prob": 1.0, |
| "noise_std": 0.5, |
| "dropout_prob": 0.0, |
| } |
| values.update(overrides) |
| return SimpleNamespace(**values) |
|
|
|
|
| def test_generated_history_proxy_config_defaults_disabled_and_train_script_enables_explicit_values(): |
| config = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() |
| train_script = Path("scripts/dememwm_full_train.slurm").read_text() |
| for token in [ |
| "generated_history_proxy:", |
| "enabled: false", |
| "start_step: 0", |
| "ramp_steps: 1", |
| "max_prob: 0.0", |
| "noise_std: 0.25", |
| "dropout_prob: 0.0", |
| ]: |
| assert token in config |
| for token in [ |
| "algorithm.dememwm.generated_history_proxy.enabled=true", |
| "algorithm.dememwm.generated_history_proxy.start_step=40000", |
| "algorithm.dememwm.generated_history_proxy.ramp_steps=40000", |
| "algorithm.dememwm.generated_history_proxy.max_prob=0.25", |
| "algorithm.dememwm.generated_history_proxy.noise_std=0.25", |
| "algorithm.dememwm.generated_history_proxy.dropout_prob=0.0", |
| ]: |
| assert token in train_script |
|
|
|
|
| def test_generated_history_proxy_probability_ramps_after_start_step(): |
| model = DummyDeMemWM(_proxy_cfg(start_step=10, ramp_steps=10, max_prob=0.5)) |
| assert model._generated_history_proxy_prob(step=9) == 0.0 |
| assert model._generated_history_proxy_prob(step=10) == 0.0 |
| assert model._generated_history_proxy_prob(step=15) == 0.25 |
| assert model._generated_history_proxy_prob(step=20) == 0.5 |
| assert model._generated_history_proxy_prob(step=30) == 0.5 |
|
|
|
|
| def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_frames(): |
| model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0) |
| source_latents = torch.zeros(4, 1, 1, 2, 2) |
| source_is_generated = torch.zeros(4, 1, dtype=torch.bool) |
|
|
| torch.manual_seed(123) |
| corrupted, generated, diagnostics = model._apply_generated_history_proxy( |
| source_latents, |
| source_is_generated, |
| ) |
|
|
| assert torch.equal(source_latents, torch.zeros_like(source_latents)) |
| assert not torch.equal(corrupted, source_latents) |
| assert generated.all() |
| assert not source_is_generated.any() |
| assert diagnostics["generated_history_proxy_frame_count"] == 4 |
| assert diagnostics["generated_history_proxy_frame_fraction"] == 1.0 |
|
|
|
|
|
|
| def test_generated_history_proxy_respects_context_prefix_and_target_window_bounds(): |
| model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0) |
| source_latents = torch.zeros(8, 1, 1, 2, 2) |
| source_is_generated = torch.zeros(8, 1, dtype=torch.bool) |
|
|
| torch.manual_seed(123) |
| corrupted, generated, diagnostics = model._apply_generated_history_proxy( |
| source_latents, |
| source_is_generated, |
| context_frame_count=3, |
| target_start_frame=6, |
| ) |
|
|
| expected_generated = torch.tensor( |
| [[False], [False], [False], [True], [True], [True], [False], [False]], |
| dtype=torch.bool, |
| ) |
| assert torch.equal(source_latents, torch.zeros_like(source_latents)) |
| assert torch.equal(generated, expected_generated) |
| assert torch.equal(corrupted[:3], source_latents[:3]) |
| assert not torch.equal(corrupted[3:6], source_latents[3:6]) |
| assert torch.equal(corrupted[6:], source_latents[6:]) |
| assert diagnostics["generated_history_proxy_frame_count"] == 3 |
| assert diagnostics["generated_history_proxy_frame_fraction"] == 3 / 8 |
|
|
|
|
| def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources(): |
| model = DummyDeMemWM(_proxy_cfg(enabled=False)) |
| model.dememwm_anchor_proj = torch.nn.Linear(1, 2, bias=False) |
| model.dememwm_revisit_proj = torch.nn.Linear(1, 2, bias=False) |
| with torch.no_grad(): |
| model.dememwm_anchor_proj.weight.fill_(1.0) |
| model.dememwm_revisit_proj.weight.fill_(1.0) |
|
|
| latents = torch.arange(4, dtype=torch.float32).reshape(4, 1, 1, 1, 1) |
| frame_indices = torch.arange(4).reshape(4, 1) |
| source_is_generated = torch.tensor([[True], [False], [False], [False]]) |
| anchor_projected = model._project_latent_patch_tokens(latents, model.dememwm_anchor_proj, patch_size=1) |
| revisit_projected = model._project_latent_patch_tokens(latents, model.dememwm_revisit_proj, patch_size=1) |
|
|
| anchor_banks, revisit_banks = model._build_causal_memory_banks( |
| anchor_projected, |
| revisit_projected, |
| frame_indices, |
| source_is_generated, |
| pose=None, |
| action=None, |
| allow_generated_anchor=False, |
| anchor_indices=[0, 1], |
| anchor_pool_h=1, |
| anchor_pool_w=1, |
| revisit_pool_h=1, |
| revisit_pool_w=1, |
| src_h=1, |
| src_w=1, |
| ) |
|
|
| anchor_frames = [int(record.frame_indices.item()) for record in anchor_banks[0].records] |
| assert anchor_frames == [1, 2] |
| generated_revisit = [record for record in revisit_banks[0].records if record.is_generated] |
| assert len(generated_revisit) == 1 |
| assert generated_revisit[0].source_type == MemorySourceType.GENERATED |
| assert generated_revisit[0].frame_indices.tolist() == [0] |
|
|