"""Integration-style checks for inference wiring and transport/session flows.""" from __future__ import annotations import json import os import sys from pathlib import Path from subprocess import CompletedProcess from types import SimpleNamespace from typing import Any from fastapi.testclient import TestClient _ROOT = Path(__file__).resolve().parents[1] if str(_ROOT) not in sys.path: sys.path.insert(0, str(_ROOT)) import inference # noqa: E402 import env_loader # noqa: E402 import server.app as app_module # noqa: E402 from client import DataOpsEnvClient # noqa: E402 from models import DataOpsAction # noqa: E402 class _FakeResponse: def __init__(self, payload: dict[str, Any]) -> None: self._payload = payload def raise_for_status(self) -> None: return None def json(self) -> dict[str, Any]: return self._payload class _FakeHTTPSession: def __init__(self) -> None: self.urls: list[str] = [] self._step_count = 0 def request( self, method: str, url: str, timeout: float | None = None, **kwargs: Any, ) -> _FakeResponse: del method, timeout, kwargs self.urls.append(url) if url.endswith("/reset"): self._step_count = 0 return _FakeResponse( { "observation": { "status": "success", "message": "Repair the ETL job.", }, "reward": 0.0, "done": False, } ) if url.endswith("/step"): self._step_count += 1 return _FakeResponse( { "observation": { "status": "success", "message": "Read ok.", }, "reward": 0.0, "done": self._step_count >= 2, } ) if url.endswith("/grader/task_2_medium_syntax"): return _FakeResponse({"score": 0.25}) raise AssertionError(f"Unexpected URL requested: {url}") class _FakeChatCompletions: def __init__(self, messages: list[Any]) -> None: self._messages = iter(messages) def create(self, **kwargs: Any) -> Any: del kwargs message = next(self._messages) return SimpleNamespace(choices=[SimpleNamespace(message=message)]) class _FakeClient: def __init__(self, messages: list[Any]) -> None: self.chat = SimpleNamespace(completions=_FakeChatCompletions(messages)) self.base_url = "https://model.local/v1" def _tool_message(name: str, arguments: dict[str, Any]) -> Any: return SimpleNamespace( tool_calls=[ SimpleNamespace( id="call-1", function=SimpleNamespace( name=name, arguments=json.dumps(arguments), ), ) ] ) def test_inference_run_task_uses_env_base_url(monkeypatch, capsys) -> None: fake_http = _FakeHTTPSession() fake_client = _FakeClient( [ _tool_message("read_file", {"filepath": "broken_pipeline.py"}), SimpleNamespace(tool_calls=[]), _tool_message("invoke_python", {"filepath": "broken_pipeline.py", "args": []}), ] ) monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local") monkeypatch.setattr(inference, "API_BASE_URL", "https://model.local/v1") monkeypatch.setattr(inference, "MODEL_NAME", "mock-model") score = inference.run_task( fake_client, fake_http, "task_2_medium_syntax", max_turns=4, seed=3, ) assert score == 0.25 assert fake_http.urls assert all(url.startswith("http://env.local") for url in fake_http.urls) assert all("model.local" not in url for url in fake_http.urls) stdout = capsys.readouterr().out assert "[START]" in stdout assert "[STEP]" in stdout assert "[END]" in stdout assert "success=false" in stdout def test_inference_normalizes_boundary_grader_scores(monkeypatch, capsys) -> None: class _PerfectScoreHTTPSession(_FakeHTTPSession): def request( self, method: str, url: str, timeout: float | None = None, **kwargs: Any, ) -> _FakeResponse: if url.endswith("/grader/task_2_medium_syntax"): return _FakeResponse({"score": 1.0}) return super().request(method, url, timeout=timeout, **kwargs) fake_http = _PerfectScoreHTTPSession() fake_client = _FakeClient([SimpleNamespace(tool_calls=[]) for _ in range(4)]) monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local") monkeypatch.setattr(inference, "MODEL_NAME", "mock-model") score = inference.run_task( fake_client, fake_http, "task_2_medium_syntax", max_turns=4, seed=3, ) assert score == inference.MAX_REPORTED_SCORE stdout = capsys.readouterr().out assert "success=true" in stdout assert "score=0.99" in stdout def test_inference_emits_grader_details_to_stderr_when_enabled(monkeypatch, capsys) -> None: class _DetailedHTTPSession(_FakeHTTPSession): def request( self, method: str, url: str, timeout: float | None = None, **kwargs: Any, ) -> _FakeResponse: if url.endswith("/grader/task_2_medium_syntax"): return _FakeResponse( { "task_id": "task_2_medium_syntax", "score": 0.25, "details": {"reason": "Visible repair only"}, } ) return super().request(method, url, timeout=timeout, **kwargs) fake_http = _DetailedHTTPSession() fake_client = _FakeClient( [_tool_message("read_file", {"filepath": "broken_pipeline.py"})] ) monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "true") monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local") monkeypatch.setattr(inference, "MODEL_NAME", "mock-model") inference.run_task( fake_client, fake_http, "task_2_medium_syntax", max_turns=1, seed=3, ) stderr = capsys.readouterr().err.strip() assert stderr assert json.loads(stderr)["details"]["reason"] == "Visible repair only" def test_baseline_endpoint_passes_env_base_url(monkeypatch) -> None: captured: dict[str, Any] = {} def fake_run( command: list[str], *, cwd: str, capture_output: bool, text: bool, timeout: float, env: dict[str, str], ) -> CompletedProcess[str]: captured["command"] = command captured["cwd"] = cwd captured["capture_output"] = capture_output captured["text"] = text captured["timeout"] = timeout captured["env"] = env stdout = "\n".join( [ "[START] task=task_1_easy_anomaly env=dataops_env model=fake-model", "[END] success=true steps=1 score=0.99 rewards=1.00", json.dumps( { "scores": {"task_1_easy_anomaly": 0.99}, "grades": { "task_1_easy_anomaly": { "task_id": "task_1_easy_anomaly", "score": 0.99, "details": {"reason": "Perfect"}, } }, "average": 0.99, "model": "fake-model", "metadata": {"env_base_url": "http://127.0.0.1:7860"}, } ), ] ) stderr = json.dumps({"task_id": "task_1_easy_anomaly", "score": 0.99}) return CompletedProcess(command, 0, stdout=stdout, stderr=stderr) monkeypatch.setenv("API_KEY", "test-key") monkeypatch.delenv("ENV_BASE_URL", raising=False) monkeypatch.setattr(app_module.subprocess, "run", fake_run) with TestClient(app_module.app) as client: response = client.post( "/baseline", json={ "task_ids": ["task_1_easy_anomaly"], "seed": 7, "max_turns": 5, }, ) assert response.status_code == 200 assert "[START] task=task_1_easy_anomaly" in response.json()["stdout"] assert response.json()["stderr"] == json.dumps({"task_id": "task_1_easy_anomaly", "score": 0.99}) assert response.json()["scores"]["task_1_easy_anomaly"] == 0.99 assert response.json()["grades"]["task_1_easy_anomaly"]["details"]["reason"] == "Perfect" assert captured["env"]["ENV_BASE_URL"] == "http://127.0.0.1:7860" assert "--seed" in captured["command"] assert "--max-turns" in captured["command"] assert "--task" in captured["command"] def test_inference_default_api_base_url_uses_google_for_api_key( monkeypatch, ) -> None: monkeypatch.setenv("API_KEY", "test-key") monkeypatch.delenv("HF_TOKEN", raising=False) monkeypatch.delenv("API_BASE_URL", raising=False) assert ( inference._resolve_api_base_url() == inference.DEFAULT_GOOGLE_OPENAI_BASE_URL ) def test_inference_default_api_base_url_uses_hf_router_for_hf_token( monkeypatch, ) -> None: monkeypatch.setenv("HF_TOKEN", "test-token") monkeypatch.delenv("API_KEY", raising=False) monkeypatch.delenv("API_BASE_URL", raising=False) assert inference._resolve_api_base_url() == inference.DEFAULT_HF_OPENAI_BASE_URL def test_session_id_header_can_resume_http_episode() -> None: with TestClient(app_module.app) as client: reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5}) assert reset.status_code == 200 session_id = reset.headers["X-Session-ID"] client.cookies.clear() state = client.get("/state", headers={"X-Session-ID": session_id}) assert state.status_code == 200 payload = state.json() assert payload["task_id"] == "task_1_easy_anomaly" assert payload["seed"] == 5 def test_reset_replaces_unknown_client_supplied_session_id() -> None: with TestClient(app_module.app) as client: reset = client.post( "/reset?task_id=task_1_easy_anomaly", headers={"X-Session-ID": "attacker-chosen-session"}, json={"seed": 4}, ) assert reset.status_code == 200 issued_session_id = reset.headers["X-Session-ID"] assert issued_session_id != "attacker-chosen-session" client.cookies.clear() forged_state = client.get("/state", headers={"X-Session-ID": "attacker-chosen-session"}) assert forged_state.status_code == 400 restored_state = client.get("/state", headers={"X-Session-ID": issued_session_id}) assert restored_state.status_code == 200 assert restored_state.json()["seed"] == 4 def test_websocket_reset_state_and_step_flow() -> None: with TestClient(app_module.app) as client: with client.websocket_connect("/ws") as websocket: websocket.send_json( { "type": "reset", "data": {"task_id": "task_1_easy_anomaly", "seed": 3}, } ) reset_payload = websocket.receive_json() assert reset_payload["data"]["observation"]["status"] == "success" websocket.send_json({"type": "state"}) state_payload = websocket.receive_json() assert state_payload["data"]["task_id"] == "task_1_easy_anomaly" websocket.send_json( { "type": "step", "data": { "action_type": "ExecuteSQL", "payload": { "query": ( "SELECT id, amount FROM transactions " "WHERE amount IS NULL ORDER BY id" ) }, }, } ) step_payload = websocket.receive_json() assert step_payload["data"]["observation"]["status"] == "success" assert step_payload["data"]["observation"]["sql_results"] websocket.send_json({"type": "close", "data": {}}) def test_http_client_overlays_top_level_reward_and_done() -> None: class _FakeSession: def post(self, url: str, **kwargs: Any) -> _FakeResponse: del kwargs if url.endswith("/reset"): return _FakeResponse( { "observation": {"status": "success", "message": "ready"}, "reward": 0.0, "done": False, } ) if url.endswith("/step"): return _FakeResponse( { "observation": {"status": "success", "message": "ok"}, "reward": 0.25, "done": True, } ) raise AssertionError(f"Unexpected URL requested: {url}") def get(self, url: str, **kwargs: Any) -> _FakeResponse: del kwargs raise AssertionError(f"Unexpected URL requested: {url}") def close(self) -> None: return None client = DataOpsEnvClient(base_url="http://env.local") client._session = _FakeSession() try: reset_obs = client.reset(task_id="task_1_easy_anomaly", seed=5) assert reset_obs.reward == 0.0 assert reset_obs.done is False step_obs = client.step( DataOpsAction( action_type="ExecuteSQL", payload={"query": "SELECT 1"}, ) ) assert step_obs.reward == 0.25 assert step_obs.done is True finally: client.close() def test_env_loader_uses_root_env_to_find_secondary_env_file( tmp_path: Path, monkeypatch ) -> None: monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path) monkeypatch.delenv("PUBLIC_GRADER_DETAILS", raising=False) monkeypatch.delenv("MODEL_NAME", raising=False) (tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8") (tmp_path / ".env.dev").write_text( "PUBLIC_GRADER_DETAILS=true\nMODEL_NAME=debug-model\n", encoding="utf-8", ) env_loader.load_env() assert os.getenv("PUBLIC_GRADER_DETAILS") == "true" assert os.getenv("MODEL_NAME") == "debug-model" def test_env_loader_preserves_external_runtime_env( tmp_path: Path, monkeypatch ) -> None: monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path) monkeypatch.setenv("PORT", "7861") monkeypatch.delenv("MODEL_NAME", raising=False) (tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8") (tmp_path / ".env.dev").write_text( "PORT=7860\nMODEL_NAME=debug-model\n", encoding="utf-8", ) env_loader.load_env() assert os.getenv("PORT") == "7861" assert os.getenv("MODEL_NAME") == "debug-model"