"""Smoke tests for ``POST /predict``. These tests require the quantile-calibration quantile bundle on disk; if it is missing they skip rather than fail so a contributor without the artifact can still run the rest of the suite. """ from __future__ import annotations import pytest from fastapi.testclient import TestClient PRIMARY_TARGETS = { "range_km", "energy_margin_raw_pct", "slope_capability_deg", "total_mass_kg", } @pytest.fixture(autouse=True) def _skip_if_no_v7_1_artifact(surrogate_v7_1_compatible: bool) -> None: if not surrogate_v7_1_compatible: pytest.skip( "schema-v7_1 quantile_bundles.joblib not on disk; skipping " "predict tests (pre-v7_1 bundles lack " "scenario_operational_duty_cycle and KeyError on the v7_1 " "feature row produced by build_feature_row)." ) def test_predict_returns_monotone_quantiles( client: TestClient, sample_design: dict[str, float | int], ) -> None: payload = {"design": sample_design, "scenario_name": "equatorial_mare_traverse"} response = client.post("/predict", json=payload) assert response.status_code == 200, response.text body = response.json() assert body["scenario_name"] == "equatorial_mare_traverse" targets = {p["target"] for p in body["predictions"]} assert targets == PRIMARY_TARGETS for pred in body["predictions"]: # repair_crossings defaults to True -> must be monotone. assert pred["q05"] <= pred["q50"] <= pred["q95"], pred def test_predict_feature_row_includes_categoricals( client: TestClient, sample_design: dict[str, float | int], ) -> None: payload = {"design": sample_design, "scenario_name": "polar_prospecting"} response = client.post("/predict", json=payload) assert response.status_code == 200 body = response.json() cols = body["feature_row"]["columns"] # 27 columns: 11 design (v7 dropped designed_duty_cycle) + 12 scenario # numerics (v7_1 added scenario_operational_duty_cycle; v9 added # scenario_payload_mass_kg + scenario_payload_power_w) + 4 scenario # categoricals. assert len(cols) == 27 assert "scenario_operational_duty_cycle" in cols assert "scenario_payload_mass_kg" in cols assert "scenario_payload_power_w" in cols # Family is forwarded from the scenario name on the canonical four. fam_idx = cols.index("scenario_family") assert body["feature_row"]["values"][fam_idx] == "polar_prospecting" def test_predict_payload_override_reaches_feature_row( client: TestClient, sample_design: dict[str, float | int], ) -> None: """Schema v9: a per-call payload override must land in the echoed feature row so the surrogate scores the mission's own payload, not the scenario default. """ payload = { "design": sample_design, "scenario_name": "equatorial_mare_traverse", "payload_mass_kg": 12.5, "payload_power_w": 7.0, } response = client.post("/predict", json=payload) assert response.status_code == 200, response.text row = response.json()["feature_row"] cols = row["columns"] mass_idx = cols.index("scenario_payload_mass_kg") power_idx = cols.index("scenario_payload_power_w") assert row["values"][mass_idx] == pytest.approx(12.5) assert row["values"][power_idx] == pytest.approx(7.0) def test_predict_mission_duration_override_reaches_feature_row( client: TestClient, sample_design: dict[str, float | int], ) -> None: """A per-call duration override must land in the echoed feature row.""" payload = { "design": sample_design, "scenario_name": "equatorial_mare_traverse", "mission_duration_earth_days": 21.0, } response = client.post("/predict", json=payload) assert response.status_code == 200, response.text row = response.json()["feature_row"] duration_idx = row["columns"].index("scenario_mission_duration_earth_days") assert row["values"][duration_idx] == pytest.approx(21.0) def test_predict_accepts_required_obstacle_height_override( client: TestClient, sample_design: dict[str, float | int], ) -> None: """Per-call obstacle requirement must not 422 on the surrogate route.""" payload = { "design": sample_design, "scenario_name": "equatorial_mare_traverse", "required_obstacle_height_m": 0.12, } response = client.post("/predict", json=payload) assert response.status_code == 200, response.text def test_predict_rejects_out_of_bounds_payload( client: TestClient, sample_design: dict[str, float | int], ) -> None: response = client.post( "/predict", json={ "design": sample_design, "scenario_name": "equatorial_mare_traverse", "payload_power_w": -1.0, }, ) assert response.status_code == 422 def test_predict_unknown_scenario_returns_404( client: TestClient, sample_design: dict[str, float | int], ) -> None: payload = {"design": sample_design, "scenario_name": "nope"} response = client.post("/predict", json=payload) assert response.status_code == 404 def test_predict_rejects_out_of_bounds_design( client: TestClient, sample_design: dict[str, float | int], ) -> None: bad = dict(sample_design) bad["wheel_radius_m"] = 5.0 # schema ceiling is 0.20 m response = client.post( "/predict", json={"design": bad, "scenario_name": "equatorial_mare_traverse"}, ) # Pydantic v2 returns 422 for body validation failures by default. assert response.status_code == 422 def test_predict_raw_quantiles_may_be_non_monotone( client: TestClient, sample_design: dict[str, float | int], ) -> None: """With ``repair_crossings=False`` the API exposes raw model output. The contract here is *not* that crossings will appear (they usually don't on a single point) but that the repair flag is plumbed end to end -- so we just check the response is well-formed. """ payload = { "design": sample_design, "scenario_name": "highland_slope_capability", "repair_crossings": False, } response = client.post("/predict", json=payload) assert response.status_code == 200 body = response.json() for pred in body["predictions"]: for key in ("q05", "q50", "q95"): assert isinstance(pred[key], (int, float))