|
|
| import pytest |
| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
| install_dememwm_namespace() |
| from algorithms.worldmem.dememwm.memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens |
| from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType |
|
|
|
|
| def _record(frame, source_type=MemorySourceType.PREFIX_GT, generated=False, slots=2): |
| return MemoryRecord(tokens=torch.ones(slots, 4) * frame, mask=torch.ones(slots, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=source_type, is_generated=generated, chunk_id=f"r{frame}") |
|
|
|
|
| def test_prefix_anchors_are_prefix_gt_records(): |
| bank = CausalMemoryBank() |
| bank.add_prefix_anchors(torch.randn(2, 3, 4), torch.ones(2, 3, dtype=torch.bool), torch.tensor([0, 4])) |
| assert [r.source_type for r in bank.records] == [MemorySourceType.PREFIX_GT, MemorySourceType.PREFIX_GT] |
| assert not any(r.is_generated for r in bank.records) |
|
|
|
|
| def test_generated_records_are_not_prefix_gt_by_default(): |
| bank = CausalMemoryBank() |
| bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([3])) |
| assert bank.records[0].source_type == MemorySourceType.GENERATED |
| assert bank.records[0].is_generated |
| with pytest.raises(ValueError): |
| bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([4]), source_type=MemorySourceType.PREFIX_GT) |
|
|
|
|
| def test_query_never_returns_future_sources(): |
| bank = CausalMemoryBank() |
| for f in [0, 2, 5, 7]: |
| bank.add_record(_record(f)) |
| records = bank.query(5) |
| assert [r.max_source_frame for r in records] == [0, 2] |
| bank.assert_causal(5, records) |
|
|
|
|
| def test_all_false_masks_are_valid_abstention_outputs(): |
| rec = MemoryRecord(tokens=torch.zeros(3, 4), mask=torch.zeros(3, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False) |
| assert rec.valid_slots == 0 |
| tokens, mask = stack_record_tokens([rec]) |
| assert tokens.shape == (3, 4) |
| assert mask.sum().item() == 0 |
|
|
|
|
| def test_query_caps_records_and_stack_uses_target_slots(): |
| bank = CausalMemoryBank(max_records=10) |
| for f in range(6): |
| bank.add_record(_record(f, slots=2)) |
| records = bank.query(MemoryBankQuery(target_frame=10, max_records=2)) |
| assert len(records) == 2 |
| tokens, mask = stack_record_tokens(records, target_slots=3) |
| assert tokens.shape[0] == 3 |
| assert mask.shape[0] == 3 |
|
|
|
|
| def test_target_slots_ignore_masked_slots_when_stacking_records(): |
| invalid = MemoryRecord( |
| tokens=torch.ones(4, 4), |
| mask=torch.zeros(4, dtype=torch.bool), |
| source_start=0, |
| source_end=1, |
| frame_indices=torch.tensor([0]), |
| pose=None, |
| source_type=MemorySourceType.REVISIT, |
| is_generated=False, |
| chunk_id="invalid", |
| ) |
| valid = MemoryRecord( |
| tokens=torch.ones(2, 4) * 2, |
| mask=torch.ones(2, dtype=torch.bool), |
| source_start=1, |
| source_end=2, |
| frame_indices=torch.tensor([1]), |
| pose=None, |
| source_type=MemorySourceType.REVISIT, |
| is_generated=False, |
| chunk_id="valid", |
| ) |
| bank = CausalMemoryBank() |
| bank.add_record(invalid) |
| bank.add_record(valid) |
|
|
| records = bank.query(MemoryBankQuery(target_frame=3)) |
| tokens, mask = stack_record_tokens(records, target_slots=2) |
|
|
| assert mask.tolist() == [True, True] |
| assert torch.equal(tokens, torch.ones(2, 4) * 2) |
|
|