| import torch |
| from dememwm_import_helper import install_dememwm_namespace |
|
|
| install_dememwm_namespace() |
| from algorithms.worldmem.dememwm.cache import StreamingCache |
| from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor |
|
|
|
|
| def small_compressor(**kwargs): |
| return CausalConv3DDynamicCompressor( |
| latent_channels=3, |
| dit_hidden_size=8, |
| patch_size=2, |
| conv_kernel_t=3, |
| conv_stride_t=2, |
| max_source_frames=4, |
| **kwargs, |
| ) |
|
|
|
|
| def test_dynamic_compressor_shapes_and_budget(): |
| comp = small_compressor(exclude_latest_local_frames=0) |
| latents = torch.randn(4, 2, 3, 2, 2) |
| frame_indices = torch.arange(4)[:, None].repeat(1, 2) |
| target = torch.tensor([[1, 2], [4, 4]]) |
| tokens, mask, diag = comp(latents, frame_indices, None, target) |
| assert tokens.shape == (2, 2, 2, 8) |
| assert mask.shape == (2, 2, 2) |
| assert mask[0, 0].any() |
| assert diag["selected_source_count"].max().item() <= 4 |
|
|
|
|
| def test_dynamic_compressor_abstains_without_old_enough_sources(): |
| comp = small_compressor(exclude_latest_local_frames=4) |
| latents = torch.randn(2, 1, 3, 2, 2) |
| frame_indices = torch.tensor([[5], [6]]) |
| target = torch.tensor([[8]]) |
| tokens, mask, diag = comp(latents, frame_indices, None, target) |
| assert tokens.shape == (1, 1, 2, 8) |
| assert not mask.any() |
| assert diag["max_source_frame"].item() == -1 |
| assert diag["dynamic_min_gap_to_target_per_target"].item() == -1 |
|
|
|
|
| def test_dynamic_compressor_reports_generated_fraction_and_no_future(): |
| comp = small_compressor(exclude_latest_local_frames=0) |
| latents = torch.randn(3, 1, 3, 2, 2) |
| frame_indices = torch.tensor([[0], [2], [5]]) |
| generated = torch.tensor([[False], [True], [True]]) |
| target = torch.tensor([[3]]) |
| _, mask, diag = comp(latents, frame_indices, None, target, generated) |
| assert mask.any() |
| assert diag["max_source_frame"].item() == 2 |
| assert 0.0 < diag["generated_source_fraction"].item() < 1.0 |
|
|
|
|
| def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape(): |
| comp = small_compressor(exclude_latest_local_frames=2) |
| latents = torch.randn(5, 1, 3, 2, 2) |
| frame_indices = torch.tensor([[0], [1], [2], [3], [4]]) |
| target = torch.tensor([[5]]) |
| tokens, mask, diag = comp(latents, frame_indices, None, target) |
| assert tokens.shape == (1, 1, 2, 8) |
| assert mask.any() |
| assert diag["max_source_frame"].item() == 2 |
| assert diag["dynamic_min_gap_to_target_per_target"].item() == 3 |
| assert diag["dynamic_max_gap_to_target_per_target"].item() == 5 |
| assert diag["dynamic_exclude_latest_local_frames"] == 2 |
|
|
|
|
| def test_cache_materialize_raw_latents_excludes_c_short_overlap(): |
| cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=False) |
| latents = torch.randn(6, 1, 3, 2, 2) |
| frames = torch.arange(6).view(6, 1) |
| cache.add_raw_latents(latents, frames) |
| raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( |
| device=torch.device("cpu"), |
| dtype=latents.dtype, |
| max_recent_frames=8, |
| target_frame_indices=torch.tensor([[6]]), |
| exclude_latest_local_frames=4, |
| ) |
| assert raw_pose is None |
| assert raw_latents.shape[0] == 2 |
| assert raw_generated.shape == raw_frames.shape |
| assert raw_frames.flatten().tolist() == [0, 1] |
|
|
|
|
| def test_dynamic_compressor_preserves_grad_to_trainable_parts(): |
| comp = small_compressor(exclude_latest_local_frames=0) |
| latents = torch.randn(4, 1, 3, 2, 2) |
| frame_indices = torch.arange(4)[:, None] |
| target = torch.tensor([[4]]) |
| tokens, mask, _ = comp(latents, frame_indices, None, target) |
| assert mask.any() |
| tokens.square().sum().backward() |
| grads = [ |
| comp.conv3d.weight.grad, |
| comp.out_norm.weight.grad, |
| ] |
| assert all(grad is not None for grad in grads) |
| assert all(grad.abs().sum().item() > 0 for grad in grads) |
|
|
|
|
| def test_dynamic_compressor_selects_only_recent_valid_sources(): |
| comp = small_compressor(exclude_latest_local_frames=2) |
| latents = torch.randn(20, 1, 3, 2, 2) |
| frame_indices = torch.arange(20)[:, None] |
| target = torch.tensor([[10]]) |
| _, mask, diag = comp(latents, frame_indices, None, target) |
| assert mask.any() |
| assert diag["selected_source_count"].item() == 4 |
|
|