import asyncio import json import logging import os import subprocess import sys from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path import yaml import uvicorn from fastapi import ( Body, FastAPI, HTTPException, Query, Request, Response, WebSocket, WebSocketDisconnect, ) from fastapi.middleware.cors import CORSMiddleware from openenv.core.env_server.http_server import serialize_observation from openenv.core.env_server.types import ( HealthResponse, HealthStatus, ResetRequest, ResetResponse, StepRequest, StepResponse, WSCloseMessage, WSErrorCode, WSErrorResponse, WSObservationResponse, WSResetMessage, WSStateMessage, WSStateResponse, WSStepMessage, ) from pydantic import ValidationError from env_loader import load_env from models import DataOpsAction, DataOpsObservation, DataOpsState from server.dataops_env_environment import DataOpsEnvironment from server.grading import evaluate_task from server.session_manager import EnvironmentSessionManager from server.task_specs import TASK_IDS, task_manifest_entries # Repo root must be on sys.path (e.g. run `uv run python -m server.app` or uvicorn from project root). PROJECT_ROOT = Path(__file__).resolve().parents[1] SERVER_DIR = Path(__file__).resolve().parent logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(name)s | %(message)s") logger = logging.getLogger(__name__) load_env() SESSION_COOKIE_NAME = "dataops_session_id" SESSION_HEADER_NAME = "X-Session-ID" MAX_HTTP_SESSIONS = int(os.getenv("MAX_HTTP_SESSIONS", "128")) HTTP_SESSION_TIMEOUT_S = float(os.getenv("HTTP_SESSION_TIMEOUT_S", "1200")) MAX_WS_SESSIONS = max(1, int(os.getenv("MAX_WS_SESSIONS", "64"))) ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "").strip() COOKIE_SECURE = os.getenv("COOKIE_SECURE", "").lower() in {"1", "true", "yes"} MIN_REPORTED_SCORE = 0.01 MAX_REPORTED_SCORE = 0.99 def _public_grader_details_enabled() -> bool: """Read at request time so env / tests can control visibility without stale import-time state.""" v = os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower() return v in {"1", "true", "yes"} _ws_active_sessions = 0 _ws_session_lock = asyncio.Lock() session_manager = EnvironmentSessionManager( max_sessions=MAX_HTTP_SESSIONS, session_timeout_s=HTTP_SESSION_TIMEOUT_S, ) @asynccontextmanager async def lifespan(_app: FastAPI) -> AsyncIterator[None]: logger.info("DataOpsEnv starting.") yield session_manager.close_all() logger.info("DataOpsEnv shutting down.") app = FastAPI( title="DataOpsEnv", description="Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant).", version="1.0.0", lifespan=lifespan, ) def _cors_allow_origins() -> list[str]: configured = os.getenv("CORS_ALLOW_ORIGINS", "").strip() if not configured: return [] if configured == "*": return ["*"] return [item.strip() for item in configured.split(",") if item.strip()] app.add_middleware( CORSMiddleware, allow_origins=_cors_allow_origins(), allow_methods=["*"], allow_headers=["*"], ) def _load_manifest() -> dict: yaml_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "openenv.yaml") try: with open(yaml_path, encoding="utf-8") as f: return yaml.safe_load(f) or {} except FileNotFoundError: return {} def _load_yaml_tasks() -> list[dict]: manifest = _load_manifest() tasks = manifest.get("tasks") if isinstance(tasks, list) and tasks: manifest_ids = [str(item.get("id", "")) for item in tasks] if manifest_ids == TASK_IDS: return tasks return task_manifest_entries() def _wrap_obs(obs: DataOpsObservation) -> dict: """Serialise an observation to the standard OpenEnv response dict.""" return obs.model_dump() def _get_session_id(request: Request) -> str | None: header_value = request.headers.get(SESSION_HEADER_NAME) if header_value: return header_value.strip() or None cookie_value = request.cookies.get(SESSION_COOKIE_NAME) if cookie_value: return cookie_value.strip() or None return None def _attach_session(response: Response, session_id: str) -> None: response.set_cookie( key=SESSION_COOKIE_NAME, value=session_id, httponly=True, samesite="lax", secure=COOKIE_SECURE, max_age=int(HTTP_SESSION_TIMEOUT_S), ) response.headers[SESSION_HEADER_NAME] = session_id def _require_active_env(request: Request) -> tuple[str, DataOpsEnvironment]: session_id, env = session_manager.get_session(_get_session_id(request)) if session_id is None or env is None: raise HTTPException(400, "No active episode. Call /reset first.") return session_id, env def _ws_error_payload(message: str, code: WSErrorCode) -> str: return WSErrorResponse( data={ "message": message, "code": code.value, } ).model_dump_json() def _require_admin(request: Request) -> None: if not ADMIN_API_KEY: return if request.headers.get("X-Admin-Key", "") != ADMIN_API_KEY: raise HTTPException(403, "Missing or invalid admin key.") def _request_is_admin(request: Request) -> bool: return bool(ADMIN_API_KEY) and request.headers.get("X-Admin-Key", "") == ADMIN_API_KEY def _normalize_reported_score(value: object) -> float: try: score = float(value) except (TypeError, ValueError): return MIN_REPORTED_SCORE if score <= 0.0: return MIN_REPORTED_SCORE if score >= 1.0: return MAX_REPORTED_SCORE score = round(score, 2) if score <= 0.0: return MIN_REPORTED_SCORE if score >= 1.0: return MAX_REPORTED_SCORE return score def _normalize_grade_payload(grade: dict) -> dict: payload = dict(grade) payload["score"] = _normalize_reported_score(payload.get("score")) return payload def _format_grader_response(grade: dict, request: Request) -> dict: grade = _normalize_grade_payload(grade) if _public_grader_details_enabled() or _request_is_admin(request): return grade return {"task_id": grade.get("task_id"), "score": grade.get("score")} async def _try_acquire_ws_slot() -> bool: global _ws_active_sessions async with _ws_session_lock: if _ws_active_sessions >= MAX_WS_SESSIONS: return False _ws_active_sessions += 1 return True async def _release_ws_slot() -> None: global _ws_active_sessions async with _ws_session_lock: _ws_active_sessions = max(0, _ws_active_sessions - 1) @app.get("/health", response_model=HealthResponse) def health_endpoint(): return HealthResponse(status=HealthStatus.HEALTHY) @app.get("/metadata") def metadata_endpoint(): manifest = _load_manifest() return { "name": manifest.get("name", "dataops_env"), "description": manifest.get( "description", ( "Enterprise data pipeline remediation environment. " "Agents debug data streams, fix scripts, and send email reports." ), ), "version": manifest.get("version", "1.0.0"), "task_count": len(_load_yaml_tasks()), } @app.get("/schema") def schema_endpoint(): return { "action": DataOpsAction.model_json_schema(), "observation": DataOpsObservation.model_json_schema(), "state": DataOpsState.model_json_schema(), } @app.post("/mcp") def mcp_endpoint(body: dict = Body(default_factory=dict)): method = body.get("method", "") req_id = body.get("id") if method == "tools/list": tools = [ {"name": atype, "description": f"Execute a {atype} action."} for atype in [ "ExecuteSQL", "ReadFile", "WriteFile", "RunScript", "SendEmail", ] ] return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}} return { "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": "Method not found"}, } @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() acquired_slot = await _try_acquire_ws_slot() if not acquired_slot: await websocket.send_text( _ws_error_payload( "WebSocket session capacity reached.", WSErrorCode.CAPACITY_REACHED, ) ) await websocket.close(code=1013) return env = DataOpsEnvironment() try: while True: raw_message = await websocket.receive_text() try: message_dict = json.loads(raw_message) except json.JSONDecodeError: await websocket.send_text( _ws_error_payload("Invalid JSON payload.", WSErrorCode.INVALID_JSON) ) continue message_type = message_dict.get("type", "") try: if message_type == "reset": message = WSResetMessage(**message_dict) observation = env.reset(**message.data) response = WSObservationResponse( data=serialize_observation(observation) ) elif message_type == "step": message = WSStepMessage(**message_dict) action = DataOpsAction(**message.data) observation = env.step(action) response = WSObservationResponse( data=serialize_observation(observation) ) elif message_type == "state": WSStateMessage(**message_dict) response = WSStateResponse(data=env.state.model_dump()) elif message_type == "close": WSCloseMessage(**message_dict) break else: await websocket.send_text( _ws_error_payload( f"Unknown message type: {message_type}", WSErrorCode.UNKNOWN_TYPE, ) ) continue await websocket.send_text(response.model_dump_json()) except ValidationError: await websocket.send_text( _ws_error_payload( "Validation error while handling the WebSocket message.", WSErrorCode.VALIDATION_ERROR, ) ) except Exception: logger.exception("WebSocket execution error") await websocket.send_text( _ws_error_payload( "Execution error while handling the WebSocket message.", WSErrorCode.EXECUTION_ERROR, ) ) except WebSocketDisconnect: logger.debug("WebSocket client disconnected.") finally: env.close() await _release_ws_slot() @app.post("/reset", response_model=ResetResponse) def reset_endpoint( request: Request, response: Response, task_id: str = Query("task_1_easy_anomaly", description="Task to initialise."), body: ResetRequest = Body(default_factory=ResetRequest), ): if task_id not in TASK_IDS: raise HTTPException(400, f"Invalid task_id. Choose from: {TASK_IDS}") session_id = _get_session_id(request) resolved_session_id, _env, obs = session_manager.reset_session( task_id=task_id, seed=body.seed, episode_id=body.episode_id, session_id=session_id, ) _attach_session(response, resolved_session_id) return ResetResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done) @app.post("/step", response_model=StepResponse) def step_endpoint(request: Request, response: Response, body: StepRequest): try: action = DataOpsAction(**body.action) except ValidationError as e: raise HTTPException(422, f"Invalid action: {e}") from e session_id, env = _require_active_env(request) _attach_session(response, session_id) obs = env.step(action, timeout_s=body.timeout_s) return StepResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done) @app.get("/state", response_model=DataOpsState) def state_endpoint(request: Request, response: Response): session_id, env = _require_active_env(request) _attach_session(response, session_id) return env.state @app.get("/tasks") def tasks_endpoint(): return { "tasks": _load_yaml_tasks(), "action_schema": DataOpsAction.model_json_schema(), "observation_schema": DataOpsObservation.model_json_schema(), "state_schema": DataOpsState.model_json_schema(), } @app.get("/grader") def grader_current_endpoint(request: Request, response: Response): """Grade the current episode (uses active task_id from state).""" session_id, env = _require_active_env(request) _attach_session(response, session_id) task_id = env.state.task_id if not task_id: raise HTTPException(400, "No active episode. Call /reset first.") return _format_grader_response(evaluate_task(task_id, env), request) @app.get("/grader/{task_id}") def grader_endpoint(task_id: str, request: Request, response: Response): if task_id not in TASK_IDS: raise HTTPException(404, f"Unknown task: {task_id}") session_id, env = _require_active_env(request) _attach_session(response, session_id) active_task_id = env.state.task_id if active_task_id and active_task_id != task_id: raise HTTPException( 400, f"Active episode belongs to task '{active_task_id}'. Reset the requested task first.", ) return _format_grader_response(evaluate_task(task_id, env), request) @app.post("/baseline") def baseline_endpoint(request: Request, body: dict = Body(default_factory=dict)): """Run inference.py (OpenAI tool-calling agent) against all tasks; same entrypoint as local baseline.""" _require_admin(request) if not ( os.environ.get("API_KEY", "").strip() or os.environ.get("HF_TOKEN", "").strip() ): raise HTTPException( 503, "API_KEY or HF_TOKEN must be set on the server process to run POST /baseline.", ) script_path = PROJECT_ROOT / "inference.py" if not script_path.is_file(): raise HTTPException(500, "inference.py missing from project root.") port = int(os.getenv("PORT", "7860")) timeout_s = HTTP_SESSION_TIMEOUT_S env = { **os.environ, "ENV_BASE_URL": os.getenv("ENV_BASE_URL", f"http://127.0.0.1:{port}"), } command = [sys.executable, str(script_path), "--json-scores"] if body.get("seed") is not None: command.extend(["--seed", str(int(body["seed"]))]) if body.get("max_turns") is not None: command.extend(["--max-turns", str(int(body["max_turns"]))]) for task_id in body.get("task_ids", []) or []: if task_id in TASK_IDS: command.extend(["--task", str(task_id)]) try: proc = subprocess.run( command, cwd=str(PROJECT_ROOT), capture_output=True, text=True, timeout=timeout_s, env=env, ) except subprocess.TimeoutExpired: raise HTTPException( 504, f"Baseline exceeded HTTP_SESSION_TIMEOUT_S ({timeout_s}s)." ) from None if proc.returncode != 0: tail = (proc.stderr or proc.stdout or "")[-6000:] logger.error("inference.py failed rc=%s stderr=%s", proc.returncode, tail[:500]) raise HTTPException( 502, {"message": "inference.py exited with an error.", "detail": tail}, ) lines = [ln.strip() for ln in (proc.stdout or "").splitlines() if ln.strip()] parsed = None for line in reversed(lines): try: parsed = json.loads(line) break except json.JSONDecodeError: continue if not isinstance(parsed, dict) or "scores" not in parsed: raise HTTPException( 502, { "message": "Could not parse JSON scores from inference.py stdout.", "stdout_tail": "\n".join(lines[-5:]), }, ) return { "message": "Model baseline completed via inference.py.", "stdout": proc.stdout, "stderr": proc.stderr, "scores": parsed["scores"], "grades": parsed.get("grades"), "average": parsed.get("average"), "model": parsed.get("model"), "metadata": parsed.get("metadata"), } def main(): """Entry point for `dataops-env` script and `openenv serve`.""" host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "7860")) reload = os.getenv("DEBUG", "").lower() in ("1", "true") cwd = Path.cwd().resolve() app_target = "app:app" if cwd == SERVER_DIR else "server.app:app" app_dir = str(SERVER_DIR if app_target == "app:app" else PROJECT_ROOT) uvicorn.run( app_target, host=host, port=port, reload=reload, reload_dirs=[str(PROJECT_ROOT)] if reload else None, ws="wsproto", app_dir=app_dir, ) if __name__ == "__main__": main()