File size: 3,605 Bytes
b47a1ce 93d7b0a b47a1ce 93d7b0a b47a1ce 93d7b0a b47a1ce 93d7b0a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import pytest
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType
def _record(frame, source_type=MemorySourceType.PREFIX_GT, generated=False, slots=2):
return MemoryRecord(tokens=torch.ones(slots, 4) * frame, mask=torch.ones(slots, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=source_type, is_generated=generated, chunk_id=f"r{frame}")
def test_prefix_anchors_are_prefix_gt_records():
bank = CausalMemoryBank()
bank.add_prefix_anchors(torch.randn(2, 3, 4), torch.ones(2, 3, dtype=torch.bool), torch.tensor([0, 4]))
assert [r.source_type for r in bank.records] == [MemorySourceType.PREFIX_GT, MemorySourceType.PREFIX_GT]
assert not any(r.is_generated for r in bank.records)
def test_generated_records_are_not_prefix_gt_by_default():
bank = CausalMemoryBank()
bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([3]))
assert bank.records[0].source_type == MemorySourceType.GENERATED
assert bank.records[0].is_generated
with pytest.raises(ValueError):
bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([4]), source_type=MemorySourceType.PREFIX_GT)
def test_query_never_returns_future_sources():
bank = CausalMemoryBank()
for f in [0, 2, 5, 7]:
bank.add_record(_record(f))
records = bank.query(5)
assert [r.max_source_frame for r in records] == [0, 2]
bank.assert_causal(5, records)
def test_all_false_masks_are_valid_abstention_outputs():
rec = MemoryRecord(tokens=torch.zeros(3, 4), mask=torch.zeros(3, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False)
assert rec.valid_slots == 0
tokens, mask = stack_record_tokens([rec])
assert tokens.shape == (3, 4)
assert mask.sum().item() == 0
def test_query_caps_records_and_stack_uses_target_slots():
bank = CausalMemoryBank(max_records=10)
for f in range(6):
bank.add_record(_record(f, slots=2))
records = bank.query(MemoryBankQuery(target_frame=10, max_records=2))
assert len(records) == 2
tokens, mask = stack_record_tokens(records, target_slots=3)
assert tokens.shape[0] == 3
assert mask.shape[0] == 3
def test_target_slots_ignore_masked_slots_when_stacking_records():
invalid = MemoryRecord(
tokens=torch.ones(4, 4),
mask=torch.zeros(4, dtype=torch.bool),
source_start=0,
source_end=1,
frame_indices=torch.tensor([0]),
pose=None,
source_type=MemorySourceType.REVISIT,
is_generated=False,
chunk_id="invalid",
)
valid = MemoryRecord(
tokens=torch.ones(2, 4) * 2,
mask=torch.ones(2, dtype=torch.bool),
source_start=1,
source_end=2,
frame_indices=torch.tensor([1]),
pose=None,
source_type=MemorySourceType.REVISIT,
is_generated=False,
chunk_id="valid",
)
bank = CausalMemoryBank()
bank.add_record(invalid)
bank.add_record(valid)
records = bank.query(MemoryBankQuery(target_frame=3))
tokens, mask = stack_record_tokens(records, target_slots=2)
assert mask.tolist() == [True, True]
assert torch.equal(tokens, torch.ones(2, 4) * 2)
|