File size: 1,881 Bytes
b47a1ce 1dae740 b47a1ce 1dae740 b47a1ce 1dae740 b47a1ce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.injection import InjectionAdapter
from algorithms.worldmem.dememwm.types import MemoryStreamTensors
def _streams(dtype=torch.float32):
return MemoryStreamTensors(
anchor_tokens=torch.randn(2, 3, 1, 4, dtype=dtype),
anchor_mask=torch.ones(2, 3, 1),
dynamic_tokens=torch.randn(2, 3, 2, 4, dtype=dtype),
dynamic_mask=torch.tensor([[[1, 0], [1, 1], [0, 0]], [[1, 1], [0, 0], [1, 0]]]),
revisit_tokens=torch.randn(2, 3, 1, 4, dtype=dtype),
revisit_mask=torch.zeros(2, 3, 1),
anchor_gate=1.0,
dynamic_gate=torch.ones(2, 3, 1) * 0.5,
revisit_gate=0.0,
)
def test_injection_kwarg_names_masks_and_dtype():
kwargs = InjectionAdapter()(_streams(), dtype=torch.float64)
assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
assert kwargs["memory_tokens"].dtype == torch.float64
assert kwargs["memory_dynamic_mask"].dtype == torch.bool
assert kwargs["memory_retrieval_tokens"].dtype == torch.float64
def test_injection_omit_disabled_streams():
kwargs = InjectionAdapter(omit_disabled=True)(_streams())
assert kwargs["memory_retrieval_tokens"] is None
assert kwargs["memory_retrieval_mask"] is None
assert kwargs["memory_dynamic_tokens"] is not None
def test_injection_rejects_bad_mask_shape():
streams = _streams()
streams.dynamic_mask = torch.ones(2, 3, 3)
try:
InjectionAdapter()(streams)
except ValueError as exc:
assert "dynamic mask" in str(exc)
else:
raise AssertionError("expected bad mask shape to fail")
|