| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
|
|
| install_dememwm_namespace() |
| from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin |
| from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType |
|
|
|
|
| def test_records_to_stream_preserves_grad_to_record_tokens(): |
| record_tokens = torch.full((2, 4), 3.0) |
| record_tokens.requires_grad_() |
| record = MemoryRecord( |
| tokens=record_tokens, |
| mask=torch.ones(2, dtype=torch.bool), |
| source_start=0, |
| source_end=1, |
| frame_indices=torch.tensor([0]), |
| pose=None, |
| source_type=MemorySourceType.REVISIT, |
| is_generated=False, |
| chunk_id="grad", |
| ) |
| tokens, mask, max_source = MemoryDiTMixin._records_to_stream( |
| object(), |
| [record], |
| target_slots=4, |
| hidden_size=4, |
| device=torch.device("cpu"), |
| dtype=torch.float32, |
| ) |
| assert mask.tolist() == [True, True, False, False] |
| assert max_source == 0 |
| tokens.sum().backward() |
| assert record_tokens.grad is not None |
| assert record_tokens.grad.abs().sum().item() > 0 |
|
|