Spaces:
Sleeping
Sleeping
| """Integration-style checks for inference wiring and transport/session flows.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from subprocess import CompletedProcess | |
| from types import SimpleNamespace | |
| from typing import Any | |
| from fastapi.testclient import TestClient | |
| _ROOT = Path(__file__).resolve().parents[1] | |
| if str(_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_ROOT)) | |
| import inference # noqa: E402 | |
| import env_loader # noqa: E402 | |
| import server.app as app_module # noqa: E402 | |
| from client import DataOpsEnvClient # noqa: E402 | |
| from models import DataOpsAction # noqa: E402 | |
| class _FakeResponse: | |
| def __init__(self, payload: dict[str, Any]) -> None: | |
| self._payload = payload | |
| def raise_for_status(self) -> None: | |
| return None | |
| def json(self) -> dict[str, Any]: | |
| return self._payload | |
| class _FakeHTTPSession: | |
| def __init__(self) -> None: | |
| self.urls: list[str] = [] | |
| self._step_count = 0 | |
| def request( | |
| self, | |
| method: str, | |
| url: str, | |
| timeout: float | None = None, | |
| **kwargs: Any, | |
| ) -> _FakeResponse: | |
| del method, timeout, kwargs | |
| self.urls.append(url) | |
| if url.endswith("/reset"): | |
| self._step_count = 0 | |
| return _FakeResponse( | |
| { | |
| "observation": { | |
| "status": "success", | |
| "message": "Repair the ETL job.", | |
| }, | |
| "reward": 0.0, | |
| "done": False, | |
| } | |
| ) | |
| if url.endswith("/step"): | |
| self._step_count += 1 | |
| return _FakeResponse( | |
| { | |
| "observation": { | |
| "status": "success", | |
| "message": "Read ok.", | |
| }, | |
| "reward": 0.0, | |
| "done": self._step_count >= 2, | |
| } | |
| ) | |
| if url.endswith("/grader/task_2_medium_syntax"): | |
| return _FakeResponse({"score": 0.25}) | |
| raise AssertionError(f"Unexpected URL requested: {url}") | |
| class _FakeChatCompletions: | |
| def __init__(self, messages: list[Any]) -> None: | |
| self._messages = iter(messages) | |
| def create(self, **kwargs: Any) -> Any: | |
| del kwargs | |
| message = next(self._messages) | |
| return SimpleNamespace(choices=[SimpleNamespace(message=message)]) | |
| class _FakeClient: | |
| def __init__(self, messages: list[Any]) -> None: | |
| self.chat = SimpleNamespace(completions=_FakeChatCompletions(messages)) | |
| self.base_url = "https://model.local/v1" | |
| def _tool_message(name: str, arguments: dict[str, Any]) -> Any: | |
| return SimpleNamespace( | |
| tool_calls=[ | |
| SimpleNamespace( | |
| id="call-1", | |
| function=SimpleNamespace( | |
| name=name, | |
| arguments=json.dumps(arguments), | |
| ), | |
| ) | |
| ] | |
| ) | |
| def test_inference_run_task_uses_env_base_url(monkeypatch, capsys) -> None: | |
| fake_http = _FakeHTTPSession() | |
| fake_client = _FakeClient( | |
| [ | |
| _tool_message("read_file", {"filepath": "broken_pipeline.py"}), | |
| SimpleNamespace(tool_calls=[]), | |
| _tool_message("invoke_python", {"filepath": "broken_pipeline.py", "args": []}), | |
| ] | |
| ) | |
| monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local") | |
| monkeypatch.setattr(inference, "API_BASE_URL", "https://model.local/v1") | |
| monkeypatch.setattr(inference, "MODEL_NAME", "mock-model") | |
| score = inference.run_task( | |
| fake_client, | |
| fake_http, | |
| "task_2_medium_syntax", | |
| max_turns=4, | |
| seed=3, | |
| ) | |
| assert score == 0.25 | |
| assert fake_http.urls | |
| assert all(url.startswith("http://env.local") for url in fake_http.urls) | |
| assert all("model.local" not in url for url in fake_http.urls) | |
| stdout = capsys.readouterr().out | |
| assert "[START]" in stdout | |
| assert "[STEP]" in stdout | |
| assert "[END]" in stdout | |
| assert "success=false" in stdout | |
| def test_inference_normalizes_boundary_grader_scores(monkeypatch, capsys) -> None: | |
| class _PerfectScoreHTTPSession(_FakeHTTPSession): | |
| def request( | |
| self, | |
| method: str, | |
| url: str, | |
| timeout: float | None = None, | |
| **kwargs: Any, | |
| ) -> _FakeResponse: | |
| if url.endswith("/grader/task_2_medium_syntax"): | |
| return _FakeResponse({"score": 1.0}) | |
| return super().request(method, url, timeout=timeout, **kwargs) | |
| fake_http = _PerfectScoreHTTPSession() | |
| fake_client = _FakeClient([SimpleNamespace(tool_calls=[]) for _ in range(4)]) | |
| monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local") | |
| monkeypatch.setattr(inference, "MODEL_NAME", "mock-model") | |
| score = inference.run_task( | |
| fake_client, | |
| fake_http, | |
| "task_2_medium_syntax", | |
| max_turns=4, | |
| seed=3, | |
| ) | |
| assert score == inference.MAX_REPORTED_SCORE | |
| stdout = capsys.readouterr().out | |
| assert "success=true" in stdout | |
| assert "score=0.99" in stdout | |
| def test_inference_emits_grader_details_to_stderr_when_enabled(monkeypatch, capsys) -> None: | |
| class _DetailedHTTPSession(_FakeHTTPSession): | |
| def request( | |
| self, | |
| method: str, | |
| url: str, | |
| timeout: float | None = None, | |
| **kwargs: Any, | |
| ) -> _FakeResponse: | |
| if url.endswith("/grader/task_2_medium_syntax"): | |
| return _FakeResponse( | |
| { | |
| "task_id": "task_2_medium_syntax", | |
| "score": 0.25, | |
| "details": {"reason": "Visible repair only"}, | |
| } | |
| ) | |
| return super().request(method, url, timeout=timeout, **kwargs) | |
| fake_http = _DetailedHTTPSession() | |
| fake_client = _FakeClient( | |
| [_tool_message("read_file", {"filepath": "broken_pipeline.py"})] | |
| ) | |
| monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "true") | |
| monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local") | |
| monkeypatch.setattr(inference, "MODEL_NAME", "mock-model") | |
| inference.run_task( | |
| fake_client, | |
| fake_http, | |
| "task_2_medium_syntax", | |
| max_turns=1, | |
| seed=3, | |
| ) | |
| stderr = capsys.readouterr().err.strip() | |
| assert stderr | |
| assert json.loads(stderr)["details"]["reason"] == "Visible repair only" | |
| def test_baseline_endpoint_passes_env_base_url(monkeypatch) -> None: | |
| captured: dict[str, Any] = {} | |
| def fake_run( | |
| command: list[str], | |
| *, | |
| cwd: str, | |
| capture_output: bool, | |
| text: bool, | |
| timeout: float, | |
| env: dict[str, str], | |
| ) -> CompletedProcess[str]: | |
| captured["command"] = command | |
| captured["cwd"] = cwd | |
| captured["capture_output"] = capture_output | |
| captured["text"] = text | |
| captured["timeout"] = timeout | |
| captured["env"] = env | |
| stdout = "\n".join( | |
| [ | |
| "[START] task=task_1_easy_anomaly env=dataops_env model=fake-model", | |
| "[END] success=true steps=1 score=0.99 rewards=1.00", | |
| json.dumps( | |
| { | |
| "scores": {"task_1_easy_anomaly": 0.99}, | |
| "grades": { | |
| "task_1_easy_anomaly": { | |
| "task_id": "task_1_easy_anomaly", | |
| "score": 0.99, | |
| "details": {"reason": "Perfect"}, | |
| } | |
| }, | |
| "average": 0.99, | |
| "model": "fake-model", | |
| "metadata": {"env_base_url": "http://127.0.0.1:7860"}, | |
| } | |
| ), | |
| ] | |
| ) | |
| stderr = json.dumps({"task_id": "task_1_easy_anomaly", "score": 0.99}) | |
| return CompletedProcess(command, 0, stdout=stdout, stderr=stderr) | |
| monkeypatch.setenv("API_KEY", "test-key") | |
| monkeypatch.delenv("ENV_BASE_URL", raising=False) | |
| monkeypatch.setattr(app_module.subprocess, "run", fake_run) | |
| with TestClient(app_module.app) as client: | |
| response = client.post( | |
| "/baseline", | |
| json={ | |
| "task_ids": ["task_1_easy_anomaly"], | |
| "seed": 7, | |
| "max_turns": 5, | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| assert "[START] task=task_1_easy_anomaly" in response.json()["stdout"] | |
| assert response.json()["stderr"] == json.dumps({"task_id": "task_1_easy_anomaly", "score": 0.99}) | |
| assert response.json()["scores"]["task_1_easy_anomaly"] == 0.99 | |
| assert response.json()["grades"]["task_1_easy_anomaly"]["details"]["reason"] == "Perfect" | |
| assert captured["env"]["ENV_BASE_URL"] == "http://127.0.0.1:7860" | |
| assert "--seed" in captured["command"] | |
| assert "--max-turns" in captured["command"] | |
| assert "--task" in captured["command"] | |
| def test_inference_default_api_base_url_uses_google_for_api_key( | |
| monkeypatch, | |
| ) -> None: | |
| monkeypatch.setenv("API_KEY", "test-key") | |
| monkeypatch.delenv("HF_TOKEN", raising=False) | |
| monkeypatch.delenv("API_BASE_URL", raising=False) | |
| assert ( | |
| inference._resolve_api_base_url() | |
| == inference.DEFAULT_GOOGLE_OPENAI_BASE_URL | |
| ) | |
| def test_inference_default_api_base_url_uses_hf_router_for_hf_token( | |
| monkeypatch, | |
| ) -> None: | |
| monkeypatch.setenv("HF_TOKEN", "test-token") | |
| monkeypatch.delenv("API_KEY", raising=False) | |
| monkeypatch.delenv("API_BASE_URL", raising=False) | |
| assert inference._resolve_api_base_url() == inference.DEFAULT_HF_OPENAI_BASE_URL | |
| def test_session_id_header_can_resume_http_episode() -> None: | |
| with TestClient(app_module.app) as client: | |
| reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5}) | |
| assert reset.status_code == 200 | |
| session_id = reset.headers["X-Session-ID"] | |
| client.cookies.clear() | |
| state = client.get("/state", headers={"X-Session-ID": session_id}) | |
| assert state.status_code == 200 | |
| payload = state.json() | |
| assert payload["task_id"] == "task_1_easy_anomaly" | |
| assert payload["seed"] == 5 | |
| def test_reset_replaces_unknown_client_supplied_session_id() -> None: | |
| with TestClient(app_module.app) as client: | |
| reset = client.post( | |
| "/reset?task_id=task_1_easy_anomaly", | |
| headers={"X-Session-ID": "attacker-chosen-session"}, | |
| json={"seed": 4}, | |
| ) | |
| assert reset.status_code == 200 | |
| issued_session_id = reset.headers["X-Session-ID"] | |
| assert issued_session_id != "attacker-chosen-session" | |
| client.cookies.clear() | |
| forged_state = client.get("/state", headers={"X-Session-ID": "attacker-chosen-session"}) | |
| assert forged_state.status_code == 400 | |
| restored_state = client.get("/state", headers={"X-Session-ID": issued_session_id}) | |
| assert restored_state.status_code == 200 | |
| assert restored_state.json()["seed"] == 4 | |
| def test_websocket_reset_state_and_step_flow() -> None: | |
| with TestClient(app_module.app) as client: | |
| with client.websocket_connect("/ws") as websocket: | |
| websocket.send_json( | |
| { | |
| "type": "reset", | |
| "data": {"task_id": "task_1_easy_anomaly", "seed": 3}, | |
| } | |
| ) | |
| reset_payload = websocket.receive_json() | |
| assert reset_payload["data"]["observation"]["status"] == "success" | |
| websocket.send_json({"type": "state"}) | |
| state_payload = websocket.receive_json() | |
| assert state_payload["data"]["task_id"] == "task_1_easy_anomaly" | |
| websocket.send_json( | |
| { | |
| "type": "step", | |
| "data": { | |
| "action_type": "ExecuteSQL", | |
| "payload": { | |
| "query": ( | |
| "SELECT id, amount FROM transactions " | |
| "WHERE amount IS NULL ORDER BY id" | |
| ) | |
| }, | |
| }, | |
| } | |
| ) | |
| step_payload = websocket.receive_json() | |
| assert step_payload["data"]["observation"]["status"] == "success" | |
| assert step_payload["data"]["observation"]["sql_results"] | |
| websocket.send_json({"type": "close", "data": {}}) | |
| def test_http_client_overlays_top_level_reward_and_done() -> None: | |
| class _FakeSession: | |
| def post(self, url: str, **kwargs: Any) -> _FakeResponse: | |
| del kwargs | |
| if url.endswith("/reset"): | |
| return _FakeResponse( | |
| { | |
| "observation": {"status": "success", "message": "ready"}, | |
| "reward": 0.0, | |
| "done": False, | |
| } | |
| ) | |
| if url.endswith("/step"): | |
| return _FakeResponse( | |
| { | |
| "observation": {"status": "success", "message": "ok"}, | |
| "reward": 0.25, | |
| "done": True, | |
| } | |
| ) | |
| raise AssertionError(f"Unexpected URL requested: {url}") | |
| def get(self, url: str, **kwargs: Any) -> _FakeResponse: | |
| del kwargs | |
| raise AssertionError(f"Unexpected URL requested: {url}") | |
| def close(self) -> None: | |
| return None | |
| client = DataOpsEnvClient(base_url="http://env.local") | |
| client._session = _FakeSession() | |
| try: | |
| reset_obs = client.reset(task_id="task_1_easy_anomaly", seed=5) | |
| assert reset_obs.reward == 0.0 | |
| assert reset_obs.done is False | |
| step_obs = client.step( | |
| DataOpsAction( | |
| action_type="ExecuteSQL", | |
| payload={"query": "SELECT 1"}, | |
| ) | |
| ) | |
| assert step_obs.reward == 0.25 | |
| assert step_obs.done is True | |
| finally: | |
| client.close() | |
| def test_env_loader_uses_root_env_to_find_secondary_env_file( | |
| tmp_path: Path, monkeypatch | |
| ) -> None: | |
| monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path) | |
| monkeypatch.delenv("PUBLIC_GRADER_DETAILS", raising=False) | |
| monkeypatch.delenv("MODEL_NAME", raising=False) | |
| (tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8") | |
| (tmp_path / ".env.dev").write_text( | |
| "PUBLIC_GRADER_DETAILS=true\nMODEL_NAME=debug-model\n", | |
| encoding="utf-8", | |
| ) | |
| env_loader.load_env() | |
| assert os.getenv("PUBLIC_GRADER_DETAILS") == "true" | |
| assert os.getenv("MODEL_NAME") == "debug-model" | |
| def test_env_loader_preserves_external_runtime_env( | |
| tmp_path: Path, monkeypatch | |
| ) -> None: | |
| monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path) | |
| monkeypatch.setenv("PORT", "7861") | |
| monkeypatch.delenv("MODEL_NAME", raising=False) | |
| (tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8") | |
| (tmp_path / ".env.dev").write_text( | |
| "PORT=7860\nMODEL_NAME=debug-model\n", | |
| encoding="utf-8", | |
| ) | |
| env_loader.load_env() | |
| assert os.getenv("PORT") == "7861" | |
| assert os.getenv("MODEL_NAME") == "debug-model" | |