DeMemWM / tests /test_dememwm_injection_static.py
BonanDing's picture
Initial commit
b47a1ce
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.injection import InjectionAdapter
from algorithms.worldmem.dememwm.types import MemoryStreamTensors
def _streams(dtype=torch.float32):
return MemoryStreamTensors(
anchor_tokens=torch.randn(2, 3, 1, 4, dtype=dtype),
anchor_mask=torch.ones(2, 3, 1),
dynamic_tokens=torch.randn(2, 3, 2, 4, dtype=dtype),
dynamic_mask=torch.tensor([[[1, 0], [1, 1], [0, 0]], [[1, 1], [0, 0], [1, 0]]]),
revisit_tokens=torch.randn(2, 3, 1, 4, dtype=dtype),
revisit_mask=torch.zeros(2, 3, 1),
anchor_gate=1.0,
dynamic_gate=torch.ones(2, 3, 1) * 0.5,
revisit_gate=0.0,
diagnostics={"selected_revisit_frame_record_ids": ["c1"], "dynamic_max_source_frame": torch.tensor(2)},
)
def test_injection_kwarg_names_masks_dtype_and_diagnostics():
kwargs, diag = InjectionAdapter()(_streams(), dtype=torch.float64)
assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
assert kwargs["memory_tokens"].dtype == torch.float64
assert kwargs["memory_dynamic_mask"].dtype == torch.bool
assert diag["anchor_valid_tokens"] == 6
assert diag["dynamic_valid_fraction"] > 0.0
assert diag["selected_revisit_frame_record_ids"] == ["c1"]
assert diag["max_source_frame"] == 2
def test_injection_omit_disabled_streams():
kwargs, _ = InjectionAdapter(omit_disabled=True)(_streams())
assert kwargs["memory_retrieval_tokens"] is None
assert kwargs["memory_retrieval_mask"] is None
assert kwargs["memory_dynamic_tokens"] is not None
def test_injection_rejects_bad_mask_shape():
streams = _streams()
streams.dynamic_mask = torch.ones(2, 3, 3)
try:
InjectionAdapter()(streams)
except ValueError as exc:
assert "dynamic mask" in str(exc)
else:
raise AssertionError("expected bad mask shape to fail")