DeMemWM / tests /test_dememwm_generated_history_proxy.py
BonanDing's picture
Initial commit
b47a1ce
from pathlib import Path
from types import SimpleNamespace
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
from algorithms.worldmem.dememwm.types import MemorySourceType
class DummyDeMemWM(MemoryDiTMixin):
def __init__(self, proxy_cfg=None, step=0):
self.cfg = SimpleNamespace(dememwm=SimpleNamespace(generated_history_proxy=proxy_cfg))
self.global_step = step
def _proxy_cfg(**overrides):
values = {
"enabled": True,
"start_step": 0,
"ramp_steps": 0,
"max_prob": 1.0,
"noise_std": 0.5,
"dropout_prob": 0.0,
}
values.update(overrides)
return SimpleNamespace(**values)
def test_generated_history_proxy_config_defaults_disabled_and_train_script_enables_explicit_values():
config = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text()
train_script = Path("scripts/dememwm_full_train.slurm").read_text()
for token in [
"generated_history_proxy:",
"enabled: false",
"start_step: 0",
"ramp_steps: 1",
"max_prob: 0.0",
"noise_std: 0.25",
"dropout_prob: 0.0",
]:
assert token in config
for token in [
"algorithm.dememwm.generated_history_proxy.enabled=true",
"algorithm.dememwm.generated_history_proxy.start_step=40000",
"algorithm.dememwm.generated_history_proxy.ramp_steps=40000",
"algorithm.dememwm.generated_history_proxy.max_prob=0.25",
"algorithm.dememwm.generated_history_proxy.noise_std=0.25",
"algorithm.dememwm.generated_history_proxy.dropout_prob=0.0",
]:
assert token in train_script
def test_generated_history_proxy_probability_ramps_after_start_step():
model = DummyDeMemWM(_proxy_cfg(start_step=10, ramp_steps=10, max_prob=0.5))
assert model._generated_history_proxy_prob(step=9) == 0.0
assert model._generated_history_proxy_prob(step=10) == 0.0
assert model._generated_history_proxy_prob(step=15) == 0.25
assert model._generated_history_proxy_prob(step=20) == 0.5
assert model._generated_history_proxy_prob(step=30) == 0.5
def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_frames():
model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0)
source_latents = torch.zeros(4, 1, 1, 2, 2)
source_is_generated = torch.zeros(4, 1, dtype=torch.bool)
torch.manual_seed(123)
corrupted, generated, diagnostics = model._apply_generated_history_proxy(
source_latents,
source_is_generated,
)
assert torch.equal(source_latents, torch.zeros_like(source_latents))
assert not torch.equal(corrupted, source_latents)
assert generated.all()
assert not source_is_generated.any()
assert diagnostics["generated_history_proxy_frame_count"] == 4
assert diagnostics["generated_history_proxy_frame_fraction"] == 1.0
def test_generated_history_proxy_respects_context_prefix_and_target_window_bounds():
model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0)
source_latents = torch.zeros(8, 1, 1, 2, 2)
source_is_generated = torch.zeros(8, 1, dtype=torch.bool)
torch.manual_seed(123)
corrupted, generated, diagnostics = model._apply_generated_history_proxy(
source_latents,
source_is_generated,
context_frame_count=3,
target_start_frame=6,
)
expected_generated = torch.tensor(
[[False], [False], [False], [True], [True], [True], [False], [False]],
dtype=torch.bool,
)
assert torch.equal(source_latents, torch.zeros_like(source_latents))
assert torch.equal(generated, expected_generated)
assert torch.equal(corrupted[:3], source_latents[:3])
assert not torch.equal(corrupted[3:6], source_latents[3:6])
assert torch.equal(corrupted[6:], source_latents[6:])
assert diagnostics["generated_history_proxy_frame_count"] == 3
assert diagnostics["generated_history_proxy_frame_fraction"] == 3 / 8
def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources():
model = DummyDeMemWM(_proxy_cfg(enabled=False))
model.dememwm_anchor_proj = torch.nn.Linear(1, 2, bias=False)
model.dememwm_revisit_proj = torch.nn.Linear(1, 2, bias=False)
with torch.no_grad():
model.dememwm_anchor_proj.weight.fill_(1.0)
model.dememwm_revisit_proj.weight.fill_(1.0)
latents = torch.arange(4, dtype=torch.float32).reshape(4, 1, 1, 1, 1)
frame_indices = torch.arange(4).reshape(4, 1)
source_is_generated = torch.tensor([[True], [False], [False], [False]])
anchor_projected = model._project_latent_patch_tokens(latents, model.dememwm_anchor_proj, patch_size=1)
revisit_projected = model._project_latent_patch_tokens(latents, model.dememwm_revisit_proj, patch_size=1)
anchor_banks, revisit_banks = model._build_causal_memory_banks(
anchor_projected,
revisit_projected,
frame_indices,
source_is_generated,
pose=None,
action=None,
allow_generated_anchor=False,
anchor_indices=[0, 1],
anchor_pool_h=1,
anchor_pool_w=1,
revisit_pool_h=1,
revisit_pool_w=1,
src_h=1,
src_w=1,
)
anchor_frames = [int(record.frame_indices.item()) for record in anchor_banks[0].records]
assert anchor_frames == [1, 2]
generated_revisit = [record for record in revisit_banks[0].records if record.is_generated]
assert len(generated_revisit) == 1
assert generated_revisit[0].source_type == MemorySourceType.GENERATED
assert generated_revisit[0].frame_indices.tolist() == [0]