import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType def test_records_to_stream_preserves_grad_to_record_tokens(): record_tokens = torch.full((2, 4), 3.0) record_tokens.requires_grad_() record = MemoryRecord( tokens=record_tokens, mask=torch.ones(2, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False, chunk_id="grad", ) tokens, mask, max_source = MemoryDiTMixin._records_to_stream( object(), [record], target_slots=4, hidden_size=4, device=torch.device("cpu"), dtype=torch.float32, ) assert mask.tolist() == [True, True, False, False] assert max_source == 0 tokens.sum().backward() assert record_tokens.grad is not None assert record_tokens.grad.abs().sum().item() > 0