import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.injection import InjectionAdapter from algorithms.worldmem.dememwm.types import MemoryStreamTensors def _streams(dtype=torch.float32): return MemoryStreamTensors( anchor_tokens=torch.randn(2, 3, 1, 4, dtype=dtype), anchor_mask=torch.ones(2, 3, 1), dynamic_tokens=torch.randn(2, 3, 2, 4, dtype=dtype), dynamic_mask=torch.tensor([[[1, 0], [1, 1], [0, 0]], [[1, 1], [0, 0], [1, 0]]]), revisit_tokens=torch.randn(2, 3, 1, 4, dtype=dtype), revisit_mask=torch.zeros(2, 3, 1), anchor_gate=1.0, dynamic_gate=torch.ones(2, 3, 1) * 0.5, revisit_gate=0.0, diagnostics={"selected_revisit_frame_record_ids": ["c1"], "dynamic_max_source_frame": torch.tensor(2)}, ) def test_injection_kwarg_names_masks_dtype_and_diagnostics(): kwargs, diag = InjectionAdapter()(_streams(), dtype=torch.float64) assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"} assert kwargs["memory_tokens"].dtype == torch.float64 assert kwargs["memory_dynamic_mask"].dtype == torch.bool assert diag["anchor_valid_tokens"] == 6 assert diag["dynamic_valid_fraction"] > 0.0 assert diag["selected_revisit_frame_record_ids"] == ["c1"] assert diag["max_source_frame"] == 2 def test_injection_omit_disabled_streams(): kwargs, _ = InjectionAdapter(omit_disabled=True)(_streams()) assert kwargs["memory_retrieval_tokens"] is None assert kwargs["memory_retrieval_mask"] is None assert kwargs["memory_dynamic_tokens"] is not None def test_injection_rejects_bad_mask_shape(): streams = _streams() streams.dynamic_mask = torch.ones(2, 3, 3) try: InjectionAdapter()(streams) except ValueError as exc: assert "dynamic mask" in str(exc) else: raise AssertionError("expected bad mask shape to fail")