DeMemWM / tests /test_dememwm_config_static.py
BonanDing's picture
Clean DeMemWM deterministic memory slot handling
93d7b0a
from pathlib import Path
def test_config_is_distinct_standalone_memory_dit_path():
text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text()
assert "_name: dememwm_memory_dit" in text
assert "base_video_dit" in text
assert "memory_token_cross_attention: true" in text
assert "dememwm:" in text
assert "debug_force_all_streams" in text
assert "ssm_memory" not in text
assert "ssm_memory_ckpt_path" not in text
def test_registry_text_contains_new_algorithm_key():
exp = Path("experiments/exp_video.py").read_text()
init = Path("algorithms/worldmem/__init__.py").read_text()
assert "DeMemWMMinecraft" in init
assert "DeMemWMMemoryDiTMinecraft" in init
assert "dememwm_memory_dit=DeMemWMMinecraft" in exp
assert "StateSpaceSpatialMemoryMinecraft" not in init + exp
assert "WorldMemMinecraft" not in init + exp
def test_current_config_contract_is_explicit_and_has_no_stale_sections():
text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text()
required = [
"token_patch_size: 2",
"exclude_latest_local_frames: 4",
"deterministic_pose_retrieval: true",
"max_frames: 2",
"fov_overlap_threshold: 0.30",
"high_quality_fov_threshold: 0.70",
"plucker_weight: 0.10",
"fov_half_h: 52.5",
"fov_half_v: 37.5",
"fov_radius: 30.0",
"fov_yaw_samples: 25",
"fov_pitch_samples: 20",
"fov_depth_samples: 20",
"pose_preselect_topk: 64",
"plucker_grid_h: 4",
"plucker_grid_w: 4",
"plucker_focal_length: 0.35",
"noise_bucket_logging: true",
"eval_ablation:",
"branch: A_plus_D_plus_R_normal",
"generated_history_proxy:",
]
for token in required:
assert token in text
for forbidden in (
"anchor_ratio",
"dynamic_ratio",
"revisit_ratio",
"lambda_abstain",
"abstention:",
"force_gate_zero_when_invalid",
"use_residual_bound_loss",
"use_utility_loss",
"use_revisit_classifier_loss",
"min_score",
"generated_penalty",
"min_gap_frames",
"max_chunks",
"chunk_frames",
):
assert forbidden not in text
def test_full_scripts_use_consumed_contract_overrides():
required = [
"algorithm.dememwm.dynamic.exclude_latest_local_frames=4",
"algorithm.dememwm.revisit.deterministic_pose_retrieval=true",
"algorithm.dememwm.revisit.pose_preselect_topk=64",
"algorithm.dememwm.revisit.fov_yaw_samples=25",
"algorithm.dememwm.revisit.fov_pitch_samples=20",
"algorithm.dememwm.revisit.fov_depth_samples=20",
"algorithm.dememwm.revisit.plucker_weight=0.10",
"algorithm.dememwm.stage_policy.noise_bucket_logging=true",
"algorithm.dememwm.cache.keep_compressed_records=true",
]
stale = [
"algorithm.dememwm.loss.",
"algorithm.dememwm.abstention.",
"algorithm.dememwm.anchor.topk",
"algorithm.dememwm.anchor.pin_prefix",
"algorithm.dememwm.dynamic.include_generated_recent",
"algorithm.dememwm.revisit.deterministic_only",
"algorithm.dememwm.revisit.min_age_frames",
"algorithm.dememwm.revisit.topk",
"algorithm.dememwm.revisit.min_gap_frames",
"algorithm.dememwm.revisit.max_chunks",
"algorithm.dememwm.revisit.chunk_frames",
"algorithm.dememwm.revisit.min_score",
"algorithm.dememwm.revisit.generated_penalty",
"algorithm.dememwm.rollout.",
]
expected_by_script = {
"scripts/dememwm_full_train.slurm": [
"algorithm.dememwm.revisit.fov_overlap_threshold=0.60",
],
"scripts/dememwm_full_eval.slurm": [
"algorithm.dememwm.revisit.fov_overlap_threshold=0.30",
"algorithm.dememwm.revisit.high_quality_fov_threshold=0.70",
],
}
for rel, script_specific_required in expected_by_script.items():
text = Path(rel).read_text()
for token in required + script_specific_required:
assert token in text, f"{token} missing from {rel}"
for token in stale:
assert token not in text, f"stale {token} override remains in {rel}"
def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args():
text = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
assert "_dememwm_validate_config_contract = _validate_config_contract" in text
for token in [
"_validate_config_contract",
"deterministic_pose_retrieval",
"exclude_latest_local_frames",
"noise_bucket_logging",
"anchor_effective_enabled",
"dynamic_effective_enabled",
"revisit_effective_enabled",
"stale DeMemWM config fields",
"revisit_retrieval_kwargs",
"fov_half_h",
"fov_yaw_samples",
"plucker_grid_h",
"plucker_focal_length",
"pose_preselect_topk",
]:
assert token in text
assert '_cfg_get(revisit_cfg, "topk"' not in text
assert "lambda_abstain" not in text
def test_revisit_retrieval_is_deterministic_fov_plucker_contract():
retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text()
labels = Path("algorithms/worldmem/dememwm/labels.py").read_text()
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
for token in [
"exclude_local_context_frames",
"fov_overlap_threshold",
"plucker_weight",
"high_quality_fov_threshold",
"best_selected_frame_fov_overlap",
"deterministic_fov_coverage_plucker",
"valid_revisit_mask",
"revisit_candidate_frame_count",
"valid_candidate_label_count",
"valid_revisit_frame_count",
"no_valid_revisit_count",
"revisit_selected_frame_count",
"revisit_frame_fov_overlap",
"revisit_abstained_count",
]:
assert token in retrieval + labels + algorithm + diagnostics
assert "same_video" not in retrieval + labels
assert "wrong_video" not in retrieval + labels
for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]:
assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm
def test_dynamic_compressor_excludes_c_short_contract():
compression = Path("algorithms/worldmem/dememwm/compression.py").read_text()
cache = Path("algorithms/worldmem/dememwm/cache.py").read_text()
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
for token in [
"exclude_latest_local_frames",
"src_frames_b < target - exclude_latest_local_frames",
"dynamic_min_gap_to_target_per_target",
"dynamic_max_gap_to_target_per_target",
"dynamic_exclude_latest_local_frames",
"_local_context_exclusion_frames",
]:
assert token in compression + cache + algorithm
assert "src_frames_b < target, as_tuple=False" not in compression
assert "src < int(target), as_tuple=False" not in cache
def test_eval_ablation_and_noise_bucket_logging_contracts():
schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text()
diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text()
algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text()
for branch in [
"memory_off",
"A_only",
"D_only",
"A_plus_D",
"A_plus_D_plus_R_normal",
"R_forced_off",
"R_forced_on",
"wrong_pose",
"time_shuffle",
"source_matched_random",
"pose_shuffle",
"wrong_video",
"local_context_overlap_fake_revisit",
]:
assert branch in schedules
for token in [
"noise_bucket_from_denoising_fraction",
"noise_bucket_from_noise_levels",
"summarize_noise_bucket_diagnostics",
"noise_bucket_id",
"summarize_eval_ablation_diagnostics",
"eval_bucket_true_revisit_count",
"eval_bucket_no_valid_revisit_count",
"eval_bucket_corrupted_memory_count",
"apply_revisit_eval_corruption",
]:
assert token in schedules + diagnostics + algorithm