Spaces:
Running
Running
| """Server endpoint tests. | |
| API 02 adds POST /reset endpoint tests. | |
| API 04 adds a smoke test for GET /scenarios. | |
| API 13 adds CORS middleware verification tests. | |
| API 03 adds POST /step endpoint tests. | |
| API 06 adds WebSocket session handler tests. | |
| API 07 adds idle-timeout and graceful disconnect cleanup tests. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import time | |
| from unittest.mock import patch | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from starlette.websockets import WebSocketDisconnect | |
| from replicalab.models import ScientistAction | |
| from server.app import app | |
| _EXPECTED_FAMILIES = {"math_reasoning", "ml_benchmark", "finance_trading"} | |
| _EXPECTED_DIFFICULTIES = ["easy", "medium", "hard"] | |
| def client(): | |
| return TestClient(app) | |
| class TestHealthEndpoint: | |
| """GET /health — API 01.""" | |
| def test_health_returns_200(self, client: TestClient) -> None: | |
| resp = client.get("/health") | |
| assert resp.status_code == 200 | |
| def test_health_payload_has_stable_keys(self, client: TestClient) -> None: | |
| data = client.get("/health").json() | |
| assert data["status"] == "ok" | |
| assert data["env"] in ("real", "stub") | |
| assert "version" in data | |
| def test_health_version_matches_app(self, client: TestClient) -> None: | |
| from server.app import app as _app | |
| data = client.get("/health").json() | |
| assert data["version"] == _app.version | |
| def test_health_is_deterministic(self, client: TestClient) -> None: | |
| r1 = client.get("/health").json() | |
| r2 = client.get("/health").json() | |
| assert r1 == r2 | |
| class TestRuntimeEndpoint: | |
| """Local runtime metadata for model-backed Scientist stepping.""" | |
| def test_runtime_defaults_to_baseline_without_api_key( | |
| self, client: TestClient, monkeypatch | |
| ) -> None: | |
| monkeypatch.delenv("REPLICALAB_SCIENTIST_RUNTIME", raising=False) | |
| monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) | |
| resp = client.get("/runtime") | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["scientist_runtime"] == "baseline" | |
| assert data["scientist_model"] == "baseline-heuristic" | |
| assert data["agent_step_available"] is True | |
| def test_runtime_reports_anthropic_when_enabled( | |
| self, client: TestClient, monkeypatch | |
| ) -> None: | |
| monkeypatch.setenv("REPLICALAB_SCIENTIST_RUNTIME", "anthropic") | |
| monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") | |
| resp = client.get("/runtime") | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["scientist_runtime"] == "anthropic" | |
| assert data["scientist_ready"] is True | |
| assert data["agent_step_available"] is True | |
| class TestLogConfig: | |
| """OBS 02 — log level configurability.""" | |
| def test_default_log_level_is_info(self) -> None: | |
| from replicalab.config import LOG_LEVEL | |
| # Default when REPLICALAB_LOG_LEVEL is not set | |
| assert LOG_LEVEL in ("INFO", "DEBUG", "WARNING", "ERROR") | |
| def test_log_level_env_var_is_respected(self, monkeypatch) -> None: | |
| """REPLICALAB_LOG_LEVEL env var controls the log level.""" | |
| import importlib | |
| import replicalab.config as config_mod | |
| monkeypatch.setenv("REPLICALAB_LOG_LEVEL", "debug") | |
| importlib.reload(config_mod) | |
| assert config_mod.LOG_LEVEL == "DEBUG" | |
| # Restore | |
| monkeypatch.delenv("REPLICALAB_LOG_LEVEL", raising=False) | |
| importlib.reload(config_mod) | |
| def test_log_format_is_readable(self) -> None: | |
| from replicalab.config import LOG_FORMAT | |
| assert "%(asctime)s" in LOG_FORMAT | |
| assert "%(levelname)s" in LOG_FORMAT | |
| assert "%(name)s" in LOG_FORMAT | |
| class TestRootEndpoint: | |
| """GET / — lightweight landing page for hosted backend deployments.""" | |
| def test_root_returns_200_html(self, client: TestClient) -> None: | |
| resp = client.get("/") | |
| assert resp.status_code == 200 | |
| assert "text/html" in resp.headers["content-type"] | |
| def test_root_mentions_core_api_endpoints(self, client: TestClient) -> None: | |
| body = client.get("/").text | |
| # When frontend/dist exists, root serves the SPA; otherwise the API landing | |
| assert "ReplicaLab" in body | |
| if "ReplicaLab API" in body: | |
| assert "GET /health" in body | |
| assert "POST /reset" in body | |
| class TestWebFallback: | |
| """GET /web — API 19: OpenEnv fallback UI.""" | |
| def test_web_returns_200_html(self, client: TestClient) -> None: | |
| resp = client.get("/web") | |
| assert resp.status_code == 200 | |
| assert "text/html" in resp.headers["content-type"] | |
| def test_web_contains_interactive_controls(self, client: TestClient) -> None: | |
| body = client.get("/web").text | |
| assert "ReplicaLab" in body | |
| assert "btnReset" in body | |
| assert "btnPropose" in body | |
| assert "btnAccept" in body | |
| assert "/reset" in body | |
| assert "/step" in body | |
| def test_web_is_self_contained(self, client: TestClient) -> None: | |
| """Fallback UI must work without external JS/CSS dependencies.""" | |
| body = client.get("/web").text | |
| assert "<script>" in body | |
| assert "<style>" in body | |
| class TestScenariosEndpoint: | |
| """GET /scenarios — API 04.""" | |
| def test_returns_200(self, client: TestClient): | |
| resp = client.get("/scenarios") | |
| assert resp.status_code == 200 | |
| def test_response_has_scenarios_key(self, client: TestClient): | |
| data = client.get("/scenarios").json() | |
| assert "scenarios" in data | |
| assert isinstance(data["scenarios"], list) | |
| def test_all_families_present(self, client: TestClient): | |
| data = client.get("/scenarios").json() | |
| families = {s["family"] for s in data["scenarios"]} | |
| assert families == _EXPECTED_FAMILIES | |
| def test_each_family_has_difficulties(self, client: TestClient): | |
| data = client.get("/scenarios").json() | |
| for entry in data["scenarios"]: | |
| assert entry["difficulties"] == _EXPECTED_DIFFICULTIES | |
| def test_no_extra_keys(self, client: TestClient): | |
| data = client.get("/scenarios").json() | |
| for entry in data["scenarios"]: | |
| assert set(entry.keys()) == {"family", "difficulties"} | |
| # --------------------------------------------------------------------------- | |
| # POST /reset — API 02 | |
| # --------------------------------------------------------------------------- | |
| class TestCorsConfiguration: | |
| """API 13: CORS middleware for local frontend and HF Spaces.""" | |
| def test_preflight_allows_localhost_vite_origin(self, client: TestClient) -> None: | |
| resp = client.options( | |
| "/reset", | |
| headers={ | |
| "Origin": "http://localhost:5173", | |
| "Access-Control-Request-Method": "POST", | |
| }, | |
| ) | |
| assert resp.status_code == 200 | |
| assert resp.headers["access-control-allow-origin"] == "http://localhost:5173" | |
| assert resp.headers["access-control-allow-credentials"] == "true" | |
| def test_preflight_allows_hf_space_origin(self, client: TestClient) -> None: | |
| origin = "https://replicalab-demo.hf.space" | |
| resp = client.options( | |
| "/health", | |
| headers={ | |
| "Origin": origin, | |
| "Access-Control-Request-Method": "GET", | |
| }, | |
| ) | |
| assert resp.status_code == 200 | |
| assert resp.headers["access-control-allow-origin"] == origin | |
| assert resp.headers["access-control-allow-credentials"] == "true" | |
| def test_preflight_rejects_unconfigured_origin(self, client: TestClient) -> None: | |
| resp = client.options( | |
| "/reset", | |
| headers={ | |
| "Origin": "https://evil.example.com", | |
| "Access-Control-Request-Method": "POST", | |
| }, | |
| ) | |
| assert resp.status_code == 400 | |
| assert "access-control-allow-origin" not in resp.headers | |
| class TestResetEndpoint: | |
| """POST /reset — API 02.""" | |
| def test_reset_returns_200_with_expected_keys(self, client: TestClient) -> None: | |
| resp = client.post("/reset", json={"seed": 1}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "session_id" in data | |
| assert "episode_id" in data | |
| assert "observation" in data | |
| def test_reset_observation_has_both_roles(self, client: TestClient) -> None: | |
| data = client.post("/reset", json={"seed": 1}).json() | |
| obs = data["observation"] | |
| assert "scientist" in obs | |
| assert "lab_manager" in obs | |
| assert obs["scientist"]["paper_title"] | |
| assert obs["lab_manager"]["budget_total"] > 0 | |
| def test_reset_with_explicit_session_id_reuses_slot( | |
| self, client: TestClient | |
| ) -> None: | |
| """Passing session_id reuses the same slot and returns the same id.""" | |
| sid = "my-fixed-session" | |
| d1 = client.post("/reset", json={"seed": 1, "session_id": sid}).json() | |
| assert d1["session_id"] == sid | |
| d2 = client.post("/reset", json={"seed": 2, "session_id": sid}).json() | |
| assert d2["session_id"] == sid | |
| # New episode each time | |
| assert d2["episode_id"] != d1["episode_id"] | |
| def test_reset_reuse_closes_prior_env(self, client: TestClient) -> None: | |
| """Resetting with the same session_id produces a fresh episode.""" | |
| sid = "reuse-session" | |
| d1 = client.post("/reset", json={"seed": 10, "session_id": sid}).json() | |
| ep1 = d1["episode_id"] | |
| d2 = client.post("/reset", json={"seed": 20, "session_id": sid}).json() | |
| ep2 = d2["episode_id"] | |
| assert ep1 != ep2 | |
| def test_reset_default_params(self, client: TestClient) -> None: | |
| """Omitting scenario and difficulty uses defaults without error.""" | |
| resp = client.post("/reset", json={"seed": 0}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["observation"]["scientist"]["paper_title"] | |
| def test_reset_custom_scenario_and_difficulty(self, client: TestClient) -> None: | |
| for family in ("math_reasoning", "ml_benchmark", "finance_trading"): | |
| for diff in ("easy", "medium", "hard"): | |
| resp = client.post( | |
| "/reset", | |
| json={"seed": 42, "scenario": family, "difficulty": diff}, | |
| ) | |
| assert resp.status_code == 200, f"Failed for {family}/{diff}" | |
| obs = resp.json()["observation"] | |
| assert obs["scientist"]["paper_title"] | |
| assert obs["lab_manager"]["budget_total"] > 0 | |
| def test_reset_deterministic_with_same_seed(self, client: TestClient) -> None: | |
| """Same seed + scenario + difficulty → identical observations.""" | |
| params = {"seed": 99, "scenario": "math_reasoning", "difficulty": "medium"} | |
| d1 = client.post("/reset", json=params).json() | |
| d2 = client.post("/reset", json=params).json() | |
| assert d1["observation"] == d2["observation"] | |
| # Episode ids differ (new UUID each time) | |
| assert d1["episode_id"] != d2["episode_id"] | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _reset(client: TestClient, **kwargs) -> dict: | |
| """Reset and return the response JSON.""" | |
| payload = {"seed": 42, "scenario": "math_reasoning", "difficulty": "easy"} | |
| payload.update(kwargs) | |
| resp = client.post("/reset", json=payload) | |
| assert resp.status_code == 200 | |
| return resp.json() | |
| def _good_action_payload(client: TestClient) -> dict: | |
| """Build a valid propose_protocol action payload from a fresh scenario.""" | |
| from replicalab.scenarios import generate_scenario | |
| scenario = generate_scenario(seed=42, template="math_reasoning", difficulty="easy") | |
| lab = scenario.lab_manager_observation | |
| spec = scenario.hidden_reference_spec | |
| return { | |
| "action_type": "propose_protocol", | |
| "sample_size": 10, | |
| "controls": ["baseline", "ablation"], | |
| "technique": spec.summary[:60] if spec.summary else "replication_plan", | |
| "duration_days": max(1, min(2, lab.time_limit_days)), | |
| "required_equipment": ( | |
| list(lab.equipment_available[:1]) if lab.equipment_available else [] | |
| ), | |
| "required_reagents": ( | |
| list(lab.reagents_in_stock[:1]) if lab.reagents_in_stock else [] | |
| ), | |
| "questions": [], | |
| "rationale": ( | |
| f"Plan addresses: {', '.join(spec.required_elements[:2])}. " | |
| f"Target metric: {spec.target_metric}. " | |
| f"Target value: {spec.target_value}. " | |
| "Stay within budget and schedule." | |
| ), | |
| } | |
| def _accept_action_payload() -> dict: | |
| return { | |
| "action_type": "accept", | |
| "sample_size": 0, | |
| "controls": [], | |
| "technique": "", | |
| "duration_days": 0, | |
| "required_equipment": [], | |
| "required_reagents": [], | |
| "questions": [], | |
| "rationale": "", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # POST /step — API 03 | |
| # --------------------------------------------------------------------------- | |
| class TestStepEndpoint: | |
| """POST /step — API 03.""" | |
| def test_reset_then_step_happy_path(self, client: TestClient) -> None: | |
| """Reset, then step with a valid action → 200 with StepResult.""" | |
| reset_data = _reset(client) | |
| session_id = reset_data["session_id"] | |
| action = _good_action_payload(client) | |
| resp = client.post("/step", json={"session_id": session_id, "action": action}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "observation" in data | |
| assert "reward" in data | |
| assert "done" in data | |
| assert "info" in data | |
| assert data["done"] is False | |
| assert data["info"]["error"] is None | |
| def test_step_invalid_session_returns_404(self, client: TestClient) -> None: | |
| """Step with a non-existent session_id → 404.""" | |
| action = _good_action_payload(client) | |
| resp = client.post( | |
| "/step", | |
| json={"session_id": "nonexistent-session-id", "action": action}, | |
| ) | |
| assert resp.status_code == 404 | |
| assert "Session not found" in resp.json()["detail"] | |
| def test_terminal_step_returns_real_reward_breakdown( | |
| self, client: TestClient | |
| ) -> None: | |
| """Propose → accept: terminal step has real reward_breakdown, | |
| judge_notes, and verdict from the env (not stubs).""" | |
| reset_data = _reset(client) | |
| session_id = reset_data["session_id"] | |
| # Step 1: propose | |
| action = _good_action_payload(client) | |
| resp1 = client.post("/step", json={"session_id": session_id, "action": action}) | |
| assert resp1.status_code == 200 | |
| assert resp1.json()["done"] is False | |
| # Step 2: accept | |
| resp2 = client.post( | |
| "/step", | |
| json={"session_id": session_id, "action": _accept_action_payload()}, | |
| ) | |
| assert resp2.status_code == 200 | |
| data = resp2.json() | |
| assert data["done"] is True | |
| assert data["reward"] > 0.0 | |
| info = data["info"] | |
| assert info["agreement_reached"] is True | |
| assert info["verdict"] == "accept" | |
| assert info["judge_notes"] is not None | |
| assert "rigor" in info["judge_notes"] | |
| rb = info["reward_breakdown"] | |
| assert rb is not None | |
| assert 0.0 <= rb["rigor"] <= 1.0 | |
| assert 0.0 <= rb["feasibility"] <= 1.0 | |
| assert 0.0 <= rb["fidelity"] <= 1.0 | |
| # Verify it's not the old stub 0.8 | |
| assert not (rb["rigor"] == 0.8 and rb["feasibility"] == 0.8 and rb["fidelity"] == 0.8) | |
| def test_semantic_invalid_action_returns_200_with_error( | |
| self, client: TestClient | |
| ) -> None: | |
| """A semantically invalid action (e.g. duration=999) returns 200 | |
| with info.error set, not a crash or 422.""" | |
| reset_data = _reset(client) | |
| session_id = reset_data["session_id"] | |
| bad_action = { | |
| "action_type": "propose_protocol", | |
| "sample_size": 5, | |
| "controls": ["baseline"], | |
| "technique": "some technique", | |
| "duration_days": 999, | |
| "required_equipment": [], | |
| "required_reagents": [], | |
| "questions": [], | |
| "rationale": "Duration is impossibly long for the lab time limit.", | |
| } | |
| resp = client.post( | |
| "/step", json={"session_id": session_id, "action": bad_action} | |
| ) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["done"] is False | |
| assert data["info"]["error"] is not None | |
| assert "Validation errors" in data["info"]["error"] | |
| def test_replay_uses_real_judge_data(self, client: TestClient) -> None: | |
| """After a terminal step, GET /replay/{episode_id} returns | |
| real judge_notes and verdict, not stub values.""" | |
| reset_data = _reset(client) | |
| session_id = reset_data["session_id"] | |
| episode_id = reset_data["episode_id"] | |
| # Propose then accept | |
| action = _good_action_payload(client) | |
| client.post("/step", json={"session_id": session_id, "action": action}) | |
| client.post( | |
| "/step", | |
| json={"session_id": session_id, "action": _accept_action_payload()}, | |
| ) | |
| # Fetch replay | |
| resp = client.get(f"/replay/{episode_id}") | |
| assert resp.status_code == 200 | |
| replay = resp.json() | |
| assert replay["agreement_reached"] is True | |
| assert "rigor" in replay["judge_notes"] | |
| assert replay["verdict"] == "accept" | |
| assert replay["reward_breakdown"] is not None | |
| assert replay["total_reward"] > 0.0 | |
| # Not the old stub string | |
| assert "Stub audit" not in replay["judge_notes"] | |
| def test_replay_includes_top_failure_reasons(self, client: TestClient) -> None: | |
| """Terminal replay records persist the canonical audit reasons.""" | |
| reset_data = _reset(client) | |
| session_id = reset_data["session_id"] | |
| episode_id = reset_data["episode_id"] | |
| # Force a timeout path so the audit builder emits failure reasons. | |
| for _ in range(6): | |
| resp = client.post( | |
| "/step", | |
| json={"session_id": session_id, "action": _good_action_payload(client)}, | |
| ) | |
| assert resp.status_code == 200 | |
| if resp.json()["done"]: | |
| break | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| assert replay["verdict"] == "timeout" | |
| assert isinstance(replay["top_failure_reasons"], list) | |
| assert replay["top_failure_reasons"] | |
| assert any( | |
| "round limit" in reason.lower() or "without agreement" in reason.lower() | |
| for reason in replay["top_failure_reasons"] | |
| ) | |
| class TestAgentStepEndpoint: | |
| """POST /agent-step uses the configured Scientist runtime.""" | |
| def test_agent_step_runs_runtime_policy( | |
| self, client: TestClient, monkeypatch | |
| ) -> None: | |
| import server.app as server_app | |
| monkeypatch.setenv("REPLICALAB_SCIENTIST_RUNTIME", "anthropic") | |
| monkeypatch.setenv("ANTHROPIC_API_KEY", "test-key") | |
| server_app._SCIENTIST_POLICY_CACHE.clear() | |
| action = ScientistAction.model_validate(_good_action_payload(client)) | |
| def fake_policy(observation, **kwargs): | |
| assert observation.paper_title | |
| assert kwargs["scenario"] == "ml_benchmark" | |
| assert kwargs["difficulty"] == "medium" | |
| return action | |
| monkeypatch.setattr(server_app, "_get_scientist_policy", lambda: fake_policy) | |
| reset_data = _reset(client, scenario="ml_benchmark", difficulty="medium") | |
| resp = client.post("/agent-step", json={"session_id": reset_data["session_id"]}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["info"]["scientist_runtime"] == "anthropic" | |
| assert data["info"]["scientist_model"] | |
| assert data["info"]["scientist_action"]["action_type"] == "propose_protocol" | |
| assert data["observation"]["scientist"]["round_number"] == 1 | |
| def test_agent_step_invalid_session_returns_404(self, client: TestClient) -> None: | |
| resp = client.post("/agent-step", json={"session_id": "missing"}) | |
| assert resp.status_code == 404 | |
| assert "Session not found" in resp.json()["detail"] | |
| def test_agent_step_falls_back_to_baseline_on_runtime_failure( | |
| self, client: TestClient, monkeypatch | |
| ) -> None: | |
| import server.app as server_app | |
| monkeypatch.setenv("REPLICALAB_SCIENTIST_RUNTIME", "ollama") | |
| monkeypatch.setattr( | |
| server_app, | |
| "_resolve_scientist_action", | |
| lambda session: (_ for _ in ()).throw(RuntimeError("model timeout")), | |
| ) | |
| reset_data = _reset(client, scenario="math_reasoning", difficulty="easy") | |
| resp = client.post("/agent-step", json={"session_id": reset_data["session_id"]}) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["info"]["scientist_runtime"] == "ollama_fallback" | |
| assert data["info"]["scientist_error"] == "model timeout" | |
| # --------------------------------------------------------------------------- | |
| # WebSocket handler — API 06 | |
| # --------------------------------------------------------------------------- | |
| def _ws_send_recv(ws, msg: dict) -> dict: | |
| """Send a JSON message over the WebSocket and return the parsed response.""" | |
| ws.send_text(json.dumps(msg)) | |
| return json.loads(ws.receive_text()) | |
| class TestWebSocket: | |
| """API 06: WebSocket session handler with isolated env per connection.""" | |
| # -- basic connectivity -------------------------------------------------- | |
| def test_ws_ping_pong(self, client: TestClient) -> None: | |
| with client.websocket_connect("/ws") as ws: | |
| resp = _ws_send_recv(ws, {"type": "ping"}) | |
| assert resp["type"] == "pong" | |
| def test_ws_reset_returns_observation(self, client: TestClient) -> None: | |
| with client.websocket_connect("/ws") as ws: | |
| resp = _ws_send_recv(ws, { | |
| "type": "reset", "seed": 42, | |
| "scenario": "math_reasoning", "difficulty": "easy", | |
| }) | |
| assert resp["type"] == "reset_ok" | |
| assert resp["episode_id"] | |
| obs = resp["observation"] | |
| assert obs["scientist"]["paper_title"] | |
| assert obs["lab_manager"]["budget_total"] > 0 | |
| def test_ws_step_returns_result(self, client: TestClient) -> None: | |
| action = _good_action_payload(client) | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| resp = _ws_send_recv(ws, {"type": "step", "action": action}) | |
| assert resp["type"] == "step_ok" | |
| assert resp["done"] is False | |
| assert resp["reward"] > 0.0 | |
| assert resp["info"]["step_reward_components"]["protocol_delta_bonus"] > 0.0 | |
| assert resp["observation"] is not None | |
| def test_ws_full_episode_real_reward(self, client: TestClient) -> None: | |
| """Propose → accept returns real reward breakdown, not stub 0.8.""" | |
| action = _good_action_payload(client) | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| _ws_send_recv(ws, {"type": "step", "action": action}) | |
| resp = _ws_send_recv(ws, {"type": "step", "action": _accept_action_payload()}) | |
| assert resp["type"] == "step_ok" | |
| assert resp["done"] is True | |
| assert resp["reward"] > 0.0 | |
| info = resp["info"] | |
| assert info["agreement_reached"] is True | |
| assert info["verdict"] == "accept" | |
| rb = info["reward_breakdown"] | |
| assert rb is not None | |
| assert 0.0 <= rb["rigor"] <= 1.0 | |
| assert 0.0 <= rb["feasibility"] <= 1.0 | |
| assert 0.0 <= rb["fidelity"] <= 1.0 | |
| assert not (rb["rigor"] == 0.8 and rb["feasibility"] == 0.8) | |
| # -- error handling ------------------------------------------------------ | |
| def test_ws_invalid_json(self, client: TestClient) -> None: | |
| with client.websocket_connect("/ws") as ws: | |
| ws.send_text("not valid json {{{") | |
| resp = json.loads(ws.receive_text()) | |
| assert resp["type"] == "error" | |
| assert "Invalid JSON" in resp["message"] | |
| def test_ws_missing_action_field(self, client: TestClient) -> None: | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| resp = _ws_send_recv(ws, {"type": "step"}) | |
| assert resp["type"] == "error" | |
| assert "Missing" in resp["message"] | |
| def test_ws_invalid_action_payload(self, client: TestClient) -> None: | |
| """Structurally invalid action (missing required fields) → WS error.""" | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| resp = _ws_send_recv(ws, { | |
| "type": "step", | |
| "action": {"action_type": "propose_protocol"}, | |
| }) | |
| assert resp["type"] == "error" | |
| assert "Invalid action" in resp["message"] | |
| def test_ws_unknown_message_type(self, client: TestClient) -> None: | |
| with client.websocket_connect("/ws") as ws: | |
| resp = _ws_send_recv(ws, {"type": "banana"}) | |
| assert resp["type"] == "error" | |
| assert "Unknown" in resp["message"] | |
| # -- session isolation --------------------------------------------------- | |
| def test_ws_session_isolation(self, client: TestClient) -> None: | |
| """Two WebSocket connections have independent env state.""" | |
| action = _good_action_payload(client) | |
| with client.websocket_connect("/ws") as ws1: | |
| r1 = _ws_send_recv(ws1, {"type": "reset", "seed": 1}) | |
| _ws_send_recv(ws1, {"type": "step", "action": action}) | |
| with client.websocket_connect("/ws") as ws2: | |
| r2 = _ws_send_recv(ws2, {"type": "reset", "seed": 2}) | |
| assert r1["episode_id"] != r2["episode_id"] | |
| # ws2 is at round 0, ws1 is at round 1 | |
| step2 = _ws_send_recv(ws2, {"type": "step", "action": action}) | |
| assert step2["observation"]["scientist"]["round_number"] == 1 | |
| # -- real-env integration (user-requested) -------------------------------- | |
| def test_ws_semantic_invalid_action_returns_step_ok_with_info_error( | |
| self, client: TestClient | |
| ) -> None: | |
| """A structurally valid but semantically invalid action (e.g. | |
| duration_days=999) returns step_ok with info.error — NOT a | |
| transport-level WS error frame.""" | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| bad_action = { | |
| "action_type": "propose_protocol", | |
| "sample_size": 5, | |
| "controls": ["baseline"], | |
| "technique": "some technique", | |
| "duration_days": 999, | |
| "required_equipment": [], | |
| "required_reagents": [], | |
| "questions": [], | |
| "rationale": "Duration is impossibly long for the lab.", | |
| } | |
| resp = _ws_send_recv(ws, {"type": "step", "action": bad_action}) | |
| assert resp["type"] == "step_ok" | |
| assert resp["done"] is False | |
| assert resp["info"]["error"] is not None | |
| assert "Validation errors" in resp["info"]["error"] | |
| def test_ws_timeout_verdict(self, client: TestClient) -> None: | |
| """Run to max_rounds without accept → done=True, verdict=timeout, | |
| reward=0.0. Proves real-env integration.""" | |
| action = _good_action_payload(client) | |
| with client.websocket_connect("/ws") as ws: | |
| reset_resp = _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| max_rounds = reset_resp["observation"]["scientist"]["max_rounds"] | |
| resp = None | |
| for _ in range(max_rounds): | |
| resp = _ws_send_recv(ws, {"type": "step", "action": action}) | |
| assert resp["done"] is True | |
| assert resp["info"]["verdict"] == "timeout" | |
| assert resp["reward"] < 0.0 | |
| assert resp["info"]["reward_breakdown"] is not None | |
| assert resp["info"]["reward_breakdown"]["penalties"]["timeout"] > 0.0 | |
| def test_ws_terminal_episode_persists_real_replay_log( | |
| self, client: TestClient | |
| ) -> None: | |
| """Complete a WS episode, then verify GET /replay/{episode_id} | |
| returns real reward_breakdown, judge_notes, and verdict — | |
| not stub strings.""" | |
| action = _good_action_payload(client) | |
| with client.websocket_connect("/ws") as ws: | |
| reset_resp = _ws_send_recv(ws, {"type": "reset", "seed": 42}) | |
| episode_id = reset_resp["episode_id"] | |
| _ws_send_recv(ws, {"type": "step", "action": action}) | |
| _ws_send_recv(ws, {"type": "step", "action": _accept_action_payload()}) | |
| # Fetch replay via REST after WS connection is closed | |
| replay_resp = client.get(f"/replay/{episode_id}") | |
| assert replay_resp.status_code == 200 | |
| replay = replay_resp.json() | |
| assert replay["agreement_reached"] is True | |
| assert replay["verdict"] == "accept" | |
| assert replay["total_reward"] > 0.0 | |
| # Real judge_notes, not stub | |
| assert replay["judge_notes"] != "" | |
| assert "Stub audit" not in replay["judge_notes"] | |
| assert "rigor" in replay["judge_notes"] | |
| # Real reward_breakdown with non-stub scores | |
| rb = replay["reward_breakdown"] | |
| assert rb is not None | |
| assert 0.0 < rb["rigor"] <= 1.0 | |
| assert 0.0 < rb["feasibility"] <= 1.0 | |
| assert 0.0 < rb["fidelity"] <= 1.0 | |
| assert not (rb["rigor"] == 0.8 and rb["feasibility"] == 0.8) | |
| # -- idle timeout & disconnect cleanup (API 07) ------------------------- | |
| def test_ws_idle_timeout_closes_connection(self, client: TestClient) -> None: | |
| """API 07: server closes WebSocket after idle timeout (no messages).""" | |
| with patch("server.app._WS_IDLE_TIMEOUT", 0.5): | |
| with client.websocket_connect("/ws") as ws: | |
| # Don't send anything — let the server-side timeout fire | |
| time.sleep(1.0) | |
| with pytest.raises(WebSocketDisconnect) as exc_info: | |
| ws.receive_text() | |
| assert exc_info.value.code == 1000 | |
| def test_ws_env_closes_on_disconnect(self, client: TestClient) -> None: | |
| """API 07: env.close() runs in the finally block on disconnect.""" | |
| import server.app as _app | |
| _original_make_env = _app._make_env | |
| close_called: list[bool] = [] | |
| def _tracked_make_env(): | |
| env = _original_make_env() | |
| _original_close = env.close | |
| def _tracking_close(): | |
| close_called.append(True) | |
| _original_close() | |
| env.close = _tracking_close | |
| return env | |
| with patch.object(_app, "_make_env", _tracked_make_env): | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "ping"}) | |
| # Context manager exit sends disconnect; server runs finally block | |
| # TestClient joins the ASGI thread, so close() has already run | |
| assert len(close_called) == 1 | |