FlowMo-WM / tests /test_experiment_framework.py
cccat6's picture
Update tests for LOS controller baseline names
761835a verified
from __future__ import annotations
import importlib
import inspect
from pathlib import Path
import numpy as np
from experiments.run_paper_image_pipeline import load_config, stage_planning, stage_prediction, stage_probe, stage_train
from experiments.shared.src.methods import METHODS, PAPER_LEARNED_METHODS, TRADITIONAL_METHODS
from experiments.shared.src.utils.parameter_count import parameter_count_by_component
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
from experiments.shared.src.vision.pose_from_image import estimate_pose_from_clean_image
from experiments.evaluate_flowmo_latent_probes import fit_ridge, predict_ridge, regression_metrics
def angle_error(a: float, b: float) -> float:
return float(abs(np.arctan2(np.sin(a - b), np.cos(a - b))))
def test_formal_method_registry_contains_only_paper_methods() -> None:
assert PAPER_LEARNED_METHODS == ["flowmo", "leworldmodel", "planet", "tdmpc2"]
assert set(METHODS) == {
"flowmo",
"leworldmodel",
"planet",
"tdmpc2",
"pid_los_controller",
"no_flow_los_controller",
"current_estimator_los_controller",
"oracle_flow_los_controller",
}
assert TRADITIONAL_METHODS == [
"pid_los_controller",
"no_flow_los_controller",
"current_estimator_los_controller",
"oracle_flow_los_controller",
]
assert all(METHODS[name].category == "A_learned_world_model" for name in PAPER_LEARNED_METHODS)
assert all(METHODS[name].category == "B_traditional_controller" for name in TRADITIONAL_METHODS)
def test_paper_method_directories_follow_public_layout() -> None:
for method in [*PAPER_LEARNED_METHODS, *TRADITIONAL_METHODS]:
root = Path("experiments") / method
assert root.is_dir()
assert (root / "src").is_dir()
assert (root / "result").is_dir()
assert (root / "README.md").is_file()
for method in PAPER_LEARNED_METHODS:
assert (Path("experiments") / method / "checkpoint").is_dir()
def test_public_experiment_docs_only_describe_ab_categories() -> None:
docs = [
Path("experiments/README.md"),
Path("experiments/BASELINES.md"),
Path("experiments/EXPERIMENT_MATRIX.md"),
Path("experiments/METHOD_AUDIT.md"),
Path("experiments/TASK_PLAN.md"),
]
forbidden = [
"Category " + "C",
"C " + "类",
"full agent",
"full-agent",
"paper_table",
"related_work_full_agent",
]
for path in docs:
text = path.read_text()
for token in forbidden:
assert token not in text
def test_paper_pipeline_uses_stage_specific_precision() -> None:
class Args:
methods = None
train_episodes = None
test_episodes = None
train_windows = None
test_windows = None
batch_size = None
steps = None
checkpoint_name = "paper.pt"
checkpoint_interval = None
train_workers = None
num_workers = None
device = "cuda"
precision = None
prediction_out = "prediction.json"
probe_results = "probe.json"
planning_episodes = None
max_steps = None
make_gifs = None
gif_stride = 1
gif_duration_ms = 55
cem_horizon = None
cem_population = None
cem_elites = None
cem_iterations = None
cem_action_std = None
cem_knots = None
cem_w_route = None
cem_w_heading_goal = None
cem_w_progress = None
cem_w_action = None
cem_w_smooth = None
cem_w_boundary = None
cem_w_goal = None
cem_w_path = None
cem_w_lookahead = None
cem_w_via = None
cem_route_horizon_distance = None
cem_boundary_margin = None
planning_workers = None
planning_out = "planning"
cfg = load_config("experiments/shared/config/paper_image.json")
def precision_from(cmd: list[str]) -> str:
return cmd[cmd.index("--precision") + 1]
assert precision_from(stage_train(cfg, Args)) == "bf16"
assert precision_from(stage_prediction(cfg, Args)) == "bf16"
assert precision_from(stage_probe(cfg, Args)) == "bf16"
assert precision_from(stage_planning(cfg, Args, "reach_target", "twin", "uniform")) == "fp32"
def test_train_test_and_planning_use_same_flow_families() -> None:
cfg = load_config("experiments/shared/config/paper_image.json")
assert cfg["data"]["train_source"] == "data/paper/train.npz"
assert cfg["data"]["test_source"] == "data/paper/test.npz"
assert cfg["prediction_eval"]["out"] == "experiments/reports/paper_prediction.json"
assert cfg["flow_families"] == [
"noflow",
"uniform",
"vortex_center",
"double_gyre",
"source_sink",
"source_sink_pair",
"gradient",
"shear",
"turbulent_patch",
"random_fourier",
]
def test_flowmo_probe_stage_uses_frozen_checkpoint_and_test_split() -> None:
class Args:
checkpoint_name = "paper.pt"
test_episodes = None
num_workers = None
device = "cuda"
precision = None
probe_results = "probe.json"
cfg = load_config("experiments/shared/config/paper_image.json")
cmd = stage_probe(cfg, Args)
assert "experiments.evaluate_flowmo_latent_probes" in cmd
assert "--checkpoint-name" in cmd
assert cmd[cmd.index("--checkpoint-name") + 1] == "paper.pt"
joined = " ".join(cmd)
assert "test:data/paper/test.npz:480" in joined
def test_paper_learned_world_models_are_parameter_matched() -> None:
totals = []
for method in PAPER_LEARNED_METHODS:
cfg = importlib.import_module(f"experiments.{method}.src.config").default_config()
model = importlib.import_module(f"experiments.{method}.src.model").build_model(cfg)
totals.append(parameter_count_by_component(model)["total"])
assert max(totals) / min(totals) < 1.03
def test_learned_world_models_do_not_use_pose_extractor() -> None:
for method in PAPER_LEARNED_METHODS:
model_module = importlib.import_module(f"experiments.{method}.src.model")
source = inspect.getsource(model_module)
assert "pose_from_image" not in source
assert "estimate_pose_from_clean_image" not in source
for module_name in [
"experiments.train_image_world_models",
"experiments.evaluate_image_world_models",
"experiments.evaluate_image_planning",
]:
source = inspect.getsource(importlib.import_module(module_name))
assert "pose_from_image" not in source
assert "estimate_pose_from_clean_image" not in source
def test_clean_image_pose_extractor_is_accurate_for_traditional_controllers() -> None:
for boat in ["twin", "triangle"]:
for theta in [-2.4, -1.0, 0.2, 1.6, 2.9]:
state = np.array([5.0, 5.0, theta, 0.0, 0.0, 0.0], dtype=np.float32)
image = render_clean_boat_array(state, boat, image_size=160, visual_scale=2.5)
pose = estimate_pose_from_clean_image(image, visual_scale=2.5)
pred_theta = float(np.arctan2(pose[3], pose[2]))
assert np.linalg.norm(pose[:2] - state[:2]) < 0.05
assert angle_error(pred_theta, theta) < 0.09
def test_linear_probe_regression_recovers_linear_signal() -> None:
rng = np.random.default_rng(123)
x = rng.normal(size=(128, 5)).astype(np.float32)
w = rng.normal(size=(5, 2)).astype(np.float32)
y = x @ w + 0.01 * rng.normal(size=(128, 2)).astype(np.float32)
model = fit_ridge(x[:96], y[:96], alpha=1.0e-4)
pred = predict_ridge(model, x[96:])
metrics = regression_metrics(pred, y[96:], ["a", "b"])
assert metrics["r2_mean"] > 0.99