""" OpenENV Moderation Environment — FastAPI application. Standard OpenEnv endpoints: WS /ws — persistent WebSocket session (primary client interface) GET /health — liveness check POST /reset — start a new episode POST /step — take an action GET /state — current observation / state GET /docs — OpenAPI documentation (auto-generated) Custom endpoints: GET /tasks — available tasks GET /grader — final episode score GET /baseline — run rule-based baseline agent and return its score POST /agent/run — run selected LLM agent on a full episode """ from __future__ import annotations import json import logging from dotenv import load_dotenv load_dotenv() # loads .env from project root before anything else from fastapi import FastAPI, HTTPException, Body, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from openenv.core.env_server.types import ( HealthResponse, HealthStatus, ResetRequest as OEResetRequest, ResetResponse, StepRequest, StepResponse, WSObservationResponse, WSStateResponse, WSErrorResponse, WSErrorCode, ) from data.tasks import TASKS from env.grader import Grader from env.state_manager import StateManager from models.schemas import ( Action, BaselineResult, EpisodeScore, ResetRequest, TaskConfig, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="OpenENV — Content Moderation Environment", description=( "A multi-step RL environment for AI content moderation agents. " "Agents receive partial observations and must investigate context, " "classify violations, and make final moderation decisions." ), version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Open for HF Spaces + local dev allow_methods=["*"], allow_headers=["*"], ) # Single shared state manager (single-threaded MVP) _state_manager = StateManager() _grader = Grader() # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @app.get("/health", response_model=HealthResponse) def health() -> HealthResponse: return HealthResponse(status=HealthStatus.HEALTHY) @app.get("/tasks") def list_tasks() -> dict[str, TaskConfig]: return TASKS @app.post("/reset", response_model=ResetResponse) def reset(request: OEResetRequest | None = Body(default=None)) -> ResetResponse: # task_id passed as extra field; fall back to episode_id or default extra = (request.model_extra or {}) if request else {} task_id = extra.get("task_id") or (request.episode_id if request else None) or "easy_harassment" seed = (request.seed if request else None) or 42 if task_id not in TASKS: raise HTTPException( status_code=400, detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}", ) task = TASKS[task_id] task = task.model_copy(update={"seed": seed}) obs = _state_manager.reset(task) return ResetResponse(observation=obs.model_dump(), reward=None, done=obs.done) @app.post("/step", response_model=StepResponse) def step(request: StepRequest) -> StepResponse: if not _state_manager.has_active_episode(): raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") try: action = Action(**request.action) except Exception as exc: raise HTTPException(status_code=422, detail=str(exc)) try: result = _state_manager.step(action) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) logger.info( "Step %d: action=%s reward=%.3f done=%s", result.observation.step, action.action_type.value, result.reward, result.done, ) return StepResponse( observation=result.observation.model_dump(), reward=result.reward, done=result.done, ) @app.get("/state") def get_state() -> dict: if not _state_manager.has_active_episode(): raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") return _state_manager.get_state().model_dump() @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() try: while True: try: raw = await websocket.receive_text() data = json.loads(raw) except json.JSONDecodeError: await websocket.send_text( WSErrorResponse(data={"message": "Invalid JSON", "code": WSErrorCode.INVALID_JSON}).model_dump_json() ) continue msg_type = data.get("type") if msg_type == "reset": reset_data = data.get("data", {}) task_id = reset_data.get("task_id") or reset_data.get("episode_id") or "easy_harassment" seed = reset_data.get("seed") or 42 if task_id not in TASKS: await websocket.send_text( WSErrorResponse(data={"message": f"Unknown task_id '{task_id}'", "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json() ) continue task = TASKS[task_id].model_copy(update={"seed": seed}) obs = _state_manager.reset(task) await websocket.send_text( WSObservationResponse(data={"observation": obs.model_dump(), "reward": None, "done": obs.done}).model_dump_json() ) elif msg_type == "step": if not _state_manager.has_active_episode(): await websocket.send_text( WSErrorResponse(data={"message": "No active episode. Send reset first.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json() ) continue action_data = data.get("data", {}) try: action = Action(**action_data) except Exception as exc: await websocket.send_text( WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json() ) continue try: result = _state_manager.step(action) except ValueError as exc: await websocket.send_text( WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.EXECUTION_ERROR}).model_dump_json() ) continue await websocket.send_text( WSObservationResponse(data={"observation": result.observation.model_dump(), "reward": result.reward, "done": result.done}).model_dump_json() ) elif msg_type == "state": if not _state_manager.has_active_episode(): await websocket.send_text( WSErrorResponse(data={"message": "No active episode.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json() ) continue obs = _state_manager.get_state() await websocket.send_text( WSStateResponse(data=obs.model_dump()).model_dump_json() ) elif msg_type == "close": break else: await websocket.send_text( WSErrorResponse(data={"message": f"Unknown message type: {msg_type!r}", "code": WSErrorCode.UNKNOWN_TYPE}).model_dump_json() ) except WebSocketDisconnect: pass @app.get("/grader", response_model=EpisodeScore) def grade() -> EpisodeScore: if not _state_manager.has_active_episode(): raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") episode = _state_manager.get_episode_state() if not episode.observation.done: raise HTTPException( status_code=400, detail="Episode is not finished yet. Complete the episode before grading.", ) score = _grader.score(episode) logger.info("Graded episode: total=%.4f", score.total) return score @app.get("/baseline", response_model=BaselineResult) def baseline(task_id: str = "easy_harassment", seed: int | None = None) -> BaselineResult: """Run the built-in rule-based baseline agent and return its score.""" from baseline.agent import BaselineAgent if task_id not in TASKS: raise HTTPException( status_code=400, detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}", ) task = TASKS[task_id] if seed is not None: task = task.model_copy(update={"seed": seed}) agent = BaselineAgent(state_manager=_state_manager, grader=_grader) result = agent.run(task) return result @app.post("/agent/run", response_model=BaselineResult) def agent_run(request: ResetRequest) -> BaselineResult: """ Run the selected LLM agent (OpenAI or Gemini) on a full episode and return the graded result. Requires OPENAI_API_KEY, or GOOGLE_API_KEY/GEMINI_API_KEY depending on LLM_PROVIDER. """ import os from agent.openai_agent import OpenAIAgent from agent.gemini_agent import GeminiAgent provider = os.getenv("LLM_PROVIDER", "openai").lower() if request.task_id not in TASKS: raise HTTPException( status_code=400, detail=f"Unknown task_id '{request.task_id}'. Available: {list(TASKS.keys())}", ) task = TASKS[request.task_id] if request.seed is not None: task = task.model_copy(update={"seed": request.seed}) try: if provider == "gemini": agent = GeminiAgent(state_manager=_state_manager, grader=_grader) else: agent = OpenAIAgent(state_manager=_state_manager, grader=_grader) except EnvironmentError as exc: raise HTTPException(status_code=500, detail=str(exc)) result = agent.run(task) logger.info( "%s agent finished: task=%s total=%.4f", provider.capitalize(), task.task_id, result.score.total ) return result