| from pathlib import Path | |
| def test_config_is_distinct_standalone_memory_dit_path(): | |
| text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() | |
| assert "_name: dememwm_memory_dit" in text | |
| assert "base_video_dit" in text | |
| assert "memory_token_cross_attention: true" in text | |
| assert "dememwm:" in text | |
| assert "debug_force_all_streams" in text | |
| assert "ssm_memory" not in text | |
| assert "ssm_memory_ckpt_path" not in text | |
| def test_registry_text_contains_new_algorithm_key(): | |
| exp = Path("experiments/exp_video.py").read_text() | |
| init = Path("algorithms/worldmem/__init__.py").read_text() | |
| assert "DeMemWMMinecraft" in init | |
| assert "DeMemWMMemoryDiTMinecraft" in init | |
| assert "dememwm_memory_dit=DeMemWMMinecraft" in exp | |
| assert "StateSpaceSpatialMemoryMinecraft" not in init + exp | |
| assert "WorldMemMinecraft" not in init + exp | |
| def test_current_config_contract_is_explicit_and_has_no_stale_sections(): | |
| text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() | |
| required = [ | |
| "token_patch_size: 2", | |
| "exclude_latest_local_frames: 4", | |
| "deterministic_pose_retrieval: true", | |
| "max_frames: 2", | |
| "fov_overlap_threshold: 0.30", | |
| "high_quality_fov_threshold: 0.70", | |
| "plucker_weight: 0.10", | |
| "fov_half_h: 52.5", | |
| "fov_half_v: 37.5", | |
| "fov_radius: 30.0", | |
| "fov_yaw_samples: 25", | |
| "fov_pitch_samples: 20", | |
| "fov_depth_samples: 20", | |
| "pose_preselect_topk: 64", | |
| "plucker_grid_h: 4", | |
| "plucker_grid_w: 4", | |
| "plucker_focal_length: 0.35", | |
| "noise_bucket_logging: true", | |
| "eval_ablation:", | |
| "branch: A_plus_D_plus_R_normal", | |
| "generated_history_proxy:", | |
| ] | |
| for token in required: | |
| assert token in text | |
| for forbidden in ( | |
| "anchor_ratio", | |
| "dynamic_ratio", | |
| "revisit_ratio", | |
| "lambda_abstain", | |
| "abstention:", | |
| "force_gate_zero_when_invalid", | |
| "use_residual_bound_loss", | |
| "use_utility_loss", | |
| "use_revisit_classifier_loss", | |
| "min_score", | |
| "generated_penalty", | |
| "min_gap_frames", | |
| "max_chunks", | |
| "chunk_frames", | |
| ): | |
| assert forbidden not in text | |
| def test_full_scripts_use_consumed_contract_overrides(): | |
| required = [ | |
| "algorithm.dememwm.dynamic.exclude_latest_local_frames=4", | |
| "algorithm.dememwm.revisit.deterministic_pose_retrieval=true", | |
| "algorithm.dememwm.revisit.pose_preselect_topk=64", | |
| "algorithm.dememwm.revisit.fov_yaw_samples=25", | |
| "algorithm.dememwm.revisit.fov_pitch_samples=20", | |
| "algorithm.dememwm.revisit.fov_depth_samples=20", | |
| "algorithm.dememwm.revisit.plucker_weight=0.10", | |
| "algorithm.dememwm.stage_policy.noise_bucket_logging=true", | |
| "algorithm.dememwm.cache.keep_compressed_records=true", | |
| ] | |
| stale = [ | |
| "algorithm.dememwm.loss.", | |
| "algorithm.dememwm.abstention.", | |
| "algorithm.dememwm.anchor.topk", | |
| "algorithm.dememwm.anchor.pin_prefix", | |
| "algorithm.dememwm.dynamic.include_generated_recent", | |
| "algorithm.dememwm.revisit.deterministic_only", | |
| "algorithm.dememwm.revisit.min_age_frames", | |
| "algorithm.dememwm.revisit.topk", | |
| "algorithm.dememwm.revisit.min_gap_frames", | |
| "algorithm.dememwm.revisit.max_chunks", | |
| "algorithm.dememwm.revisit.chunk_frames", | |
| "algorithm.dememwm.revisit.min_score", | |
| "algorithm.dememwm.revisit.generated_penalty", | |
| "algorithm.dememwm.rollout.", | |
| ] | |
| expected_by_script = { | |
| "scripts/dememwm_full_train.slurm": [ | |
| "algorithm.dememwm.revisit.fov_overlap_threshold=0.60", | |
| ], | |
| "scripts/dememwm_full_eval.slurm": [ | |
| "algorithm.dememwm.revisit.fov_overlap_threshold=0.30", | |
| "algorithm.dememwm.revisit.high_quality_fov_threshold=0.70", | |
| ], | |
| } | |
| for rel, script_specific_required in expected_by_script.items(): | |
| text = Path(rel).read_text() | |
| for token in required + script_specific_required: | |
| assert token in text, f"{token} missing from {rel}" | |
| for token in stale: | |
| assert token not in text, f"stale {token} override remains in {rel}" | |
| def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args(): | |
| text = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() | |
| assert "_dememwm_validate_config_contract = _validate_config_contract" in text | |
| for token in [ | |
| "_validate_config_contract", | |
| "deterministic_pose_retrieval", | |
| "exclude_latest_local_frames", | |
| "noise_bucket_logging", | |
| "anchor_effective_enabled", | |
| "dynamic_effective_enabled", | |
| "revisit_effective_enabled", | |
| "stale DeMemWM config fields", | |
| "revisit_retrieval_kwargs", | |
| "fov_half_h", | |
| "fov_yaw_samples", | |
| "plucker_grid_h", | |
| "plucker_focal_length", | |
| "pose_preselect_topk", | |
| ]: | |
| assert token in text | |
| assert '_cfg_get(revisit_cfg, "topk"' not in text | |
| assert "lambda_abstain" not in text | |
| def test_revisit_retrieval_is_deterministic_fov_plucker_contract(): | |
| retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text() | |
| labels = Path("algorithms/worldmem/dememwm/labels.py").read_text() | |
| algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() | |
| diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text() | |
| for token in [ | |
| "exclude_local_context_frames", | |
| "fov_overlap_threshold", | |
| "plucker_weight", | |
| "high_quality_fov_threshold", | |
| "best_selected_frame_fov_overlap", | |
| "deterministic_fov_coverage_plucker", | |
| "valid_revisit_mask", | |
| "revisit_candidate_frame_count", | |
| "valid_candidate_label_count", | |
| "valid_revisit_frame_count", | |
| "no_valid_revisit_count", | |
| "revisit_selected_frame_count", | |
| "revisit_frame_fov_overlap", | |
| "revisit_abstained_count", | |
| ]: | |
| assert token in retrieval + labels + algorithm + diagnostics | |
| assert "same_video" not in retrieval + labels | |
| assert "wrong_video" not in retrieval + labels | |
| for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]: | |
| assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm | |
| def test_dynamic_compressor_excludes_c_short_contract(): | |
| compression = Path("algorithms/worldmem/dememwm/compression.py").read_text() | |
| cache = Path("algorithms/worldmem/dememwm/cache.py").read_text() | |
| algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() | |
| for token in [ | |
| "exclude_latest_local_frames", | |
| "src_frames_b < target - exclude_latest_local_frames", | |
| "dynamic_min_gap_to_target_per_target", | |
| "dynamic_max_gap_to_target_per_target", | |
| "dynamic_exclude_latest_local_frames", | |
| "_local_context_exclusion_frames", | |
| ]: | |
| assert token in compression + cache + algorithm | |
| assert "src_frames_b < target, as_tuple=False" not in compression | |
| assert "src < int(target), as_tuple=False" not in cache | |
| def test_eval_ablation_and_noise_bucket_logging_contracts(): | |
| schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text() | |
| diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text() | |
| algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() | |
| for branch in [ | |
| "memory_off", | |
| "A_only", | |
| "D_only", | |
| "A_plus_D", | |
| "A_plus_D_plus_R_normal", | |
| "R_forced_off", | |
| "R_forced_on", | |
| "wrong_pose", | |
| "time_shuffle", | |
| "source_matched_random", | |
| "pose_shuffle", | |
| "wrong_video", | |
| "local_context_overlap_fake_revisit", | |
| ]: | |
| assert branch in schedules | |
| for token in [ | |
| "noise_bucket_from_denoising_fraction", | |
| "noise_bucket_from_noise_levels", | |
| "summarize_noise_bucket_diagnostics", | |
| "noise_bucket_id", | |
| "summarize_eval_ablation_diagnostics", | |
| "eval_bucket_true_revisit_count", | |
| "eval_bucket_no_valid_revisit_count", | |
| "eval_bucket_corrupted_memory_count", | |
| "apply_revisit_eval_corruption", | |
| ]: | |
| assert token in schedules + diagnostics + algorithm | |