DeMemWM / tests /test_dememwm_algorithm_static.py
BonanDing's picture
Initial commit
b47a1ce
import ast
from pathlib import Path
def test_entrypoint_class_uses_standalone_mixin_and_base_video_dit():
src = Path("algorithms/worldmem/dememwm_memory_dit.py").read_text()
tree = ast.parse(src)
classes = {node.name: [base.id if isinstance(base, ast.Name) else getattr(base, "attr", "") for base in node.bases] for node in tree.body if isinstance(node, ast.ClassDef)}
assert classes["DeMemWMMinecraft"] == ["MemoryDiTMixin", "BaseVideoDiTMinecraft"]
assert "DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft" in src
assert "SSM" not in src
def test_algorithm_mixin_has_strict_checkpoint_helper_and_no_old_imports():
src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
assert "strict_checkpoint_key_check" in src
assert "strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check" in src
assert "spatial_ssm_memory" not in src
assert "df_video_ssm_memory" not in src
assert "ssm_memory" not in src
def test_algorithm_mixin_wires_standalone_memory_retrieval_surface():
src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
tree = ast.parse(src)
required_imports = {
"CausalMemoryBank",
"MemoryBankQuery",
"stack_record_tokens",
"deterministic_revisit_retrieval",
"MemorySourceType",
}
for name in required_imports:
assert name in src
imported_names = {
alias.asname or alias.name.rsplit(".", 1)[-1]
for node in ast.walk(tree)
if isinstance(node, (ast.Import, ast.ImportFrom))
for alias in node.names
}
assert required_imports <= imported_names
mixin = next(
node
for node in tree.body
if isinstance(node, ast.ClassDef) and node.name == "MemoryDiTMixin"
)
method_names = {
node.name
for node in mixin.body
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
}
assert {
"_build_causal_memory_banks",
"_build_preselected_causal_memory_banks",
"_records_to_stream",
} <= method_names
build_method = next(
node
for node in mixin.body
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == "build_memory_streams"
)
call_names = set()
for node in ast.walk(build_method):
if isinstance(node, ast.Call):
func = node.func
if isinstance(func, ast.Name):
call_names.add(func.id)
elif isinstance(func, ast.Attribute):
call_names.add(func.attr)
assert {
"_build_preselected_causal_memory_banks",
"deterministic_revisit_retrieval",
"dememwm_dynamic_compressor",
} <= call_names