import torch from dememwm_import_helper import install_dememwm_namespace install_dememwm_namespace() from algorithms.worldmem.dememwm.cache import StreamingCache from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor def small_compressor(**kwargs): return CausalConv3DDynamicCompressor( latent_channels=3, dit_hidden_size=8, patch_size=2, conv_kernel_t=3, conv_stride_t=2, max_source_frames=4, **kwargs, ) def test_dynamic_compressor_shapes_and_budget(): comp = small_compressor(exclude_latest_local_frames=0) latents = torch.randn(4, 2, 3, 2, 2) frame_indices = torch.arange(4)[:, None].repeat(1, 2) target = torch.tensor([[1, 2], [4, 4]]) tokens, mask, diag = comp(latents, frame_indices, None, target) assert tokens.shape == (2, 2, 2, 8) assert mask.shape == (2, 2, 2) assert mask[0, 0].any() assert diag["selected_source_count"].max().item() <= 4 def test_dynamic_compressor_abstains_without_old_enough_sources(): comp = small_compressor(exclude_latest_local_frames=4) latents = torch.randn(2, 1, 3, 2, 2) frame_indices = torch.tensor([[5], [6]]) target = torch.tensor([[8]]) tokens, mask, diag = comp(latents, frame_indices, None, target) assert tokens.shape == (1, 1, 2, 8) assert not mask.any() assert diag["max_source_frame"].item() == -1 assert diag["dynamic_min_gap_to_target_per_target"].item() == -1 def test_dynamic_compressor_reports_generated_fraction_and_no_future(): comp = small_compressor(exclude_latest_local_frames=0) latents = torch.randn(3, 1, 3, 2, 2) frame_indices = torch.tensor([[0], [2], [5]]) generated = torch.tensor([[False], [True], [True]]) target = torch.tensor([[3]]) _, mask, diag = comp(latents, frame_indices, None, target, generated) assert mask.any() assert diag["max_source_frame"].item() == 2 assert 0.0 < diag["generated_source_fraction"].item() < 1.0 def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape(): comp = small_compressor(exclude_latest_local_frames=2) latents = torch.randn(5, 1, 3, 2, 2) frame_indices = torch.tensor([[0], [1], [2], [3], [4]]) target = torch.tensor([[5]]) tokens, mask, diag = comp(latents, frame_indices, None, target) assert tokens.shape == (1, 1, 2, 8) assert mask.any() assert diag["max_source_frame"].item() == 2 assert diag["dynamic_min_gap_to_target_per_target"].item() == 3 assert diag["dynamic_max_gap_to_target_per_target"].item() == 5 assert diag["dynamic_exclude_latest_local_frames"] == 2 def test_cache_materialize_raw_latents_excludes_c_short_overlap(): cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=False) latents = torch.randn(6, 1, 3, 2, 2) frames = torch.arange(6).view(6, 1) cache.add_raw_latents(latents, frames) raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( device=torch.device("cpu"), dtype=latents.dtype, max_recent_frames=8, target_frame_indices=torch.tensor([[6]]), exclude_latest_local_frames=4, ) assert raw_pose is None assert raw_latents.shape[0] == 2 assert raw_generated.shape == raw_frames.shape assert raw_frames.flatten().tolist() == [0, 1] def test_dynamic_compressor_preserves_grad_to_trainable_parts(): comp = small_compressor(exclude_latest_local_frames=0) latents = torch.randn(4, 1, 3, 2, 2) frame_indices = torch.arange(4)[:, None] target = torch.tensor([[4]]) tokens, mask, _ = comp(latents, frame_indices, None, target) assert mask.any() tokens.square().sum().backward() grads = [ comp.conv3d.weight.grad, comp.out_norm.weight.grad, ] assert all(grad is not None for grad in grads) assert all(grad.abs().sum().item() > 0 for grad in grads) def test_dynamic_compressor_selects_only_recent_valid_sources(): comp = small_compressor(exclude_latest_local_frames=2) latents = torch.randn(20, 1, 3, 2, 2) frame_indices = torch.arange(20)[:, None] target = torch.tensor([[10]]) _, mask, diag = comp(latents, frame_indices, None, target) assert mask.any() assert diag["selected_source_count"].item() == 4