DeMemWM / tests /test_dememwm_stream_grad.py
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin
from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType
def test_records_to_stream_preserves_grad_to_record_tokens():
record_tokens = torch.full((2, 4), 3.0)
record_tokens.requires_grad_()
record = MemoryRecord(
tokens=record_tokens,
mask=torch.ones(2, dtype=torch.bool),
source_start=0,
source_end=1,
frame_indices=torch.tensor([0]),
pose=None,
source_type=MemorySourceType.REVISIT,
is_generated=False,
chunk_id="grad",
)
tokens, mask, max_source = MemoryDiTMixin._records_to_stream(
object(),
[record],
target_slots=4,
hidden_size=4,
device=torch.device("cpu"),
dtype=torch.float32,
)
assert mask.tolist() == [True, True, False, False]
assert max_source == 0
tokens.sum().backward()
assert record_tokens.grad is not None
assert record_tokens.grad.abs().sum().item() > 0