File size: 4,331 Bytes
b47a1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from dememwm_import_helper import install_dememwm_namespace

install_dememwm_namespace()
from algorithms.worldmem.dememwm.cache import StreamingCache
from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor


def small_compressor(**kwargs):
    return CausalConv3DDynamicCompressor(
        latent_channels=3,
        dit_hidden_size=8,
        patch_size=2,
        conv_kernel_t=3,
        conv_stride_t=2,
        max_source_frames=4,
        **kwargs,
    )


def test_dynamic_compressor_shapes_and_budget():
    comp = small_compressor(exclude_latest_local_frames=0)
    latents = torch.randn(4, 2, 3, 2, 2)
    frame_indices = torch.arange(4)[:, None].repeat(1, 2)
    target = torch.tensor([[1, 2], [4, 4]])
    tokens, mask, diag = comp(latents, frame_indices, None, target)
    assert tokens.shape == (2, 2, 2, 8)
    assert mask.shape == (2, 2, 2)
    assert mask[0, 0].any()
    assert diag["selected_source_count"].max().item() <= 4


def test_dynamic_compressor_abstains_without_old_enough_sources():
    comp = small_compressor(exclude_latest_local_frames=4)
    latents = torch.randn(2, 1, 3, 2, 2)
    frame_indices = torch.tensor([[5], [6]])
    target = torch.tensor([[8]])
    tokens, mask, diag = comp(latents, frame_indices, None, target)
    assert tokens.shape == (1, 1, 2, 8)
    assert not mask.any()
    assert diag["max_source_frame"].item() == -1
    assert diag["dynamic_min_gap_to_target_per_target"].item() == -1


def test_dynamic_compressor_reports_generated_fraction_and_no_future():
    comp = small_compressor(exclude_latest_local_frames=0)
    latents = torch.randn(3, 1, 3, 2, 2)
    frame_indices = torch.tensor([[0], [2], [5]])
    generated = torch.tensor([[False], [True], [True]])
    target = torch.tensor([[3]])
    _, mask, diag = comp(latents, frame_indices, None, target, generated)
    assert mask.any()
    assert diag["max_source_frame"].item() == 2
    assert 0.0 < diag["generated_source_fraction"].item() < 1.0


def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape():
    comp = small_compressor(exclude_latest_local_frames=2)
    latents = torch.randn(5, 1, 3, 2, 2)
    frame_indices = torch.tensor([[0], [1], [2], [3], [4]])
    target = torch.tensor([[5]])
    tokens, mask, diag = comp(latents, frame_indices, None, target)
    assert tokens.shape == (1, 1, 2, 8)
    assert mask.any()
    assert diag["max_source_frame"].item() == 2
    assert diag["dynamic_min_gap_to_target_per_target"].item() == 3
    assert diag["dynamic_max_gap_to_target_per_target"].item() == 5
    assert diag["dynamic_exclude_latest_local_frames"] == 2


def test_cache_materialize_raw_latents_excludes_c_short_overlap():
    cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=False)
    latents = torch.randn(6, 1, 3, 2, 2)
    frames = torch.arange(6).view(6, 1)
    cache.add_raw_latents(latents, frames)
    raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents(
        device=torch.device("cpu"),
        dtype=latents.dtype,
        max_recent_frames=8,
        target_frame_indices=torch.tensor([[6]]),
        exclude_latest_local_frames=4,
    )
    assert raw_pose is None
    assert raw_latents.shape[0] == 2
    assert raw_generated.shape == raw_frames.shape
    assert raw_frames.flatten().tolist() == [0, 1]


def test_dynamic_compressor_preserves_grad_to_trainable_parts():
    comp = small_compressor(exclude_latest_local_frames=0)
    latents = torch.randn(4, 1, 3, 2, 2)
    frame_indices = torch.arange(4)[:, None]
    target = torch.tensor([[4]])
    tokens, mask, _ = comp(latents, frame_indices, None, target)
    assert mask.any()
    tokens.square().sum().backward()
    grads = [
        comp.conv3d.weight.grad,
        comp.out_norm.weight.grad,
    ]
    assert all(grad is not None for grad in grads)
    assert all(grad.abs().sum().item() > 0 for grad in grads)


def test_dynamic_compressor_selects_only_recent_valid_sources():
    comp = small_compressor(exclude_latest_local_frames=2)
    latents = torch.randn(20, 1, 3, 2, 2)
    frame_indices = torch.arange(20)[:, None]
    target = torch.tensor([[10]])
    _, mask, diag = comp(latents, frame_indices, None, target)
    assert mask.any()
    assert diag["selected_source_count"].item() == 4