import ast from pathlib import Path def test_entrypoint_class_uses_standalone_mixin_and_base_video_dit(): src = Path("algorithms/worldmem/dememwm_memory_dit.py").read_text() tree = ast.parse(src) classes = {node.name: [base.id if isinstance(base, ast.Name) else getattr(base, "attr", "") for base in node.bases] for node in tree.body if isinstance(node, ast.ClassDef)} assert classes["DeMemWMMinecraft"] == ["MemoryDiTMixin", "BaseVideoDiTMinecraft"] assert "DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft" in src assert "SSM" not in src def test_algorithm_mixin_has_strict_checkpoint_helper_and_no_old_imports(): src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() assert "strict_checkpoint_key_check" in src assert "strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check" in src assert "spatial_ssm_memory" not in src assert "df_video_ssm_memory" not in src assert "ssm_memory" not in src def test_algorithm_mixin_wires_standalone_memory_retrieval_surface(): src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() tree = ast.parse(src) required_imports = { "CausalMemoryBank", "MemoryBankQuery", "stack_record_tokens", "deterministic_revisit_retrieval", "MemorySourceType", } for name in required_imports: assert name in src imported_names = { alias.asname or alias.name.rsplit(".", 1)[-1] for node in ast.walk(tree) if isinstance(node, (ast.Import, ast.ImportFrom)) for alias in node.names } assert required_imports <= imported_names mixin = next( node for node in tree.body if isinstance(node, ast.ClassDef) and node.name == "MemoryDiTMixin" ) method_names = { node.name for node in mixin.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) } assert { "_build_causal_memory_banks", "_build_preselected_causal_memory_banks", "_records_to_stream", } <= method_names build_method = next( node for node in mixin.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == "build_memory_streams" ) call_names = set() for node in ast.walk(build_method): if isinstance(node, ast.Call): func = node.func if isinstance(func, ast.Name): call_names.add(func.id) elif isinstance(func, ast.Attribute): call_names.add(func.attr) assert { "_build_preselected_causal_memory_banks", "deterministic_revisit_retrieval", "dememwm_dynamic_compressor", } <= call_names