dataops-env / tests /test_inference_api.py
visheshrathi's picture
Upload folder using huggingface_hub
a1b343c verified
"""Integration-style checks for inference wiring and transport/session flows."""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
from subprocess import CompletedProcess
from types import SimpleNamespace
from typing import Any
from fastapi.testclient import TestClient
_ROOT = Path(__file__).resolve().parents[1]
if str(_ROOT) not in sys.path:
sys.path.insert(0, str(_ROOT))
import inference # noqa: E402
import env_loader # noqa: E402
import server.app as app_module # noqa: E402
from client import DataOpsEnvClient # noqa: E402
from models import DataOpsAction # noqa: E402
class _FakeResponse:
def __init__(self, payload: dict[str, Any]) -> None:
self._payload = payload
def raise_for_status(self) -> None:
return None
def json(self) -> dict[str, Any]:
return self._payload
class _FakeHTTPSession:
def __init__(self) -> None:
self.urls: list[str] = []
self._step_count = 0
def request(
self,
method: str,
url: str,
timeout: float | None = None,
**kwargs: Any,
) -> _FakeResponse:
del method, timeout, kwargs
self.urls.append(url)
if url.endswith("/reset"):
self._step_count = 0
return _FakeResponse(
{
"observation": {
"status": "success",
"message": "Repair the ETL job.",
},
"reward": 0.0,
"done": False,
}
)
if url.endswith("/step"):
self._step_count += 1
return _FakeResponse(
{
"observation": {
"status": "success",
"message": "Read ok.",
},
"reward": 0.0,
"done": self._step_count >= 2,
}
)
if url.endswith("/grader/task_2_medium_syntax"):
return _FakeResponse({"score": 0.25})
raise AssertionError(f"Unexpected URL requested: {url}")
class _FakeChatCompletions:
def __init__(self, messages: list[Any]) -> None:
self._messages = iter(messages)
def create(self, **kwargs: Any) -> Any:
del kwargs
message = next(self._messages)
return SimpleNamespace(choices=[SimpleNamespace(message=message)])
class _FakeClient:
def __init__(self, messages: list[Any]) -> None:
self.chat = SimpleNamespace(completions=_FakeChatCompletions(messages))
self.base_url = "https://model.local/v1"
def _tool_message(name: str, arguments: dict[str, Any]) -> Any:
return SimpleNamespace(
tool_calls=[
SimpleNamespace(
id="call-1",
function=SimpleNamespace(
name=name,
arguments=json.dumps(arguments),
),
)
]
)
def test_inference_run_task_uses_env_base_url(monkeypatch, capsys) -> None:
fake_http = _FakeHTTPSession()
fake_client = _FakeClient(
[
_tool_message("read_file", {"filepath": "broken_pipeline.py"}),
SimpleNamespace(tool_calls=[]),
_tool_message("invoke_python", {"filepath": "broken_pipeline.py", "args": []}),
]
)
monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
monkeypatch.setattr(inference, "API_BASE_URL", "https://model.local/v1")
monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
score = inference.run_task(
fake_client,
fake_http,
"task_2_medium_syntax",
max_turns=4,
seed=3,
)
assert score == 0.25
assert fake_http.urls
assert all(url.startswith("http://env.local") for url in fake_http.urls)
assert all("model.local" not in url for url in fake_http.urls)
stdout = capsys.readouterr().out
assert "[START]" in stdout
assert "[STEP]" in stdout
assert "[END]" in stdout
assert "success=false" in stdout
def test_inference_normalizes_boundary_grader_scores(monkeypatch, capsys) -> None:
class _PerfectScoreHTTPSession(_FakeHTTPSession):
def request(
self,
method: str,
url: str,
timeout: float | None = None,
**kwargs: Any,
) -> _FakeResponse:
if url.endswith("/grader/task_2_medium_syntax"):
return _FakeResponse({"score": 1.0})
return super().request(method, url, timeout=timeout, **kwargs)
fake_http = _PerfectScoreHTTPSession()
fake_client = _FakeClient([SimpleNamespace(tool_calls=[]) for _ in range(4)])
monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
score = inference.run_task(
fake_client,
fake_http,
"task_2_medium_syntax",
max_turns=4,
seed=3,
)
assert score == inference.MAX_REPORTED_SCORE
stdout = capsys.readouterr().out
assert "success=true" in stdout
assert "score=0.99" in stdout
def test_inference_emits_grader_details_to_stderr_when_enabled(monkeypatch, capsys) -> None:
class _DetailedHTTPSession(_FakeHTTPSession):
def request(
self,
method: str,
url: str,
timeout: float | None = None,
**kwargs: Any,
) -> _FakeResponse:
if url.endswith("/grader/task_2_medium_syntax"):
return _FakeResponse(
{
"task_id": "task_2_medium_syntax",
"score": 0.25,
"details": {"reason": "Visible repair only"},
}
)
return super().request(method, url, timeout=timeout, **kwargs)
fake_http = _DetailedHTTPSession()
fake_client = _FakeClient(
[_tool_message("read_file", {"filepath": "broken_pipeline.py"})]
)
monkeypatch.setenv("PUBLIC_GRADER_DETAILS", "true")
monkeypatch.setattr(inference, "ENV_BASE_URL", "http://env.local")
monkeypatch.setattr(inference, "MODEL_NAME", "mock-model")
inference.run_task(
fake_client,
fake_http,
"task_2_medium_syntax",
max_turns=1,
seed=3,
)
stderr = capsys.readouterr().err.strip()
assert stderr
assert json.loads(stderr)["details"]["reason"] == "Visible repair only"
def test_baseline_endpoint_passes_env_base_url(monkeypatch) -> None:
captured: dict[str, Any] = {}
def fake_run(
command: list[str],
*,
cwd: str,
capture_output: bool,
text: bool,
timeout: float,
env: dict[str, str],
) -> CompletedProcess[str]:
captured["command"] = command
captured["cwd"] = cwd
captured["capture_output"] = capture_output
captured["text"] = text
captured["timeout"] = timeout
captured["env"] = env
stdout = "\n".join(
[
"[START] task=task_1_easy_anomaly env=dataops_env model=fake-model",
"[END] success=true steps=1 score=0.99 rewards=1.00",
json.dumps(
{
"scores": {"task_1_easy_anomaly": 0.99},
"grades": {
"task_1_easy_anomaly": {
"task_id": "task_1_easy_anomaly",
"score": 0.99,
"details": {"reason": "Perfect"},
}
},
"average": 0.99,
"model": "fake-model",
"metadata": {"env_base_url": "http://127.0.0.1:7860"},
}
),
]
)
stderr = json.dumps({"task_id": "task_1_easy_anomaly", "score": 0.99})
return CompletedProcess(command, 0, stdout=stdout, stderr=stderr)
monkeypatch.setenv("API_KEY", "test-key")
monkeypatch.delenv("ENV_BASE_URL", raising=False)
monkeypatch.setattr(app_module.subprocess, "run", fake_run)
with TestClient(app_module.app) as client:
response = client.post(
"/baseline",
json={
"task_ids": ["task_1_easy_anomaly"],
"seed": 7,
"max_turns": 5,
},
)
assert response.status_code == 200
assert "[START] task=task_1_easy_anomaly" in response.json()["stdout"]
assert response.json()["stderr"] == json.dumps({"task_id": "task_1_easy_anomaly", "score": 0.99})
assert response.json()["scores"]["task_1_easy_anomaly"] == 0.99
assert response.json()["grades"]["task_1_easy_anomaly"]["details"]["reason"] == "Perfect"
assert captured["env"]["ENV_BASE_URL"] == "http://127.0.0.1:7860"
assert "--seed" in captured["command"]
assert "--max-turns" in captured["command"]
assert "--task" in captured["command"]
def test_inference_default_api_base_url_uses_google_for_api_key(
monkeypatch,
) -> None:
monkeypatch.setenv("API_KEY", "test-key")
monkeypatch.delenv("HF_TOKEN", raising=False)
monkeypatch.delenv("API_BASE_URL", raising=False)
assert (
inference._resolve_api_base_url()
== inference.DEFAULT_GOOGLE_OPENAI_BASE_URL
)
def test_inference_default_api_base_url_uses_hf_router_for_hf_token(
monkeypatch,
) -> None:
monkeypatch.setenv("HF_TOKEN", "test-token")
monkeypatch.delenv("API_KEY", raising=False)
monkeypatch.delenv("API_BASE_URL", raising=False)
assert inference._resolve_api_base_url() == inference.DEFAULT_HF_OPENAI_BASE_URL
def test_session_id_header_can_resume_http_episode() -> None:
with TestClient(app_module.app) as client:
reset = client.post("/reset?task_id=task_1_easy_anomaly", json={"seed": 5})
assert reset.status_code == 200
session_id = reset.headers["X-Session-ID"]
client.cookies.clear()
state = client.get("/state", headers={"X-Session-ID": session_id})
assert state.status_code == 200
payload = state.json()
assert payload["task_id"] == "task_1_easy_anomaly"
assert payload["seed"] == 5
def test_reset_replaces_unknown_client_supplied_session_id() -> None:
with TestClient(app_module.app) as client:
reset = client.post(
"/reset?task_id=task_1_easy_anomaly",
headers={"X-Session-ID": "attacker-chosen-session"},
json={"seed": 4},
)
assert reset.status_code == 200
issued_session_id = reset.headers["X-Session-ID"]
assert issued_session_id != "attacker-chosen-session"
client.cookies.clear()
forged_state = client.get("/state", headers={"X-Session-ID": "attacker-chosen-session"})
assert forged_state.status_code == 400
restored_state = client.get("/state", headers={"X-Session-ID": issued_session_id})
assert restored_state.status_code == 200
assert restored_state.json()["seed"] == 4
def test_websocket_reset_state_and_step_flow() -> None:
with TestClient(app_module.app) as client:
with client.websocket_connect("/ws") as websocket:
websocket.send_json(
{
"type": "reset",
"data": {"task_id": "task_1_easy_anomaly", "seed": 3},
}
)
reset_payload = websocket.receive_json()
assert reset_payload["data"]["observation"]["status"] == "success"
websocket.send_json({"type": "state"})
state_payload = websocket.receive_json()
assert state_payload["data"]["task_id"] == "task_1_easy_anomaly"
websocket.send_json(
{
"type": "step",
"data": {
"action_type": "ExecuteSQL",
"payload": {
"query": (
"SELECT id, amount FROM transactions "
"WHERE amount IS NULL ORDER BY id"
)
},
},
}
)
step_payload = websocket.receive_json()
assert step_payload["data"]["observation"]["status"] == "success"
assert step_payload["data"]["observation"]["sql_results"]
websocket.send_json({"type": "close", "data": {}})
def test_http_client_overlays_top_level_reward_and_done() -> None:
class _FakeSession:
def post(self, url: str, **kwargs: Any) -> _FakeResponse:
del kwargs
if url.endswith("/reset"):
return _FakeResponse(
{
"observation": {"status": "success", "message": "ready"},
"reward": 0.0,
"done": False,
}
)
if url.endswith("/step"):
return _FakeResponse(
{
"observation": {"status": "success", "message": "ok"},
"reward": 0.25,
"done": True,
}
)
raise AssertionError(f"Unexpected URL requested: {url}")
def get(self, url: str, **kwargs: Any) -> _FakeResponse:
del kwargs
raise AssertionError(f"Unexpected URL requested: {url}")
def close(self) -> None:
return None
client = DataOpsEnvClient(base_url="http://env.local")
client._session = _FakeSession()
try:
reset_obs = client.reset(task_id="task_1_easy_anomaly", seed=5)
assert reset_obs.reward == 0.0
assert reset_obs.done is False
step_obs = client.step(
DataOpsAction(
action_type="ExecuteSQL",
payload={"query": "SELECT 1"},
)
)
assert step_obs.reward == 0.25
assert step_obs.done is True
finally:
client.close()
def test_env_loader_uses_root_env_to_find_secondary_env_file(
tmp_path: Path, monkeypatch
) -> None:
monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path)
monkeypatch.delenv("PUBLIC_GRADER_DETAILS", raising=False)
monkeypatch.delenv("MODEL_NAME", raising=False)
(tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8")
(tmp_path / ".env.dev").write_text(
"PUBLIC_GRADER_DETAILS=true\nMODEL_NAME=debug-model\n",
encoding="utf-8",
)
env_loader.load_env()
assert os.getenv("PUBLIC_GRADER_DETAILS") == "true"
assert os.getenv("MODEL_NAME") == "debug-model"
def test_env_loader_preserves_external_runtime_env(
tmp_path: Path, monkeypatch
) -> None:
monkeypatch.setattr(env_loader, "_PROJECT_ROOT", tmp_path)
monkeypatch.setenv("PORT", "7861")
monkeypatch.delenv("MODEL_NAME", raising=False)
(tmp_path / ".env").write_text("ENV_FILE=.env.dev\n", encoding="utf-8")
(tmp_path / ".env.dev").write_text(
"PORT=7860\nMODEL_NAME=debug-model\n",
encoding="utf-8",
)
env_loader.load_env()
assert os.getenv("PORT") == "7861"
assert os.getenv("MODEL_NAME") == "debug-model"