DeMemWM / tests /test_dememwm_compression.py
BonanDing's picture
Initial commit
b47a1ce
import torch
from dememwm_import_helper import install_dememwm_namespace
install_dememwm_namespace()
from algorithms.worldmem.dememwm.cache import StreamingCache
from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor
def small_compressor(**kwargs):
return CausalConv3DDynamicCompressor(
latent_channels=3,
dit_hidden_size=8,
patch_size=2,
conv_kernel_t=3,
conv_stride_t=2,
max_source_frames=4,
**kwargs,
)
def test_dynamic_compressor_shapes_and_budget():
comp = small_compressor(exclude_latest_local_frames=0)
latents = torch.randn(4, 2, 3, 2, 2)
frame_indices = torch.arange(4)[:, None].repeat(1, 2)
target = torch.tensor([[1, 2], [4, 4]])
tokens, mask, diag = comp(latents, frame_indices, None, target)
assert tokens.shape == (2, 2, 2, 8)
assert mask.shape == (2, 2, 2)
assert mask[0, 0].any()
assert diag["selected_source_count"].max().item() <= 4
def test_dynamic_compressor_abstains_without_old_enough_sources():
comp = small_compressor(exclude_latest_local_frames=4)
latents = torch.randn(2, 1, 3, 2, 2)
frame_indices = torch.tensor([[5], [6]])
target = torch.tensor([[8]])
tokens, mask, diag = comp(latents, frame_indices, None, target)
assert tokens.shape == (1, 1, 2, 8)
assert not mask.any()
assert diag["max_source_frame"].item() == -1
assert diag["dynamic_min_gap_to_target_per_target"].item() == -1
def test_dynamic_compressor_reports_generated_fraction_and_no_future():
comp = small_compressor(exclude_latest_local_frames=0)
latents = torch.randn(3, 1, 3, 2, 2)
frame_indices = torch.tensor([[0], [2], [5]])
generated = torch.tensor([[False], [True], [True]])
target = torch.tensor([[3]])
_, mask, diag = comp(latents, frame_indices, None, target, generated)
assert mask.any()
assert diag["max_source_frame"].item() == 2
assert 0.0 < diag["generated_source_fraction"].item() < 1.0
def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
comp = small_compressor(exclude_latest_local_frames=2)
latents = torch.randn(5, 1, 3, 2, 2)
frame_indices = torch.tensor([[0], [1], [2], [3], [4]])
target = torch.tensor([[5]])
tokens, mask, diag = comp(latents, frame_indices, None, target)
assert tokens.shape == (1, 1, 2, 8)
assert mask.any()
assert diag["max_source_frame"].item() == 2
assert diag["dynamic_min_gap_to_target_per_target"].item() == 3
assert diag["dynamic_max_gap_to_target_per_target"].item() == 5
assert diag["dynamic_exclude_latest_local_frames"] == 2
def test_cache_materialize_raw_latents_excludes_c_short_overlap():
cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=False)
latents = torch.randn(6, 1, 3, 2, 2)
frames = torch.arange(6).view(6, 1)
cache.add_raw_latents(latents, frames)
raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents(
device=torch.device("cpu"),
dtype=latents.dtype,
max_recent_frames=8,
target_frame_indices=torch.tensor([[6]]),
exclude_latest_local_frames=4,
)
assert raw_pose is None
assert raw_latents.shape[0] == 2
assert raw_generated.shape == raw_frames.shape
assert raw_frames.flatten().tolist() == [0, 1]
def test_dynamic_compressor_preserves_grad_to_trainable_parts():
comp = small_compressor(exclude_latest_local_frames=0)
latents = torch.randn(4, 1, 3, 2, 2)
frame_indices = torch.arange(4)[:, None]
target = torch.tensor([[4]])
tokens, mask, _ = comp(latents, frame_indices, None, target)
assert mask.any()
tokens.square().sum().backward()
grads = [
comp.conv3d.weight.grad,
comp.out_norm.weight.grad,
]
assert all(grad is not None for grad in grads)
assert all(grad.abs().sum().item() > 0 for grad in grads)
def test_dynamic_compressor_selects_only_recent_valid_sources():
comp = small_compressor(exclude_latest_local_frames=2)
latents = torch.randn(20, 1, 3, 2, 2)
frame_indices = torch.arange(20)[:, None]
target = torch.tensor([[10]])
_, mask, diag = comp(latents, frame_indices, None, target)
assert mask.any()
assert diag["selected_source_count"].item() == 4