Spaces:
Sleeping
Sleeping
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import subprocess | |
| import sys | |
| from collections.abc import AsyncIterator | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| import yaml | |
| import uvicorn | |
| from fastapi import ( | |
| Body, | |
| FastAPI, | |
| HTTPException, | |
| Query, | |
| Request, | |
| Response, | |
| WebSocket, | |
| WebSocketDisconnect, | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from openenv.core.env_server.http_server import serialize_observation | |
| from openenv.core.env_server.types import ( | |
| HealthResponse, | |
| HealthStatus, | |
| ResetRequest, | |
| ResetResponse, | |
| StepRequest, | |
| StepResponse, | |
| WSCloseMessage, | |
| WSErrorCode, | |
| WSErrorResponse, | |
| WSObservationResponse, | |
| WSResetMessage, | |
| WSStateMessage, | |
| WSStateResponse, | |
| WSStepMessage, | |
| ) | |
| from pydantic import ValidationError | |
| from env_loader import load_env | |
| from models import DataOpsAction, DataOpsObservation, DataOpsState | |
| from server.dataops_env_environment import DataOpsEnvironment | |
| from server.grading import evaluate_task | |
| from server.session_manager import EnvironmentSessionManager | |
| from server.task_specs import TASK_IDS, task_manifest_entries | |
| # Repo root must be on sys.path (e.g. run `uv run python -m server.app` or uvicorn from project root). | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| SERVER_DIR = Path(__file__).resolve().parent | |
| logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(name)s | %(message)s") | |
| logger = logging.getLogger(__name__) | |
| load_env() | |
| SESSION_COOKIE_NAME = "dataops_session_id" | |
| SESSION_HEADER_NAME = "X-Session-ID" | |
| MAX_HTTP_SESSIONS = int(os.getenv("MAX_HTTP_SESSIONS", "128")) | |
| HTTP_SESSION_TIMEOUT_S = float(os.getenv("HTTP_SESSION_TIMEOUT_S", "1200")) | |
| MAX_WS_SESSIONS = max(1, int(os.getenv("MAX_WS_SESSIONS", "64"))) | |
| ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "").strip() | |
| COOKIE_SECURE = os.getenv("COOKIE_SECURE", "").lower() in {"1", "true", "yes"} | |
| MIN_REPORTED_SCORE = 0.01 | |
| MAX_REPORTED_SCORE = 0.99 | |
| def _public_grader_details_enabled() -> bool: | |
| """Read at request time so env / tests can control visibility without stale import-time state.""" | |
| v = os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower() | |
| return v in {"1", "true", "yes"} | |
| _ws_active_sessions = 0 | |
| _ws_session_lock = asyncio.Lock() | |
| session_manager = EnvironmentSessionManager( | |
| max_sessions=MAX_HTTP_SESSIONS, | |
| session_timeout_s=HTTP_SESSION_TIMEOUT_S, | |
| ) | |
| async def lifespan(_app: FastAPI) -> AsyncIterator[None]: | |
| logger.info("DataOpsEnv starting.") | |
| yield | |
| session_manager.close_all() | |
| logger.info("DataOpsEnv shutting down.") | |
| app = FastAPI( | |
| title="DataOpsEnv", | |
| description="Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant).", | |
| version="1.0.0", | |
| lifespan=lifespan, | |
| ) | |
| def _cors_allow_origins() -> list[str]: | |
| configured = os.getenv("CORS_ALLOW_ORIGINS", "").strip() | |
| if not configured: | |
| return [] | |
| if configured == "*": | |
| return ["*"] | |
| return [item.strip() for item in configured.split(",") if item.strip()] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=_cors_allow_origins(), | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def _load_manifest() -> dict: | |
| yaml_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "openenv.yaml") | |
| try: | |
| with open(yaml_path, encoding="utf-8") as f: | |
| return yaml.safe_load(f) or {} | |
| except FileNotFoundError: | |
| return {} | |
| def _load_yaml_tasks() -> list[dict]: | |
| manifest = _load_manifest() | |
| tasks = manifest.get("tasks") | |
| if isinstance(tasks, list) and tasks: | |
| manifest_ids = [str(item.get("id", "")) for item in tasks] | |
| if manifest_ids == TASK_IDS: | |
| return tasks | |
| return task_manifest_entries() | |
| def _wrap_obs(obs: DataOpsObservation) -> dict: | |
| """Serialise an observation to the standard OpenEnv response dict.""" | |
| return obs.model_dump() | |
| def _get_session_id(request: Request) -> str | None: | |
| header_value = request.headers.get(SESSION_HEADER_NAME) | |
| if header_value: | |
| return header_value.strip() or None | |
| cookie_value = request.cookies.get(SESSION_COOKIE_NAME) | |
| if cookie_value: | |
| return cookie_value.strip() or None | |
| return None | |
| def _attach_session(response: Response, session_id: str) -> None: | |
| response.set_cookie( | |
| key=SESSION_COOKIE_NAME, | |
| value=session_id, | |
| httponly=True, | |
| samesite="lax", | |
| secure=COOKIE_SECURE, | |
| max_age=int(HTTP_SESSION_TIMEOUT_S), | |
| ) | |
| response.headers[SESSION_HEADER_NAME] = session_id | |
| def _require_active_env(request: Request) -> tuple[str, DataOpsEnvironment]: | |
| session_id, env = session_manager.get_session(_get_session_id(request)) | |
| if session_id is None or env is None: | |
| raise HTTPException(400, "No active episode. Call /reset first.") | |
| return session_id, env | |
| def _ws_error_payload(message: str, code: WSErrorCode) -> str: | |
| return WSErrorResponse( | |
| data={ | |
| "message": message, | |
| "code": code.value, | |
| } | |
| ).model_dump_json() | |
| def _require_admin(request: Request) -> None: | |
| if not ADMIN_API_KEY: | |
| return | |
| if request.headers.get("X-Admin-Key", "") != ADMIN_API_KEY: | |
| raise HTTPException(403, "Missing or invalid admin key.") | |
| def _request_is_admin(request: Request) -> bool: | |
| return bool(ADMIN_API_KEY) and request.headers.get("X-Admin-Key", "") == ADMIN_API_KEY | |
| def _normalize_reported_score(value: object) -> float: | |
| try: | |
| score = float(value) | |
| except (TypeError, ValueError): | |
| return MIN_REPORTED_SCORE | |
| if score <= 0.0: | |
| return MIN_REPORTED_SCORE | |
| if score >= 1.0: | |
| return MAX_REPORTED_SCORE | |
| score = round(score, 2) | |
| if score <= 0.0: | |
| return MIN_REPORTED_SCORE | |
| if score >= 1.0: | |
| return MAX_REPORTED_SCORE | |
| return score | |
| def _normalize_grade_payload(grade: dict) -> dict: | |
| payload = dict(grade) | |
| payload["score"] = _normalize_reported_score(payload.get("score")) | |
| return payload | |
| def _format_grader_response(grade: dict, request: Request) -> dict: | |
| grade = _normalize_grade_payload(grade) | |
| if _public_grader_details_enabled() or _request_is_admin(request): | |
| return grade | |
| return {"task_id": grade.get("task_id"), "score": grade.get("score")} | |
| async def _try_acquire_ws_slot() -> bool: | |
| global _ws_active_sessions | |
| async with _ws_session_lock: | |
| if _ws_active_sessions >= MAX_WS_SESSIONS: | |
| return False | |
| _ws_active_sessions += 1 | |
| return True | |
| async def _release_ws_slot() -> None: | |
| global _ws_active_sessions | |
| async with _ws_session_lock: | |
| _ws_active_sessions = max(0, _ws_active_sessions - 1) | |
| def health_endpoint(): | |
| return HealthResponse(status=HealthStatus.HEALTHY) | |
| def metadata_endpoint(): | |
| manifest = _load_manifest() | |
| return { | |
| "name": manifest.get("name", "dataops_env"), | |
| "description": manifest.get( | |
| "description", | |
| ( | |
| "Enterprise data pipeline remediation environment. " | |
| "Agents debug data streams, fix scripts, and send email reports." | |
| ), | |
| ), | |
| "version": manifest.get("version", "1.0.0"), | |
| "task_count": len(_load_yaml_tasks()), | |
| } | |
| def schema_endpoint(): | |
| return { | |
| "action": DataOpsAction.model_json_schema(), | |
| "observation": DataOpsObservation.model_json_schema(), | |
| "state": DataOpsState.model_json_schema(), | |
| } | |
| def mcp_endpoint(body: dict = Body(default_factory=dict)): | |
| method = body.get("method", "") | |
| req_id = body.get("id") | |
| if method == "tools/list": | |
| tools = [ | |
| {"name": atype, "description": f"Execute a {atype} action."} | |
| for atype in [ | |
| "ExecuteSQL", | |
| "ReadFile", | |
| "WriteFile", | |
| "RunScript", | |
| "SendEmail", | |
| ] | |
| ] | |
| return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}} | |
| return { | |
| "jsonrpc": "2.0", | |
| "id": req_id, | |
| "error": {"code": -32601, "message": "Method not found"}, | |
| } | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| acquired_slot = await _try_acquire_ws_slot() | |
| if not acquired_slot: | |
| await websocket.send_text( | |
| _ws_error_payload( | |
| "WebSocket session capacity reached.", | |
| WSErrorCode.CAPACITY_REACHED, | |
| ) | |
| ) | |
| await websocket.close(code=1013) | |
| return | |
| env = DataOpsEnvironment() | |
| try: | |
| while True: | |
| raw_message = await websocket.receive_text() | |
| try: | |
| message_dict = json.loads(raw_message) | |
| except json.JSONDecodeError: | |
| await websocket.send_text( | |
| _ws_error_payload("Invalid JSON payload.", WSErrorCode.INVALID_JSON) | |
| ) | |
| continue | |
| message_type = message_dict.get("type", "") | |
| try: | |
| if message_type == "reset": | |
| message = WSResetMessage(**message_dict) | |
| observation = env.reset(**message.data) | |
| response = WSObservationResponse( | |
| data=serialize_observation(observation) | |
| ) | |
| elif message_type == "step": | |
| message = WSStepMessage(**message_dict) | |
| action = DataOpsAction(**message.data) | |
| observation = env.step(action) | |
| response = WSObservationResponse( | |
| data=serialize_observation(observation) | |
| ) | |
| elif message_type == "state": | |
| WSStateMessage(**message_dict) | |
| response = WSStateResponse(data=env.state.model_dump()) | |
| elif message_type == "close": | |
| WSCloseMessage(**message_dict) | |
| break | |
| else: | |
| await websocket.send_text( | |
| _ws_error_payload( | |
| f"Unknown message type: {message_type}", | |
| WSErrorCode.UNKNOWN_TYPE, | |
| ) | |
| ) | |
| continue | |
| await websocket.send_text(response.model_dump_json()) | |
| except ValidationError: | |
| await websocket.send_text( | |
| _ws_error_payload( | |
| "Validation error while handling the WebSocket message.", | |
| WSErrorCode.VALIDATION_ERROR, | |
| ) | |
| ) | |
| except Exception: | |
| logger.exception("WebSocket execution error") | |
| await websocket.send_text( | |
| _ws_error_payload( | |
| "Execution error while handling the WebSocket message.", | |
| WSErrorCode.EXECUTION_ERROR, | |
| ) | |
| ) | |
| except WebSocketDisconnect: | |
| logger.debug("WebSocket client disconnected.") | |
| finally: | |
| env.close() | |
| await _release_ws_slot() | |
| def reset_endpoint( | |
| request: Request, | |
| response: Response, | |
| task_id: str = Query("task_1_easy_anomaly", description="Task to initialise."), | |
| body: ResetRequest = Body(default_factory=ResetRequest), | |
| ): | |
| if task_id not in TASK_IDS: | |
| raise HTTPException(400, f"Invalid task_id. Choose from: {TASK_IDS}") | |
| session_id = _get_session_id(request) | |
| resolved_session_id, _env, obs = session_manager.reset_session( | |
| task_id=task_id, | |
| seed=body.seed, | |
| episode_id=body.episode_id, | |
| session_id=session_id, | |
| ) | |
| _attach_session(response, resolved_session_id) | |
| return ResetResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done) | |
| def step_endpoint(request: Request, response: Response, body: StepRequest): | |
| try: | |
| action = DataOpsAction(**body.action) | |
| except ValidationError as e: | |
| raise HTTPException(422, f"Invalid action: {e}") from e | |
| session_id, env = _require_active_env(request) | |
| _attach_session(response, session_id) | |
| obs = env.step(action, timeout_s=body.timeout_s) | |
| return StepResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done) | |
| def state_endpoint(request: Request, response: Response): | |
| session_id, env = _require_active_env(request) | |
| _attach_session(response, session_id) | |
| return env.state | |
| def tasks_endpoint(): | |
| return { | |
| "tasks": _load_yaml_tasks(), | |
| "action_schema": DataOpsAction.model_json_schema(), | |
| "observation_schema": DataOpsObservation.model_json_schema(), | |
| "state_schema": DataOpsState.model_json_schema(), | |
| } | |
| def grader_current_endpoint(request: Request, response: Response): | |
| """Grade the current episode (uses active task_id from state).""" | |
| session_id, env = _require_active_env(request) | |
| _attach_session(response, session_id) | |
| task_id = env.state.task_id | |
| if not task_id: | |
| raise HTTPException(400, "No active episode. Call /reset first.") | |
| return _format_grader_response(evaluate_task(task_id, env), request) | |
| def grader_endpoint(task_id: str, request: Request, response: Response): | |
| if task_id not in TASK_IDS: | |
| raise HTTPException(404, f"Unknown task: {task_id}") | |
| session_id, env = _require_active_env(request) | |
| _attach_session(response, session_id) | |
| active_task_id = env.state.task_id | |
| if active_task_id and active_task_id != task_id: | |
| raise HTTPException( | |
| 400, | |
| f"Active episode belongs to task '{active_task_id}'. Reset the requested task first.", | |
| ) | |
| return _format_grader_response(evaluate_task(task_id, env), request) | |
| def baseline_endpoint(request: Request, body: dict = Body(default_factory=dict)): | |
| """Run inference.py (OpenAI tool-calling agent) against all tasks; same entrypoint as local baseline.""" | |
| _require_admin(request) | |
| if not ( | |
| os.environ.get("API_KEY", "").strip() | |
| or os.environ.get("HF_TOKEN", "").strip() | |
| ): | |
| raise HTTPException( | |
| 503, | |
| "API_KEY or HF_TOKEN must be set on the server process to run POST /baseline.", | |
| ) | |
| script_path = PROJECT_ROOT / "inference.py" | |
| if not script_path.is_file(): | |
| raise HTTPException(500, "inference.py missing from project root.") | |
| port = int(os.getenv("PORT", "7860")) | |
| timeout_s = HTTP_SESSION_TIMEOUT_S | |
| env = { | |
| **os.environ, | |
| "ENV_BASE_URL": os.getenv("ENV_BASE_URL", f"http://127.0.0.1:{port}"), | |
| } | |
| command = [sys.executable, str(script_path), "--json-scores"] | |
| if body.get("seed") is not None: | |
| command.extend(["--seed", str(int(body["seed"]))]) | |
| if body.get("max_turns") is not None: | |
| command.extend(["--max-turns", str(int(body["max_turns"]))]) | |
| for task_id in body.get("task_ids", []) or []: | |
| if task_id in TASK_IDS: | |
| command.extend(["--task", str(task_id)]) | |
| try: | |
| proc = subprocess.run( | |
| command, | |
| cwd=str(PROJECT_ROOT), | |
| capture_output=True, | |
| text=True, | |
| timeout=timeout_s, | |
| env=env, | |
| ) | |
| except subprocess.TimeoutExpired: | |
| raise HTTPException( | |
| 504, f"Baseline exceeded HTTP_SESSION_TIMEOUT_S ({timeout_s}s)." | |
| ) from None | |
| if proc.returncode != 0: | |
| tail = (proc.stderr or proc.stdout or "")[-6000:] | |
| logger.error("inference.py failed rc=%s stderr=%s", proc.returncode, tail[:500]) | |
| raise HTTPException( | |
| 502, | |
| {"message": "inference.py exited with an error.", "detail": tail}, | |
| ) | |
| lines = [ln.strip() for ln in (proc.stdout or "").splitlines() if ln.strip()] | |
| parsed = None | |
| for line in reversed(lines): | |
| try: | |
| parsed = json.loads(line) | |
| break | |
| except json.JSONDecodeError: | |
| continue | |
| if not isinstance(parsed, dict) or "scores" not in parsed: | |
| raise HTTPException( | |
| 502, | |
| { | |
| "message": "Could not parse JSON scores from inference.py stdout.", | |
| "stdout_tail": "\n".join(lines[-5:]), | |
| }, | |
| ) | |
| return { | |
| "message": "Model baseline completed via inference.py.", | |
| "stdout": proc.stdout, | |
| "stderr": proc.stderr, | |
| "scores": parsed["scores"], | |
| "grades": parsed.get("grades"), | |
| "average": parsed.get("average"), | |
| "model": parsed.get("model"), | |
| "metadata": parsed.get("metadata"), | |
| } | |
| def main(): | |
| """Entry point for `dataops-env` script and `openenv serve`.""" | |
| host = os.getenv("HOST", "0.0.0.0") | |
| port = int(os.getenv("PORT", "7860")) | |
| reload = os.getenv("DEBUG", "").lower() in ("1", "true") | |
| cwd = Path.cwd().resolve() | |
| app_target = "app:app" if cwd == SERVER_DIR else "server.app:app" | |
| app_dir = str(SERVER_DIR if app_target == "app:app" else PROJECT_ROOT) | |
| uvicorn.run( | |
| app_target, | |
| host=host, | |
| port=port, | |
| reload=reload, | |
| reload_dirs=[str(PROJECT_ROOT)] if reload else None, | |
| ws="wsproto", | |
| app_dir=app_dir, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |