File size: 2,748 Bytes
b47a1ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

import ast
from pathlib import Path


def test_entrypoint_class_uses_standalone_mixin_and_base_video_dit():
    src = Path("algorithms/worldmem/dememwm_memory_dit.py").read_text()
    tree = ast.parse(src)
    classes = {node.name: [base.id if isinstance(base, ast.Name) else getattr(base, "attr", "") for base in node.bases] for node in tree.body if isinstance(node, ast.ClassDef)}
    assert classes["DeMemWMMinecraft"] == ["MemoryDiTMixin", "BaseVideoDiTMinecraft"]
    assert "DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft" in src
    assert "SSM" not in src


def test_algorithm_mixin_has_strict_checkpoint_helper_and_no_old_imports():
    src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
    assert "strict_checkpoint_key_check" in src
    assert "strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check" in src
    assert "spatial_ssm_memory" not in src
    assert "df_video_ssm_memory" not in src
    assert "ssm_memory" not in src


def test_algorithm_mixin_wires_standalone_memory_retrieval_surface():
    src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
    tree = ast.parse(src)

    required_imports = {
        "CausalMemoryBank",
        "MemoryBankQuery",
        "stack_record_tokens",
        "deterministic_revisit_retrieval",
        "MemorySourceType",
    }
    for name in required_imports:
        assert name in src

    imported_names = {
        alias.asname or alias.name.rsplit(".", 1)[-1]
        for node in ast.walk(tree)
        if isinstance(node, (ast.Import, ast.ImportFrom))
        for alias in node.names
    }
    assert required_imports <= imported_names

    mixin = next(
        node
        for node in tree.body
        if isinstance(node, ast.ClassDef) and node.name == "MemoryDiTMixin"
    )
    method_names = {
        node.name
        for node in mixin.body
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
    }
    assert {
        "_build_causal_memory_banks",
        "_build_preselected_causal_memory_banks",
        "_records_to_stream",
    } <= method_names

    build_method = next(
        node
        for node in mixin.body
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
        and node.name == "build_memory_streams"
    )
    call_names = set()
    for node in ast.walk(build_method):
        if isinstance(node, ast.Call):
            func = node.func
            if isinstance(func, ast.Name):
                call_names.add(func.id)
            elif isinstance(func, ast.Attribute):
                call_names.add(func.attr)
    assert {
        "_build_preselected_causal_memory_banks",
        "deterministic_revisit_retrieval",
        "dememwm_dynamic_compressor",
    } <= call_names