dataops-env / server /app.py
visheshrathi's picture
Upload folder using huggingface_hub
a1b343c verified
import asyncio
import json
import logging
import os
import subprocess
import sys
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
import yaml
import uvicorn
from fastapi import (
Body,
FastAPI,
HTTPException,
Query,
Request,
Response,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from openenv.core.env_server.http_server import serialize_observation
from openenv.core.env_server.types import (
HealthResponse,
HealthStatus,
ResetRequest,
ResetResponse,
StepRequest,
StepResponse,
WSCloseMessage,
WSErrorCode,
WSErrorResponse,
WSObservationResponse,
WSResetMessage,
WSStateMessage,
WSStateResponse,
WSStepMessage,
)
from pydantic import ValidationError
from env_loader import load_env
from models import DataOpsAction, DataOpsObservation, DataOpsState
from server.dataops_env_environment import DataOpsEnvironment
from server.grading import evaluate_task
from server.session_manager import EnvironmentSessionManager
from server.task_specs import TASK_IDS, task_manifest_entries
# Repo root must be on sys.path (e.g. run `uv run python -m server.app` or uvicorn from project root).
PROJECT_ROOT = Path(__file__).resolve().parents[1]
SERVER_DIR = Path(__file__).resolve().parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s | %(name)s | %(message)s")
logger = logging.getLogger(__name__)
load_env()
SESSION_COOKIE_NAME = "dataops_session_id"
SESSION_HEADER_NAME = "X-Session-ID"
MAX_HTTP_SESSIONS = int(os.getenv("MAX_HTTP_SESSIONS", "128"))
HTTP_SESSION_TIMEOUT_S = float(os.getenv("HTTP_SESSION_TIMEOUT_S", "1200"))
MAX_WS_SESSIONS = max(1, int(os.getenv("MAX_WS_SESSIONS", "64")))
ADMIN_API_KEY = os.getenv("ADMIN_API_KEY", "").strip()
COOKIE_SECURE = os.getenv("COOKIE_SECURE", "").lower() in {"1", "true", "yes"}
MIN_REPORTED_SCORE = 0.01
MAX_REPORTED_SCORE = 0.99
def _public_grader_details_enabled() -> bool:
"""Read at request time so env / tests can control visibility without stale import-time state."""
v = os.getenv("PUBLIC_GRADER_DETAILS", "").strip().lower()
return v in {"1", "true", "yes"}
_ws_active_sessions = 0
_ws_session_lock = asyncio.Lock()
session_manager = EnvironmentSessionManager(
max_sessions=MAX_HTTP_SESSIONS,
session_timeout_s=HTTP_SESSION_TIMEOUT_S,
)
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncIterator[None]:
logger.info("DataOpsEnv starting.")
yield
session_manager.close_all()
logger.info("DataOpsEnv shutting down.")
app = FastAPI(
title="DataOpsEnv",
description="Enterprise data pipeline remediation environment for training AI agents (OpenEnv-compliant).",
version="1.0.0",
lifespan=lifespan,
)
def _cors_allow_origins() -> list[str]:
configured = os.getenv("CORS_ALLOW_ORIGINS", "").strip()
if not configured:
return []
if configured == "*":
return ["*"]
return [item.strip() for item in configured.split(",") if item.strip()]
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_allow_origins(),
allow_methods=["*"],
allow_headers=["*"],
)
def _load_manifest() -> dict:
yaml_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "openenv.yaml")
try:
with open(yaml_path, encoding="utf-8") as f:
return yaml.safe_load(f) or {}
except FileNotFoundError:
return {}
def _load_yaml_tasks() -> list[dict]:
manifest = _load_manifest()
tasks = manifest.get("tasks")
if isinstance(tasks, list) and tasks:
manifest_ids = [str(item.get("id", "")) for item in tasks]
if manifest_ids == TASK_IDS:
return tasks
return task_manifest_entries()
def _wrap_obs(obs: DataOpsObservation) -> dict:
"""Serialise an observation to the standard OpenEnv response dict."""
return obs.model_dump()
def _get_session_id(request: Request) -> str | None:
header_value = request.headers.get(SESSION_HEADER_NAME)
if header_value:
return header_value.strip() or None
cookie_value = request.cookies.get(SESSION_COOKIE_NAME)
if cookie_value:
return cookie_value.strip() or None
return None
def _attach_session(response: Response, session_id: str) -> None:
response.set_cookie(
key=SESSION_COOKIE_NAME,
value=session_id,
httponly=True,
samesite="lax",
secure=COOKIE_SECURE,
max_age=int(HTTP_SESSION_TIMEOUT_S),
)
response.headers[SESSION_HEADER_NAME] = session_id
def _require_active_env(request: Request) -> tuple[str, DataOpsEnvironment]:
session_id, env = session_manager.get_session(_get_session_id(request))
if session_id is None or env is None:
raise HTTPException(400, "No active episode. Call /reset first.")
return session_id, env
def _ws_error_payload(message: str, code: WSErrorCode) -> str:
return WSErrorResponse(
data={
"message": message,
"code": code.value,
}
).model_dump_json()
def _require_admin(request: Request) -> None:
if not ADMIN_API_KEY:
return
if request.headers.get("X-Admin-Key", "") != ADMIN_API_KEY:
raise HTTPException(403, "Missing or invalid admin key.")
def _request_is_admin(request: Request) -> bool:
return bool(ADMIN_API_KEY) and request.headers.get("X-Admin-Key", "") == ADMIN_API_KEY
def _normalize_reported_score(value: object) -> float:
try:
score = float(value)
except (TypeError, ValueError):
return MIN_REPORTED_SCORE
if score <= 0.0:
return MIN_REPORTED_SCORE
if score >= 1.0:
return MAX_REPORTED_SCORE
score = round(score, 2)
if score <= 0.0:
return MIN_REPORTED_SCORE
if score >= 1.0:
return MAX_REPORTED_SCORE
return score
def _normalize_grade_payload(grade: dict) -> dict:
payload = dict(grade)
payload["score"] = _normalize_reported_score(payload.get("score"))
return payload
def _format_grader_response(grade: dict, request: Request) -> dict:
grade = _normalize_grade_payload(grade)
if _public_grader_details_enabled() or _request_is_admin(request):
return grade
return {"task_id": grade.get("task_id"), "score": grade.get("score")}
async def _try_acquire_ws_slot() -> bool:
global _ws_active_sessions
async with _ws_session_lock:
if _ws_active_sessions >= MAX_WS_SESSIONS:
return False
_ws_active_sessions += 1
return True
async def _release_ws_slot() -> None:
global _ws_active_sessions
async with _ws_session_lock:
_ws_active_sessions = max(0, _ws_active_sessions - 1)
@app.get("/health", response_model=HealthResponse)
def health_endpoint():
return HealthResponse(status=HealthStatus.HEALTHY)
@app.get("/metadata")
def metadata_endpoint():
manifest = _load_manifest()
return {
"name": manifest.get("name", "dataops_env"),
"description": manifest.get(
"description",
(
"Enterprise data pipeline remediation environment. "
"Agents debug data streams, fix scripts, and send email reports."
),
),
"version": manifest.get("version", "1.0.0"),
"task_count": len(_load_yaml_tasks()),
}
@app.get("/schema")
def schema_endpoint():
return {
"action": DataOpsAction.model_json_schema(),
"observation": DataOpsObservation.model_json_schema(),
"state": DataOpsState.model_json_schema(),
}
@app.post("/mcp")
def mcp_endpoint(body: dict = Body(default_factory=dict)):
method = body.get("method", "")
req_id = body.get("id")
if method == "tools/list":
tools = [
{"name": atype, "description": f"Execute a {atype} action."}
for atype in [
"ExecuteSQL",
"ReadFile",
"WriteFile",
"RunScript",
"SendEmail",
]
]
return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": tools}}
return {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": "Method not found"},
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
acquired_slot = await _try_acquire_ws_slot()
if not acquired_slot:
await websocket.send_text(
_ws_error_payload(
"WebSocket session capacity reached.",
WSErrorCode.CAPACITY_REACHED,
)
)
await websocket.close(code=1013)
return
env = DataOpsEnvironment()
try:
while True:
raw_message = await websocket.receive_text()
try:
message_dict = json.loads(raw_message)
except json.JSONDecodeError:
await websocket.send_text(
_ws_error_payload("Invalid JSON payload.", WSErrorCode.INVALID_JSON)
)
continue
message_type = message_dict.get("type", "")
try:
if message_type == "reset":
message = WSResetMessage(**message_dict)
observation = env.reset(**message.data)
response = WSObservationResponse(
data=serialize_observation(observation)
)
elif message_type == "step":
message = WSStepMessage(**message_dict)
action = DataOpsAction(**message.data)
observation = env.step(action)
response = WSObservationResponse(
data=serialize_observation(observation)
)
elif message_type == "state":
WSStateMessage(**message_dict)
response = WSStateResponse(data=env.state.model_dump())
elif message_type == "close":
WSCloseMessage(**message_dict)
break
else:
await websocket.send_text(
_ws_error_payload(
f"Unknown message type: {message_type}",
WSErrorCode.UNKNOWN_TYPE,
)
)
continue
await websocket.send_text(response.model_dump_json())
except ValidationError:
await websocket.send_text(
_ws_error_payload(
"Validation error while handling the WebSocket message.",
WSErrorCode.VALIDATION_ERROR,
)
)
except Exception:
logger.exception("WebSocket execution error")
await websocket.send_text(
_ws_error_payload(
"Execution error while handling the WebSocket message.",
WSErrorCode.EXECUTION_ERROR,
)
)
except WebSocketDisconnect:
logger.debug("WebSocket client disconnected.")
finally:
env.close()
await _release_ws_slot()
@app.post("/reset", response_model=ResetResponse)
def reset_endpoint(
request: Request,
response: Response,
task_id: str = Query("task_1_easy_anomaly", description="Task to initialise."),
body: ResetRequest = Body(default_factory=ResetRequest),
):
if task_id not in TASK_IDS:
raise HTTPException(400, f"Invalid task_id. Choose from: {TASK_IDS}")
session_id = _get_session_id(request)
resolved_session_id, _env, obs = session_manager.reset_session(
task_id=task_id,
seed=body.seed,
episode_id=body.episode_id,
session_id=session_id,
)
_attach_session(response, resolved_session_id)
return ResetResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done)
@app.post("/step", response_model=StepResponse)
def step_endpoint(request: Request, response: Response, body: StepRequest):
try:
action = DataOpsAction(**body.action)
except ValidationError as e:
raise HTTPException(422, f"Invalid action: {e}") from e
session_id, env = _require_active_env(request)
_attach_session(response, session_id)
obs = env.step(action, timeout_s=body.timeout_s)
return StepResponse(observation=_wrap_obs(obs), reward=obs.reward, done=obs.done)
@app.get("/state", response_model=DataOpsState)
def state_endpoint(request: Request, response: Response):
session_id, env = _require_active_env(request)
_attach_session(response, session_id)
return env.state
@app.get("/tasks")
def tasks_endpoint():
return {
"tasks": _load_yaml_tasks(),
"action_schema": DataOpsAction.model_json_schema(),
"observation_schema": DataOpsObservation.model_json_schema(),
"state_schema": DataOpsState.model_json_schema(),
}
@app.get("/grader")
def grader_current_endpoint(request: Request, response: Response):
"""Grade the current episode (uses active task_id from state)."""
session_id, env = _require_active_env(request)
_attach_session(response, session_id)
task_id = env.state.task_id
if not task_id:
raise HTTPException(400, "No active episode. Call /reset first.")
return _format_grader_response(evaluate_task(task_id, env), request)
@app.get("/grader/{task_id}")
def grader_endpoint(task_id: str, request: Request, response: Response):
if task_id not in TASK_IDS:
raise HTTPException(404, f"Unknown task: {task_id}")
session_id, env = _require_active_env(request)
_attach_session(response, session_id)
active_task_id = env.state.task_id
if active_task_id and active_task_id != task_id:
raise HTTPException(
400,
f"Active episode belongs to task '{active_task_id}'. Reset the requested task first.",
)
return _format_grader_response(evaluate_task(task_id, env), request)
@app.post("/baseline")
def baseline_endpoint(request: Request, body: dict = Body(default_factory=dict)):
"""Run inference.py (OpenAI tool-calling agent) against all tasks; same entrypoint as local baseline."""
_require_admin(request)
if not (
os.environ.get("API_KEY", "").strip()
or os.environ.get("HF_TOKEN", "").strip()
):
raise HTTPException(
503,
"API_KEY or HF_TOKEN must be set on the server process to run POST /baseline.",
)
script_path = PROJECT_ROOT / "inference.py"
if not script_path.is_file():
raise HTTPException(500, "inference.py missing from project root.")
port = int(os.getenv("PORT", "7860"))
timeout_s = HTTP_SESSION_TIMEOUT_S
env = {
**os.environ,
"ENV_BASE_URL": os.getenv("ENV_BASE_URL", f"http://127.0.0.1:{port}"),
}
command = [sys.executable, str(script_path), "--json-scores"]
if body.get("seed") is not None:
command.extend(["--seed", str(int(body["seed"]))])
if body.get("max_turns") is not None:
command.extend(["--max-turns", str(int(body["max_turns"]))])
for task_id in body.get("task_ids", []) or []:
if task_id in TASK_IDS:
command.extend(["--task", str(task_id)])
try:
proc = subprocess.run(
command,
cwd=str(PROJECT_ROOT),
capture_output=True,
text=True,
timeout=timeout_s,
env=env,
)
except subprocess.TimeoutExpired:
raise HTTPException(
504, f"Baseline exceeded HTTP_SESSION_TIMEOUT_S ({timeout_s}s)."
) from None
if proc.returncode != 0:
tail = (proc.stderr or proc.stdout or "")[-6000:]
logger.error("inference.py failed rc=%s stderr=%s", proc.returncode, tail[:500])
raise HTTPException(
502,
{"message": "inference.py exited with an error.", "detail": tail},
)
lines = [ln.strip() for ln in (proc.stdout or "").splitlines() if ln.strip()]
parsed = None
for line in reversed(lines):
try:
parsed = json.loads(line)
break
except json.JSONDecodeError:
continue
if not isinstance(parsed, dict) or "scores" not in parsed:
raise HTTPException(
502,
{
"message": "Could not parse JSON scores from inference.py stdout.",
"stdout_tail": "\n".join(lines[-5:]),
},
)
return {
"message": "Model baseline completed via inference.py.",
"stdout": proc.stdout,
"stderr": proc.stderr,
"scores": parsed["scores"],
"grades": parsed.get("grades"),
"average": parsed.get("average"),
"model": parsed.get("model"),
"metadata": parsed.get("metadata"),
}
def main():
"""Entry point for `dataops-env` script and `openenv serve`."""
host = os.getenv("HOST", "0.0.0.0")
port = int(os.getenv("PORT", "7860"))
reload = os.getenv("DEBUG", "").lower() in ("1", "true")
cwd = Path.cwd().resolve()
app_target = "app:app" if cwd == SERVER_DIR else "server.app:app"
app_dir = str(SERVER_DIR if app_target == "app:app" else PROJECT_ROOT)
uvicorn.run(
app_target,
host=host,
port=port,
reload=reload,
reload_dirs=[str(PROJECT_ROOT)] if reload else None,
ws="wsproto",
app_dir=app_dir,
)
if __name__ == "__main__":
main()