import pytest import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType def _record(frame, source_type=MemorySourceType.PREFIX_GT, generated=False, slots=2): return MemoryRecord(tokens=torch.ones(slots, 4) * frame, mask=torch.ones(slots, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=source_type, is_generated=generated, chunk_id=f"r{frame}") def test_prefix_anchors_are_prefix_gt_records(): bank = CausalMemoryBank() bank.add_prefix_anchors(torch.randn(2, 3, 4), torch.ones(2, 3, dtype=torch.bool), torch.tensor([0, 4])) assert [r.source_type for r in bank.records] == [MemorySourceType.PREFIX_GT, MemorySourceType.PREFIX_GT] assert not any(r.is_generated for r in bank.records) def test_generated_records_are_not_prefix_gt_by_default(): bank = CausalMemoryBank() bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([3])) assert bank.records[0].source_type == MemorySourceType.GENERATED assert bank.records[0].is_generated with pytest.raises(ValueError): bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([4]), source_type=MemorySourceType.PREFIX_GT) def test_query_never_returns_future_sources(): bank = CausalMemoryBank() for f in [0, 2, 5, 7]: bank.add_record(_record(f)) records = bank.query(5) assert [r.max_source_frame for r in records] == [0, 2] bank.assert_causal(5, records) def test_all_false_masks_are_valid_abstention_outputs(): rec = MemoryRecord(tokens=torch.zeros(3, 4), mask=torch.zeros(3, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False) assert rec.valid_slots == 0 tokens, mask = stack_record_tokens([rec]) assert tokens.shape == (3, 4) assert mask.sum().item() == 0 def test_query_caps_records_and_stack_uses_target_slots(): bank = CausalMemoryBank(max_records=10) for f in range(6): bank.add_record(_record(f, slots=2)) records = bank.query(MemoryBankQuery(target_frame=10, max_records=2)) assert len(records) == 2 tokens, mask = stack_record_tokens(records, target_slots=3) assert tokens.shape[0] == 3 assert mask.shape[0] == 3 def test_target_slots_ignore_masked_slots_when_stacking_records(): invalid = MemoryRecord( tokens=torch.ones(4, 4), mask=torch.zeros(4, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False, chunk_id="invalid", ) valid = MemoryRecord( tokens=torch.ones(2, 4) * 2, mask=torch.ones(2, dtype=torch.bool), source_start=1, source_end=2, frame_indices=torch.tensor([1]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False, chunk_id="valid", ) bank = CausalMemoryBank() bank.add_record(invalid) bank.add_record(valid) records = bank.query(MemoryBankQuery(target_frame=3)) tokens, mask = stack_record_tokens(records, target_slots=2) assert mask.tolist() == [True, True] assert torch.equal(tokens, torch.ones(2, 4) * 2)