File size: 1,881 Bytes
b47a1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dae740
 
b47a1ce
 
 
1dae740
b47a1ce
 
 
1dae740
b47a1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()

from algorithms.worldmem.dememwm.injection import InjectionAdapter
from algorithms.worldmem.dememwm.types import MemoryStreamTensors


def _streams(dtype=torch.float32):
    return MemoryStreamTensors(
        anchor_tokens=torch.randn(2, 3, 1, 4, dtype=dtype),
        anchor_mask=torch.ones(2, 3, 1),
        dynamic_tokens=torch.randn(2, 3, 2, 4, dtype=dtype),
        dynamic_mask=torch.tensor([[[1, 0], [1, 1], [0, 0]], [[1, 1], [0, 0], [1, 0]]]),
        revisit_tokens=torch.randn(2, 3, 1, 4, dtype=dtype),
        revisit_mask=torch.zeros(2, 3, 1),
        anchor_gate=1.0,
        dynamic_gate=torch.ones(2, 3, 1) * 0.5,
        revisit_gate=0.0,
    )


def test_injection_kwarg_names_masks_and_dtype():
    kwargs = InjectionAdapter()(_streams(), dtype=torch.float64)
    assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"}
    assert kwargs["memory_tokens"].dtype == torch.float64
    assert kwargs["memory_dynamic_mask"].dtype == torch.bool
    assert kwargs["memory_retrieval_tokens"].dtype == torch.float64


def test_injection_omit_disabled_streams():
    kwargs = InjectionAdapter(omit_disabled=True)(_streams())
    assert kwargs["memory_retrieval_tokens"] is None
    assert kwargs["memory_retrieval_mask"] is None
    assert kwargs["memory_dynamic_tokens"] is not None


def test_injection_rejects_bad_mask_shape():
    streams = _streams()
    streams.dynamic_mask = torch.ones(2, 3, 3)
    try:
        InjectionAdapter()(streams)
    except ValueError as exc:
        assert "dynamic mask" in str(exc)
    else:
        raise AssertionError("expected bad mask shape to fail")