|
|
| import ast |
| from pathlib import Path |
|
|
|
|
| def test_entrypoint_class_uses_standalone_mixin_and_base_video_dit(): |
| src = Path("algorithms/worldmem/dememwm_memory_dit.py").read_text() |
| tree = ast.parse(src) |
| classes = {node.name: [base.id if isinstance(base, ast.Name) else getattr(base, "attr", "") for base in node.bases] for node in tree.body if isinstance(node, ast.ClassDef)} |
| assert classes["DeMemWMMinecraft"] == ["MemoryDiTMixin", "BaseVideoDiTMinecraft"] |
| assert "DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft" in src |
| assert "SSM" not in src |
|
|
|
|
| def test_algorithm_mixin_has_strict_checkpoint_helper_and_no_old_imports(): |
| src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() |
| assert "strict_checkpoint_key_check" in src |
| assert "strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check" in src |
| assert "spatial_ssm_memory" not in src |
| assert "df_video_ssm_memory" not in src |
| assert "ssm_memory" not in src |
|
|
|
|
| def test_algorithm_mixin_wires_standalone_memory_retrieval_surface(): |
| src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() |
| tree = ast.parse(src) |
|
|
| required_imports = { |
| "CausalMemoryBank", |
| "MemoryBankQuery", |
| "stack_record_tokens", |
| "deterministic_revisit_retrieval", |
| "MemorySourceType", |
| } |
| for name in required_imports: |
| assert name in src |
|
|
| imported_names = { |
| alias.asname or alias.name.rsplit(".", 1)[-1] |
| for node in ast.walk(tree) |
| if isinstance(node, (ast.Import, ast.ImportFrom)) |
| for alias in node.names |
| } |
| assert required_imports <= imported_names |
|
|
| mixin = next( |
| node |
| for node in tree.body |
| if isinstance(node, ast.ClassDef) and node.name == "MemoryDiTMixin" |
| ) |
| method_names = { |
| node.name |
| for node in mixin.body |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) |
| } |
| assert { |
| "_build_causal_memory_banks", |
| "_build_preselected_causal_memory_banks", |
| "_records_to_stream", |
| } <= method_names |
|
|
| build_method = next( |
| node |
| for node in mixin.body |
| if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) |
| and node.name == "build_memory_streams" |
| ) |
| call_names = set() |
| for node in ast.walk(build_method): |
| if isinstance(node, ast.Call): |
| func = node.func |
| if isinstance(func, ast.Name): |
| call_names.add(func.id) |
| elif isinstance(func, ast.Attribute): |
| call_names.add(func.attr) |
| assert { |
| "_build_preselected_causal_memory_banks", |
| "deterministic_revisit_retrieval", |
| "dememwm_dynamic_compressor", |
| } <= call_names |
|
|
|
|