from pathlib import Path def test_config_is_distinct_standalone_memory_dit_path(): text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() assert "_name: dememwm_memory_dit" in text assert "base_video_dit" in text assert "memory_token_cross_attention: true" in text assert "dememwm:" in text assert "debug_force_all_streams" in text assert "ssm_memory" not in text assert "ssm_memory_ckpt_path" not in text def test_registry_text_contains_new_algorithm_key(): exp = Path("experiments/exp_video.py").read_text() init = Path("algorithms/worldmem/__init__.py").read_text() assert "DeMemWMMinecraft" in init assert "DeMemWMMemoryDiTMinecraft" in init assert "dememwm_memory_dit=DeMemWMMinecraft" in exp assert "StateSpaceSpatialMemoryMinecraft" not in init + exp assert "WorldMemMinecraft" not in init + exp def test_current_config_contract_is_explicit_and_has_no_stale_sections(): text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() required = [ "token_patch_size: 2", "exclude_latest_local_frames: 4", "deterministic_pose_retrieval: true", "max_frames: 2", "fov_overlap_threshold: 0.30", "high_quality_fov_threshold: 0.70", "plucker_weight: 0.10", "fov_half_h: 52.5", "fov_half_v: 37.5", "fov_radius: 30.0", "fov_yaw_samples: 25", "fov_pitch_samples: 20", "fov_depth_samples: 20", "pose_preselect_topk: 64", "plucker_grid_h: 4", "plucker_grid_w: 4", "plucker_focal_length: 0.35", "noise_bucket_logging: true", "eval_ablation:", "branch: A_plus_D_plus_R_normal", "generated_history_proxy:", ] for token in required: assert token in text for forbidden in ( "anchor_ratio", "dynamic_ratio", "revisit_ratio", "lambda_abstain", "abstention:", "force_gate_zero_when_invalid", "use_residual_bound_loss", "use_utility_loss", "use_revisit_classifier_loss", "min_score", "generated_penalty", "min_gap_frames", "max_chunks", "chunk_frames", ): assert forbidden not in text def test_full_scripts_use_consumed_contract_overrides(): required = [ "algorithm.dememwm.dynamic.exclude_latest_local_frames=4", "algorithm.dememwm.revisit.deterministic_pose_retrieval=true", "algorithm.dememwm.revisit.pose_preselect_topk=64", "algorithm.dememwm.revisit.fov_yaw_samples=25", "algorithm.dememwm.revisit.fov_pitch_samples=20", "algorithm.dememwm.revisit.fov_depth_samples=20", "algorithm.dememwm.revisit.plucker_weight=0.10", "algorithm.dememwm.stage_policy.noise_bucket_logging=true", "algorithm.dememwm.cache.keep_compressed_records=true", ] stale = [ "algorithm.dememwm.loss.", "algorithm.dememwm.abstention.", "algorithm.dememwm.anchor.topk", "algorithm.dememwm.anchor.pin_prefix", "algorithm.dememwm.dynamic.include_generated_recent", "algorithm.dememwm.revisit.deterministic_only", "algorithm.dememwm.revisit.min_age_frames", "algorithm.dememwm.revisit.topk", "algorithm.dememwm.revisit.min_gap_frames", "algorithm.dememwm.revisit.max_chunks", "algorithm.dememwm.revisit.chunk_frames", "algorithm.dememwm.revisit.min_score", "algorithm.dememwm.revisit.generated_penalty", "algorithm.dememwm.rollout.", ] expected_by_script = { "scripts/dememwm_full_train.slurm": [ "algorithm.dememwm.revisit.fov_overlap_threshold=0.60", ], "scripts/dememwm_full_eval.slurm": [ "algorithm.dememwm.revisit.fov_overlap_threshold=0.30", "algorithm.dememwm.revisit.high_quality_fov_threshold=0.70", ], } for rel, script_specific_required in expected_by_script.items(): text = Path(rel).read_text() for token in required + script_specific_required: assert token in text, f"{token} missing from {rel}" for token in stale: assert token not in text, f"stale {token} override remains in {rel}" def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args(): text = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() assert "_dememwm_validate_config_contract = _validate_config_contract" in text for token in [ "_validate_config_contract", "deterministic_pose_retrieval", "exclude_latest_local_frames", "noise_bucket_logging", "anchor_effective_enabled", "dynamic_effective_enabled", "revisit_effective_enabled", "stale DeMemWM config fields", "revisit_retrieval_kwargs", "fov_half_h", "fov_yaw_samples", "plucker_grid_h", "plucker_focal_length", "pose_preselect_topk", ]: assert token in text assert '_cfg_get(revisit_cfg, "topk"' not in text assert "lambda_abstain" not in text def test_revisit_retrieval_is_deterministic_fov_plucker_contract(): retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text() labels = Path("algorithms/worldmem/dememwm/labels.py").read_text() algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text() for token in [ "exclude_local_context_frames", "fov_overlap_threshold", "plucker_weight", "high_quality_fov_threshold", "best_selected_frame_fov_overlap", "deterministic_fov_coverage_plucker", "valid_revisit_mask", "revisit_candidate_frame_count", "valid_candidate_label_count", "valid_revisit_frame_count", "no_valid_revisit_count", "revisit_selected_frame_count", "revisit_frame_fov_overlap", "revisit_abstained_count", ]: assert token in retrieval + labels + algorithm + diagnostics assert "same_video" not in retrieval + labels assert "wrong_video" not in retrieval + labels for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]: assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm def test_dynamic_compressor_excludes_c_short_contract(): compression = Path("algorithms/worldmem/dememwm/compression.py").read_text() cache = Path("algorithms/worldmem/dememwm/cache.py").read_text() algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() for token in [ "exclude_latest_local_frames", "src_frames_b < target - exclude_latest_local_frames", "dynamic_min_gap_to_target_per_target", "dynamic_max_gap_to_target_per_target", "dynamic_exclude_latest_local_frames", "_local_context_exclusion_frames", ]: assert token in compression + cache + algorithm assert "src_frames_b < target, as_tuple=False" not in compression assert "src < int(target), as_tuple=False" not in cache def test_eval_ablation_and_noise_bucket_logging_contracts(): schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text() diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text() algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() for branch in [ "memory_off", "A_only", "D_only", "A_plus_D", "A_plus_D_plus_R_normal", "R_forced_off", "R_forced_on", "wrong_pose", "time_shuffle", "source_matched_random", "pose_shuffle", "wrong_video", "local_context_overlap_fake_revisit", ]: assert branch in schedules for token in [ "noise_bucket_from_denoising_fraction", "noise_bucket_from_noise_levels", "summarize_noise_bucket_diagnostics", "noise_bucket_id", "summarize_eval_ablation_diagnostics", "eval_bucket_true_revisit_count", "eval_bucket_no_valid_revisit_count", "eval_bucket_corrupted_memory_count", "apply_revisit_eval_corruption", ]: assert token in schedules + diagnostics + algorithm