FlowMo-WM / tests /test_experiment_framework.py
cccat6's picture
Add files using upload-large-folder tool
ee93556 verified
raw
history blame
7.72 kB
from __future__ import annotations
import importlib
import inspect
from pathlib import Path
import numpy as np
from experiments.run_paper_image_pipeline import load_config, stage_planning, stage_prediction, stage_probe, stage_train
from experiments.shared.src.methods import METHODS, PAPER_LEARNED_METHODS, TRADITIONAL_METHODS
from experiments.shared.src.utils.parameter_count import parameter_count_by_component
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
from experiments.shared.src.vision.pose_from_image import estimate_pose_from_clean_image
from experiments.evaluate_flowmo_latent_probes import fit_ridge, predict_ridge, regression_metrics
def angle_error(a: float, b: float) -> float:
return float(abs(np.arctan2(np.sin(a - b), np.cos(a - b))))
def test_formal_method_registry_contains_only_paper_methods() -> None:
assert PAPER_LEARNED_METHODS == ["flowmo", "leworldmodel", "planet", "tdmpc2"]
assert set(METHODS) == {
"flowmo",
"leworldmodel",
"planet",
"tdmpc2",
"pid_los_controller",
"physics_mpc_no_flow",
"current_estimator_mpc",
"oracle_flow_mpc",
}
assert TRADITIONAL_METHODS == [
"pid_los_controller",
"physics_mpc_no_flow",
"current_estimator_mpc",
"oracle_flow_mpc",
]
assert all(METHODS[name].category == "A_learned_world_model" for name in PAPER_LEARNED_METHODS)
assert all(METHODS[name].category == "B_traditional_controller" for name in TRADITIONAL_METHODS)
def test_paper_method_directories_follow_public_layout() -> None:
for method in [*PAPER_LEARNED_METHODS, *TRADITIONAL_METHODS]:
root = Path("experiments") / method
assert root.is_dir()
assert (root / "src").is_dir()
assert (root / "result").is_dir()
assert (root / "README.md").is_file()
for method in PAPER_LEARNED_METHODS:
assert (Path("experiments") / method / "checkpoint").is_dir()
def test_public_experiment_docs_only_describe_ab_categories() -> None:
docs = [
Path("experiments/README.md"),
Path("experiments/BASELINES.md"),
Path("experiments/EXPERIMENT_MATRIX.md"),
Path("experiments/METHOD_AUDIT.md"),
Path("experiments/TASK_PLAN.md"),
]
forbidden = [
"Category " + "C",
"C " + "类",
"full agent",
"full-agent",
"paper_table",
"related_work_full_agent",
]
for path in docs:
text = path.read_text()
for token in forbidden:
assert token not in text
def test_paper_pipeline_uses_stage_specific_precision() -> None:
class Args:
methods = None
train_episodes = None
test_episodes = None
train_windows = None
test_windows = None
batch_size = None
steps = None
checkpoint_name = "paper.pt"
checkpoint_interval = None
num_workers = None
device = "cuda"
precision = None
prediction_unseen_flow = "prediction_unseen_flow.json"
prediction_unseen_boat_params = "prediction_unseen_boat_params.json"
prediction_seen_flow_diagnostic = "prediction_seen_flow_diagnostic.json"
probe_results = "probe.json"
planning_episodes = None
max_steps = None
make_gifs = None
gif_stride = 1
gif_duration_ms = 55
cem_horizon = None
cem_population = None
cem_elites = None
cem_iterations = None
cem_action_std = None
cem_knots = None
cem_w_route = None
cem_w_heading_goal = None
cem_w_progress = None
cem_w_action = None
cem_w_smooth = None
cem_w_boundary = None
planning_out = "planning"
cfg = load_config("experiments/shared/config/paper_image.json")
def precision_from(cmd: list[str]) -> str:
return cmd[cmd.index("--precision") + 1]
assert precision_from(stage_train(cfg, Args)) == "bf16"
assert precision_from(stage_prediction(cfg, Args)) == "bf16"
assert precision_from(stage_probe(cfg, Args)) == "bf16"
assert precision_from(stage_planning(cfg, Args, "reach_uniform", "twin")) == "fp32"
def test_prediction_splits_are_unseen_first_with_seen_diagnostic() -> None:
cfg = load_config("experiments/shared/config/paper_image.json")
splits = cfg["prediction_eval"]["splits"]
assert [s["name"] for s in splits] == ["unseen_flow", "unseen_boat_params", "seen_flow_diagnostic"]
assert splits[0]["primary"] is True
assert splits[1]["primary"] is True
assert splits[2]["primary"] is False
assert cfg["data"]["diagnostic_source"] == "data/paper/diagnostic_seen_flow.npz"
def test_flowmo_probe_stage_uses_frozen_checkpoint_and_all_splits() -> None:
class Args:
checkpoint_name = "paper.pt"
num_workers = None
device = "cuda"
precision = None
probe_results = "probe.json"
cfg = load_config("experiments/shared/config/paper_image.json")
cmd = stage_probe(cfg, Args)
assert "experiments.evaluate_flowmo_latent_probes" in cmd
assert "--checkpoint-name" in cmd
assert cmd[cmd.index("--checkpoint-name") + 1] == "paper.pt"
joined = " ".join(cmd)
assert "unseen_flow:data/paper/test_unseen_flow.npz:480" in joined
assert "unseen_boat_params:data/paper/test_unseen_boat_params.npz:480" in joined
assert "seen_flow_diagnostic:data/paper/diagnostic_seen_flow.npz:480" in joined
def test_paper_learned_world_models_are_parameter_matched() -> None:
totals = []
for method in PAPER_LEARNED_METHODS:
cfg = importlib.import_module(f"experiments.{method}.src.config").default_config()
model = importlib.import_module(f"experiments.{method}.src.model").build_model(cfg)
totals.append(parameter_count_by_component(model)["total"])
assert max(totals) / min(totals) < 1.03
def test_learned_world_models_do_not_use_pose_extractor() -> None:
for method in PAPER_LEARNED_METHODS:
model_module = importlib.import_module(f"experiments.{method}.src.model")
source = inspect.getsource(model_module)
assert "pose_from_image" not in source
assert "estimate_pose_from_clean_image" not in source
for module_name in [
"experiments.train_image_world_models",
"experiments.evaluate_image_world_models",
"experiments.evaluate_image_planning",
]:
source = inspect.getsource(importlib.import_module(module_name))
assert "pose_from_image" not in source
assert "estimate_pose_from_clean_image" not in source
def test_clean_image_pose_extractor_is_accurate_for_traditional_controllers() -> None:
for boat in ["twin", "triangle"]:
for theta in [-2.4, -1.0, 0.2, 1.6, 2.9]:
state = np.array([5.0, 5.0, theta, 0.0, 0.0, 0.0], dtype=np.float32)
image = render_clean_boat_array(state, boat, image_size=160, visual_scale=2.5)
pose = estimate_pose_from_clean_image(image, visual_scale=2.5)
pred_theta = float(np.arctan2(pose[3], pose[2]))
assert np.linalg.norm(pose[:2] - state[:2]) < 0.05
assert angle_error(pred_theta, theta) < 0.09
def test_linear_probe_regression_recovers_linear_signal() -> None:
rng = np.random.default_rng(123)
x = rng.normal(size=(128, 5)).astype(np.float32)
w = rng.normal(size=(5, 2)).astype(np.float32)
y = x @ w + 0.01 * rng.normal(size=(128, 2)).astype(np.float32)
model = fit_ridge(x[:96], y[:96], alpha=1.0e-4)
pred = predict_ridge(model, x[96:])
metrics = regression_metrics(pred, y[96:], ["a", "b"])
assert metrics["r2_mean"] > 0.99