DeMemWM / tests /test_dememwm_memory.py
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
import pytest
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType
def _record(frame, source_type=MemorySourceType.PREFIX_GT, generated=False, slots=2):
return MemoryRecord(tokens=torch.ones(slots, 4) * frame, mask=torch.ones(slots, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=source_type, is_generated=generated, chunk_id=f"r{frame}")
def test_prefix_anchors_are_prefix_gt_records():
bank = CausalMemoryBank()
bank.add_prefix_anchors(torch.randn(2, 3, 4), torch.ones(2, 3, dtype=torch.bool), torch.tensor([0, 4]))
assert [r.source_type for r in bank.records] == [MemorySourceType.PREFIX_GT, MemorySourceType.PREFIX_GT]
assert not any(r.is_generated for r in bank.records)
def test_generated_records_are_not_prefix_gt_by_default():
bank = CausalMemoryBank()
bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([3]))
assert bank.records[0].source_type == MemorySourceType.GENERATED
assert bank.records[0].is_generated
with pytest.raises(ValueError):
bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([4]), source_type=MemorySourceType.PREFIX_GT)
def test_query_never_returns_future_sources():
bank = CausalMemoryBank()
for f in [0, 2, 5, 7]:
bank.add_record(_record(f))
records = bank.query(5)
assert [r.max_source_frame for r in records] == [0, 2]
bank.assert_causal(5, records)
def test_all_false_masks_are_valid_abstention_outputs():
rec = MemoryRecord(tokens=torch.zeros(3, 4), mask=torch.zeros(3, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False)
assert rec.valid_slots == 0
tokens, mask = stack_record_tokens([rec])
assert tokens.shape == (3, 4)
assert mask.sum().item() == 0
def test_query_caps_records_and_stack_uses_target_slots():
bank = CausalMemoryBank(max_records=10)
for f in range(6):
bank.add_record(_record(f, slots=2))
records = bank.query(MemoryBankQuery(target_frame=10, max_records=2))
assert len(records) == 2
tokens, mask = stack_record_tokens(records, target_slots=3)
assert tokens.shape[0] == 3
assert mask.shape[0] == 3
def test_target_slots_ignore_masked_slots_when_stacking_records():
invalid = MemoryRecord(
tokens=torch.ones(4, 4),
mask=torch.zeros(4, dtype=torch.bool),
source_start=0,
source_end=1,
frame_indices=torch.tensor([0]),
pose=None,
source_type=MemorySourceType.REVISIT,
is_generated=False,
chunk_id="invalid",
)
valid = MemoryRecord(
tokens=torch.ones(2, 4) * 2,
mask=torch.ones(2, dtype=torch.bool),
source_start=1,
source_end=2,
frame_indices=torch.tensor([1]),
pose=None,
source_type=MemorySourceType.REVISIT,
is_generated=False,
chunk_id="valid",
)
bank = CausalMemoryBank()
bank.add_record(invalid)
bank.add_record(valid)
records = bank.query(MemoryBankQuery(target_frame=3))
tokens, mask = stack_record_tokens(records, target_slots=2)
assert mask.tolist() == [True, True]
assert torch.equal(tokens, torch.ones(2, 4) * 2)