|
|
| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
| install_dememwm_namespace() |
|
|
| from algorithms.worldmem.dememwm.injection import InjectionAdapter |
| from algorithms.worldmem.dememwm.types import MemoryStreamTensors |
|
|
|
|
| def _streams(dtype=torch.float32): |
| return MemoryStreamTensors( |
| anchor_tokens=torch.randn(2, 3, 1, 4, dtype=dtype), |
| anchor_mask=torch.ones(2, 3, 1), |
| dynamic_tokens=torch.randn(2, 3, 2, 4, dtype=dtype), |
| dynamic_mask=torch.tensor([[[1, 0], [1, 1], [0, 0]], [[1, 1], [0, 0], [1, 0]]]), |
| revisit_tokens=torch.randn(2, 3, 1, 4, dtype=dtype), |
| revisit_mask=torch.zeros(2, 3, 1), |
| anchor_gate=1.0, |
| dynamic_gate=torch.ones(2, 3, 1) * 0.5, |
| revisit_gate=0.0, |
| diagnostics={"selected_revisit_frame_record_ids": ["c1"], "dynamic_max_source_frame": torch.tensor(2)}, |
| ) |
|
|
|
|
| def test_injection_kwarg_names_masks_dtype_and_diagnostics(): |
| kwargs, diag = InjectionAdapter()(_streams(), dtype=torch.float64) |
| assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"} |
| assert kwargs["memory_tokens"].dtype == torch.float64 |
| assert kwargs["memory_dynamic_mask"].dtype == torch.bool |
| assert diag["anchor_valid_tokens"] == 6 |
| assert diag["dynamic_valid_fraction"] > 0.0 |
| assert diag["selected_revisit_frame_record_ids"] == ["c1"] |
| assert diag["max_source_frame"] == 2 |
|
|
|
|
| def test_injection_omit_disabled_streams(): |
| kwargs, _ = InjectionAdapter(omit_disabled=True)(_streams()) |
| assert kwargs["memory_retrieval_tokens"] is None |
| assert kwargs["memory_retrieval_mask"] is None |
| assert kwargs["memory_dynamic_tokens"] is not None |
|
|
|
|
| def test_injection_rejects_bad_mask_shape(): |
| streams = _streams() |
| streams.dynamic_mask = torch.ones(2, 3, 3) |
| try: |
| InjectionAdapter()(streams) |
| except ValueError as exc: |
| assert "dynamic mask" in str(exc) |
| else: |
| raise AssertionError("expected bad mask shape to fail") |
|
|