Spaces:
Sleeping
Sleeping
| """ | |
| DataOps benchmark runner: drives the sandbox over HTTP (`/reset`, `/step`, `/grader`) with an OpenAI | |
| tool-calling loop. Tool schemas are task-scoped (e.g. send_email only for the hard E2E task). | |
| Flow per task: reset → chat completions (prefer `tool_choice="required"`) → validate tool args → POST each action → | |
| append tool/observation messages until the env reports `done` or `max_turns` → GET grader score. Success is | |
| derived from the score vs `SUCCESS_SCORE_THRESHOLD`. | |
| Stdout is the harness protocol only: one `[START]`, one `[STEP]` per env step, one `[END]` (always). Use | |
| `--json-scores` to append a single JSON object (scores, average, metadata) for `/baseline` ingestion. | |
| CLI: `--task` (repeatable), `--seed`, `--max-turns`, `--json-scores`. The environment HTTP base URL comes from | |
| `ENV_BASE_URL`, or if unset `http://127.0.0.1:$PORT` (default port 7860). Auth uses either `API_KEY` or | |
| `HF_TOKEN`. `API_BASE_URL` is optional: when omitted, the runner defaults to Google's OpenAI-compatible Gemini | |
| endpoint for `API_KEY` and Hugging Face's router for `HF_TOKEN`. | |
| Library logging is disabled so parsers see only these lines. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import zlib | |
| from datetime import datetime, timezone | |
| from typing import Any, Optional, Type | |
| import requests | |
| from openai import BadRequestError, OpenAI | |
| from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam | |
| from pydantic import BaseModel, ValidationError | |
| from env_loader import load_env | |
| from models import ( | |
| ExecuteSQLPayload, | |
| ReadFilePayload, | |
| RunScriptPayload, | |
| SendEmailPayload, | |
| WriteFilePayload, | |
| ) | |
| from server.task_specs import TASK_IDS, TASK_METADATA | |
| # Silence all library logging (httpx, openai, urllib3, env_loader, etc.). | |
| logging.disable(logging.CRITICAL) | |
| load_env() | |
| DEFAULT_GOOGLE_OPENAI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/" | |
| DEFAULT_HF_OPENAI_BASE_URL = "https://router.huggingface.co/v1" | |
| _DEFAULT_PORT = int(os.getenv("PORT", "7860")) | |
| ENV_BASE_URL = os.getenv("ENV_BASE_URL") or f"http://127.0.0.1:{_DEFAULT_PORT}" | |
| MODEL_NAME = os.getenv("MODEL_NAME") or "gemini-3.1-flash-lite-preview" | |
| BENCHMARK = "dataops_env" | |
| MAX_TURNS = 12 | |
| MIN_REPORTED_SCORE = 0.01 | |
| MAX_REPORTED_SCORE = 0.99 | |
| SUCCESS_SCORE_THRESHOLD = MAX_REPORTED_SCORE | |
| _TOOL_HELP: dict[str, str] = { | |
| "execute_sql": "execute_sql — SQL over the task warehouse (field: query).", | |
| "read_file": "read_file — read a workspace file (field: filepath).", | |
| "write_file": "write_file — overwrite a file (fields: filepath, content).", | |
| "invoke_python": "invoke_python — run a Python script (fields: filepath, optional args).", | |
| "send_email": "send_email — send email (fields: to_email, subject, body).", | |
| } | |
| _ACTION_TO_TOOL: dict[str, str] = { | |
| "ExecuteSQL": "execute_sql", | |
| "ReadFile": "read_file", | |
| "WriteFile": "write_file", | |
| "RunScript": "invoke_python", | |
| "SendEmail": "send_email", | |
| } | |
| def _normalize_reported_score(value: Any) -> float: | |
| try: | |
| score = float(value) | |
| except (TypeError, ValueError): | |
| return MIN_REPORTED_SCORE | |
| if score <= 0.0: | |
| return MIN_REPORTED_SCORE | |
| if score >= 1.0: | |
| return MAX_REPORTED_SCORE | |
| score = round(score, 2) | |
| if score <= 0.0: | |
| return MIN_REPORTED_SCORE | |
| if score >= 1.0: | |
| return MAX_REPORTED_SCORE | |
| return score | |
| def _normalize_grade_payload(grade: dict[str, Any]) -> dict[str, Any]: | |
| payload = dict(grade) | |
| payload["score"] = _normalize_reported_score(payload.get("score")) | |
| return payload | |
| def _allowed_tool_names_csv(task_id: str) -> str: | |
| order = ( | |
| "execute_sql", | |
| "read_file", | |
| "write_file", | |
| "invoke_python", | |
| "send_email", | |
| ) | |
| allowed = {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions} | |
| return ", ".join(t for t in order if t in allowed) | |
| def _system_prompt_for_task(task_id: str) -> str: | |
| lines = [ | |
| _TOOL_HELP[t] | |
| for t in ( | |
| "execute_sql", | |
| "read_file", | |
| "write_file", | |
| "invoke_python", | |
| "send_email", | |
| ) | |
| if t in {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions} | |
| ] | |
| tools_block = "\n".join(f" - {line}" for line in lines) | |
| return f"""\ | |
| You are an expert DataOps agent in a task-scoped benchmark. Only the tools listed below exist for this task — do not assume other actions are available. | |
| Available tools: | |
| {tools_block} | |
| Rules: | |
| - Always read files before modifying them when read_file is available. | |
| - After writing a fix, run the script to verify it works when invoke_python is available. | |
| - Be precise. Do not drop tables. Do not guess — inspect first. | |
| - For tasks that include send_email, match subject and body to the task description exactly. | |
| """ | |
| TASK_PROMPTS = { | |
| "task_1_easy_anomaly": ( | |
| "Solve the seeded cleanup task carefully. Inspect before mutating. Only NULL-amount rows are corrupted; preserve every non-null row exactly, including legitimate zero or negative adjustments." | |
| ), | |
| "task_2_medium_syntax": ( | |
| "Solve the seeded script-repair task. Read the file, make the minimal correct fix, and verify with execution." | |
| ), | |
| "task_3_hard_e2e": ( | |
| "Solve the seeded incident task end to end. Use SQL for the exact slice, write the exact JSON file, " | |
| "repair the formatter, execute it, and email the exact generated report." | |
| ), | |
| } | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step( | |
| step: int, action: str, reward: float, done: bool, error: Optional[str] | |
| ) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print( | |
| f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", | |
| flush=True, | |
| ) | |
| def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print( | |
| f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| def _public_grader_details_enabled() -> bool: | |
| return os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower() in {"1", "true", "yes"} | |
| def _emit_grader_details_to_stderr(grade: dict[str, Any]) -> None: | |
| if not _public_grader_details_enabled(): | |
| return | |
| if "details" not in grade: | |
| return | |
| print(json.dumps(grade, ensure_ascii=False), file=sys.stderr, flush=True) | |
| def _request_json( | |
| http: requests.Session, | |
| method: str, | |
| path: str, | |
| *, | |
| timeout: float, | |
| **kwargs: Any, | |
| ) -> dict[str, Any]: | |
| response = http.request(method, f"{ENV_BASE_URL}{path}", timeout=timeout, **kwargs) | |
| response.raise_for_status() | |
| return response.json() | |
| def _build_tools(task_id: str) -> list[ChatCompletionToolParam]: | |
| defs: dict[str, tuple[str, Type[BaseModel]]] = { | |
| "execute_sql": ( | |
| "Run a task-scoped SQL query against the SQLite warehouse DB.", | |
| ExecuteSQLPayload, | |
| ), | |
| "read_file": ("Read a file in the workspace.", ReadFilePayload), | |
| "write_file": ("Overwrite a file with new content.", WriteFilePayload), | |
| "invoke_python": ( | |
| "Execute a Python script in the workspace (optional args).", | |
| RunScriptPayload, | |
| ), | |
| "send_email": ("Send a formatted email notification.", SendEmailPayload), | |
| } | |
| allowed_names = {_ACTION_TO_TOOL[a] for a in TASK_METADATA[task_id].allowed_actions} | |
| return [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": name, | |
| "description": defs[name][0], | |
| "parameters": defs[name][1].model_json_schema(), | |
| }, | |
| } | |
| for name in ( | |
| "execute_sql", | |
| "read_file", | |
| "write_file", | |
| "invoke_python", | |
| "send_email", | |
| ) | |
| if name in allowed_names | |
| ] | |
| def _tool_call_to_action(name: str, arguments: str) -> dict[str, Any]: | |
| if name == "run_script": | |
| name = "invoke_python" | |
| mapping: dict[str, tuple[str, Type[BaseModel]]] = { | |
| "execute_sql": ("ExecuteSQL", ExecuteSQLPayload), | |
| "read_file": ("ReadFile", ReadFilePayload), | |
| "write_file": ("WriteFile", WriteFilePayload), | |
| "invoke_python": ("RunScript", RunScriptPayload), | |
| "send_email": ("SendEmail", SendEmailPayload), | |
| } | |
| if name not in mapping: | |
| raise ValueError(f"Unknown tool: {name}") | |
| action_type, model = mapping[name] | |
| data = json.loads(arguments) if (arguments or "").strip() else {} | |
| payload = model.model_validate(data).model_dump() | |
| return {"action_type": action_type, "payload": payload} | |
| _MALFORMED_TOOL = re.compile( | |
| r"^([a-zA-Z_][a-zA-Z0-9_]*)[\s,=\(]+(\{.*\})\)?\s*$", re.DOTALL | |
| ) | |
| def _normalize_tool_name_and_args(name: str, arguments: str) -> tuple[str, str]: | |
| name = (name or "").strip() | |
| arguments = (arguments or "").strip() | |
| m = _MALFORMED_TOOL.match(name) | |
| if m: | |
| base, embedded = m.group(1).strip(), m.group(2).strip() | |
| if not arguments: | |
| return base, embedded | |
| return name, arguments | |
| def _action_from_tool_call(tc: Any) -> dict[str, Any]: | |
| name, arguments = _normalize_tool_name_and_args( | |
| tc.function.name or "", tc.function.arguments or "" | |
| ) | |
| return _tool_call_to_action(name, arguments) | |
| def _action_str(action_payload: dict[str, Any]) -> str: | |
| at = action_payload.get("action_type", "") | |
| pl = action_payload.get("payload") or {} | |
| raw = f"{at}({json.dumps(pl, ensure_ascii=False)})" | |
| if len(raw) > 1200: | |
| return raw[:600] + "..." + raw[-550:] | |
| return raw | |
| def _obs_error(obs: dict[str, Any]) -> Optional[str]: | |
| if obs.get("status") != "error": | |
| return None | |
| msg = obs.get("message") | |
| if isinstance(msg, str) and msg.strip(): | |
| return msg.replace("\n", " ").strip() | |
| return None | |
| def _resolve_api_base_url() -> str: | |
| explicit = os.getenv("API_BASE_URL", "").strip() | |
| if explicit: | |
| return explicit | |
| if os.getenv("HF_TOKEN", "").strip(): | |
| return DEFAULT_HF_OPENAI_BASE_URL | |
| return DEFAULT_GOOGLE_OPENAI_BASE_URL | |
| API_BASE_URL = _resolve_api_base_url() | |
| def _openai_client() -> OpenAI: | |
| key = (os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "").strip() | |
| if not key: | |
| print( | |
| "[inference] Missing API_KEY or HF_TOKEN for model access.", | |
| file=sys.stderr, | |
| flush=True, | |
| ) | |
| sys.exit(1) | |
| return OpenAI(api_key=key, base_url=API_BASE_URL) | |
| def _llm_seed(env_seed: int | None, task_id: str) -> int | None: | |
| if env_seed is None: | |
| return None | |
| mixed = (int(env_seed) * 1_000_003) ^ (zlib.crc32(task_id.encode()) & 0xFFFFFFFF) | |
| return mixed & 0x7FFFFFFF | |
| def _create_chat_completion( | |
| client: OpenAI, | |
| messages: list[ChatCompletionMessageParam], | |
| tools: list[ChatCompletionToolParam], | |
| *, | |
| task_id: str, | |
| env_seed: int | None, | |
| ) -> Any: | |
| """Prefer tool_choice=required so the model cannot end a turn without a tool call.""" | |
| kwargs: dict[str, Any] = { | |
| "model": MODEL_NAME, | |
| "messages": messages, | |
| "tools": tools, | |
| "parallel_tool_calls": False, | |
| "temperature": 0, | |
| "top_p": 1.0, | |
| } | |
| llm_seed = _llm_seed(env_seed, task_id) | |
| if llm_seed is not None: | |
| kwargs["seed"] = llm_seed | |
| def _call(tool_choice: str) -> Any: | |
| return client.chat.completions.create(**kwargs, tool_choice=tool_choice) | |
| try: | |
| return _call("required") | |
| except BadRequestError as e: | |
| err = str(e).lower() | |
| if "seed" in err and llm_seed is not None: | |
| kwargs.pop("seed", None) | |
| try: | |
| return _call("required") | |
| except BadRequestError as e2: | |
| err = str(e2).lower() | |
| if not any(x in err for x in ("tool_choice", "required", "unsupported")): | |
| raise | |
| return _call("auto") | |
| def run_task( | |
| client: OpenAI, | |
| http: requests.Session, | |
| task_id: str, | |
| *, | |
| max_turns: int, | |
| seed: int | None, | |
| ) -> float: | |
| rewards: list[float] = [] | |
| steps_taken = 0 | |
| score = MIN_REPORTED_SCORE | |
| success = False | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| try: | |
| tools = _build_tools(task_id) | |
| names_csv = _allowed_tool_names_csv(task_id) | |
| reset_resp = _request_json( | |
| http, | |
| "POST", | |
| "/reset", | |
| timeout=10, | |
| params={"task_id": task_id}, | |
| json={} if seed is None else {"seed": seed}, | |
| ) | |
| reset_obs = reset_resp.get("observation", reset_resp) | |
| messages: list[ChatCompletionMessageParam] = [ | |
| {"role": "system", "content": _system_prompt_for_task(task_id)}, | |
| { | |
| "role": "user", | |
| "content": TASK_PROMPTS[task_id] | |
| + f"\n\nEnvironment says: {reset_obs['message']}", | |
| }, | |
| ] | |
| done = False | |
| step_num = 0 | |
| no_tool_streak = 0 | |
| for turn in range(1, max_turns + 1): | |
| try: | |
| response = _create_chat_completion( | |
| client, | |
| messages, | |
| tools, | |
| task_id=task_id, | |
| env_seed=seed, | |
| ) | |
| except BadRequestError as e: | |
| err_str = str(e).lower() | |
| if "tool" not in err_str and "function" not in err_str: | |
| raise | |
| if messages and messages[-1].get("role") == "assistant": # type: ignore[union-attr] | |
| messages.pop() | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": ( | |
| "IMPORTANT: Call tools using ONLY these exact names: " | |
| f"{names_csv}. " | |
| "Put ALL parameters inside the tool's JSON arguments field. " | |
| "Do NOT embed parameters in the tool name itself." | |
| ), | |
| } | |
| ) | |
| try: | |
| response = _create_chat_completion( | |
| client, | |
| messages, | |
| tools, | |
| task_id=task_id, | |
| env_seed=seed, | |
| ) | |
| except BadRequestError: | |
| break | |
| msg = response.choices[0].message | |
| if not msg.tool_calls: | |
| no_tool_streak += 1 | |
| if no_tool_streak > 3: | |
| break | |
| messages.append(msg) # type: ignore[arg-type] | |
| messages.append( | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"You must respond with exactly one tool call ({names_csv}). " | |
| "Do not reply with plain text only." | |
| ), | |
| } | |
| ) | |
| continue | |
| no_tool_streak = 0 | |
| messages.append(msg) # type: ignore[arg-type] | |
| for tc in msg.tool_calls: | |
| try: | |
| action_payload = _action_from_tool_call(tc) | |
| except (json.JSONDecodeError, ValidationError, ValueError) as e: | |
| messages.append( | |
| { | |
| "role": "tool", | |
| "tool_call_id": tc.id, | |
| "content": f"Invalid tool arguments: {e}", | |
| } | |
| ) | |
| continue | |
| step_num += 1 | |
| step_resp = _request_json( | |
| http, | |
| "POST", | |
| "/step", | |
| timeout=30, | |
| json={"action": action_payload}, | |
| ) | |
| obs = step_resp.get("observation", step_resp) | |
| reward_raw = step_resp.get("reward") | |
| reward = 0.0 if reward_raw is None else float(reward_raw) | |
| done = step_resp.get("done", False) | |
| rewards.append(reward) | |
| steps_taken = step_num | |
| err = _obs_error(obs if isinstance(obs, dict) else {}) | |
| log_step( | |
| step=step_num, | |
| action=_action_str(action_payload), | |
| reward=reward, | |
| done=done, | |
| error=err, | |
| ) | |
| messages.append( | |
| {"role": "tool", "tool_call_id": tc.id, "content": json.dumps(obs)} | |
| ) | |
| if done: | |
| break | |
| if done: | |
| break | |
| grade = _normalize_grade_payload( | |
| _request_json(http, "GET", f"/grader/{task_id}", timeout=10) | |
| ) | |
| _emit_grader_details_to_stderr(grade) | |
| score = _normalize_reported_score(grade["score"]) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| except Exception as exc: | |
| print( | |
| f"[inference] task={task_id} failed: {exc!r}", file=sys.stderr, flush=True | |
| ) | |
| finally: | |
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) | |
| return score | |
| def _parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="DataOpsEnv inference (OpenAI client; protocol lines on stdout)." | |
| ) | |
| p.add_argument( | |
| "--task", | |
| action="append", | |
| choices=TASK_IDS, | |
| dest="tasks", | |
| help="Run only the selected task(s). Defaults to all tasks.", | |
| ) | |
| p.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="Environment seed for /reset; also used for LLM seed when the API supports it.", | |
| ) | |
| p.add_argument( | |
| "--max-turns", | |
| type=int, | |
| default=MAX_TURNS, | |
| help=f"Maximum tool-using turns per task (default: {MAX_TURNS}).", | |
| ) | |
| p.add_argument( | |
| "--json-scores", | |
| action="store_true", | |
| help="Print a final JSON object with scores to stdout (for POST /baseline).", | |
| ) | |
| return p.parse_args() | |
| def _run_inference_sync(args: argparse.Namespace) -> None: | |
| client = _openai_client() | |
| scores: dict[str, float] = {} | |
| grades: dict[str, dict[str, Any]] = {} | |
| task_ids = args.tasks or list(TASK_PROMPTS) | |
| with requests.Session() as http: | |
| for task_id in task_ids: | |
| scores[task_id] = run_task( | |
| client, | |
| http, | |
| task_id, | |
| max_turns=max(1, int(args.max_turns)), | |
| seed=args.seed, | |
| ) | |
| if args.json_scores: | |
| try: | |
| grades[task_id] = _normalize_grade_payload( | |
| _request_json( | |
| http, | |
| "GET", | |
| f"/grader/{task_id}", | |
| timeout=10, | |
| ) | |
| ) | |
| except Exception: | |
| grades[task_id] = { | |
| "task_id": task_id, | |
| "score": _normalize_reported_score(scores[task_id]), | |
| } | |
| if args.json_scores: | |
| avg = sum(scores.values()) / len(scores) | |
| payload = { | |
| "scores": scores, | |
| "grades": grades, | |
| "average": round(avg, 4), | |
| "model": MODEL_NAME, | |
| "metadata": { | |
| "env_base_url": ENV_BASE_URL, | |
| "seed": args.seed, | |
| "max_turns": max(1, int(args.max_turns)), | |
| "tasks": task_ids, | |
| "generated_at_utc": datetime.now(timezone.utc).isoformat(), | |
| "model_base_url": str(getattr(client, "base_url", "")), | |
| }, | |
| } | |
| print(json.dumps(payload), flush=True) | |
| async def main() -> None: | |
| args = _parse_args() | |
| await asyncio.to_thread(_run_inference_sync, args) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |