File size: 3,605 Bytes
b47a1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93d7b0a
b47a1ce
 
 
93d7b0a
b47a1ce
93d7b0a
b47a1ce
 
93d7b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

import pytest
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType


def _record(frame, source_type=MemorySourceType.PREFIX_GT, generated=False, slots=2):
    return MemoryRecord(tokens=torch.ones(slots, 4) * frame, mask=torch.ones(slots, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=source_type, is_generated=generated, chunk_id=f"r{frame}")


def test_prefix_anchors_are_prefix_gt_records():
    bank = CausalMemoryBank()
    bank.add_prefix_anchors(torch.randn(2, 3, 4), torch.ones(2, 3, dtype=torch.bool), torch.tensor([0, 4]))
    assert [r.source_type for r in bank.records] == [MemorySourceType.PREFIX_GT, MemorySourceType.PREFIX_GT]
    assert not any(r.is_generated for r in bank.records)


def test_generated_records_are_not_prefix_gt_by_default():
    bank = CausalMemoryBank()
    bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([3]))
    assert bank.records[0].source_type == MemorySourceType.GENERATED
    assert bank.records[0].is_generated
    with pytest.raises(ValueError):
        bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([4]), source_type=MemorySourceType.PREFIX_GT)


def test_query_never_returns_future_sources():
    bank = CausalMemoryBank()
    for f in [0, 2, 5, 7]:
        bank.add_record(_record(f))
    records = bank.query(5)
    assert [r.max_source_frame for r in records] == [0, 2]
    bank.assert_causal(5, records)


def test_all_false_masks_are_valid_abstention_outputs():
    rec = MemoryRecord(tokens=torch.zeros(3, 4), mask=torch.zeros(3, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False)
    assert rec.valid_slots == 0
    tokens, mask = stack_record_tokens([rec])
    assert tokens.shape == (3, 4)
    assert mask.sum().item() == 0


def test_query_caps_records_and_stack_uses_target_slots():
    bank = CausalMemoryBank(max_records=10)
    for f in range(6):
        bank.add_record(_record(f, slots=2))
    records = bank.query(MemoryBankQuery(target_frame=10, max_records=2))
    assert len(records) == 2
    tokens, mask = stack_record_tokens(records, target_slots=3)
    assert tokens.shape[0] == 3
    assert mask.shape[0] == 3


def test_target_slots_ignore_masked_slots_when_stacking_records():
    invalid = MemoryRecord(
        tokens=torch.ones(4, 4),
        mask=torch.zeros(4, dtype=torch.bool),
        source_start=0,
        source_end=1,
        frame_indices=torch.tensor([0]),
        pose=None,
        source_type=MemorySourceType.REVISIT,
        is_generated=False,
        chunk_id="invalid",
    )
    valid = MemoryRecord(
        tokens=torch.ones(2, 4) * 2,
        mask=torch.ones(2, dtype=torch.bool),
        source_start=1,
        source_end=2,
        frame_indices=torch.tensor([1]),
        pose=None,
        source_type=MemorySourceType.REVISIT,
        is_generated=False,
        chunk_id="valid",
    )
    bank = CausalMemoryBank()
    bank.add_record(invalid)
    bank.add_record(valid)

    records = bank.query(MemoryBankQuery(target_frame=3))
    tokens, mask = stack_record_tokens(records, target_slots=2)

    assert mask.tolist() == [True, True]
    assert torch.equal(tokens, torch.ones(2, 4) * 2)