| """ |
| OpenENV Moderation Environment — FastAPI application. |
| |
| Standard OpenEnv endpoints: |
| WS /ws — persistent WebSocket session (primary client interface) |
| GET /health — liveness check |
| POST /reset — start a new episode |
| POST /step — take an action |
| GET /state — current observation / state |
| GET /docs — OpenAPI documentation (auto-generated) |
| |
| Custom endpoints: |
| GET /tasks — available tasks |
| GET /grader — final episode score |
| GET /baseline — run rule-based baseline agent and return its score |
| POST /agent/run — run selected LLM agent on a full episode |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import logging |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| from fastapi import FastAPI, HTTPException, Body, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from openenv.core.env_server.types import ( |
| HealthResponse, |
| HealthStatus, |
| ResetRequest as OEResetRequest, |
| ResetResponse, |
| StepRequest, |
| StepResponse, |
| WSObservationResponse, |
| WSStateResponse, |
| WSErrorResponse, |
| WSErrorCode, |
| ) |
|
|
| from data.tasks import TASKS |
| from env.grader import Grader |
| from env.state_manager import StateManager |
| from models.schemas import ( |
| Action, |
| BaselineResult, |
| EpisodeScore, |
| ResetRequest, |
| TaskConfig, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="OpenENV — Content Moderation Environment", |
| description=( |
| "A multi-step RL environment for AI content moderation agents. " |
| "Agents receive partial observations and must investigate context, " |
| "classify violations, and make final moderation decisions." |
| ), |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| _state_manager = StateManager() |
| _grader = Grader() |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health() -> HealthResponse: |
| return HealthResponse(status=HealthStatus.HEALTHY) |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks() -> dict[str, TaskConfig]: |
| return TASKS |
|
|
|
|
| @app.post("/reset", response_model=ResetResponse) |
| def reset(request: OEResetRequest | None = Body(default=None)) -> ResetResponse: |
| |
| extra = (request.model_extra or {}) if request else {} |
| task_id = extra.get("task_id") or (request.episode_id if request else None) or "easy_harassment" |
| seed = (request.seed if request else None) or 42 |
|
|
| if task_id not in TASKS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}", |
| ) |
|
|
| task = TASKS[task_id] |
| task = task.model_copy(update={"seed": seed}) |
|
|
| obs = _state_manager.reset(task) |
| return ResetResponse(observation=obs.model_dump(), reward=None, done=obs.done) |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(request: StepRequest) -> StepResponse: |
| if not _state_manager.has_active_episode(): |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") |
|
|
| try: |
| action = Action(**request.action) |
| except Exception as exc: |
| raise HTTPException(status_code=422, detail=str(exc)) |
|
|
| try: |
| result = _state_manager.step(action) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) |
|
|
| logger.info( |
| "Step %d: action=%s reward=%.3f done=%s", |
| result.observation.step, |
| action.action_type.value, |
| result.reward, |
| result.done, |
| ) |
| return StepResponse( |
| observation=result.observation.model_dump(), |
| reward=result.reward, |
| done=result.done, |
| ) |
|
|
|
|
| @app.get("/state") |
| def get_state() -> dict: |
| if not _state_manager.has_active_episode(): |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") |
| return _state_manager.get_state().model_dump() |
|
|
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket) -> None: |
| await websocket.accept() |
| try: |
| while True: |
| try: |
| raw = await websocket.receive_text() |
| data = json.loads(raw) |
| except json.JSONDecodeError: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": "Invalid JSON", "code": WSErrorCode.INVALID_JSON}).model_dump_json() |
| ) |
| continue |
|
|
| msg_type = data.get("type") |
|
|
| if msg_type == "reset": |
| reset_data = data.get("data", {}) |
| task_id = reset_data.get("task_id") or reset_data.get("episode_id") or "easy_harassment" |
| seed = reset_data.get("seed") or 42 |
|
|
| if task_id not in TASKS: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": f"Unknown task_id '{task_id}'", "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| task = TASKS[task_id].model_copy(update={"seed": seed}) |
| obs = _state_manager.reset(task) |
| await websocket.send_text( |
| WSObservationResponse(data={"observation": obs.model_dump(), "reward": None, "done": obs.done}).model_dump_json() |
| ) |
|
|
| elif msg_type == "step": |
| if not _state_manager.has_active_episode(): |
| await websocket.send_text( |
| WSErrorResponse(data={"message": "No active episode. Send reset first.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| action_data = data.get("data", {}) |
| try: |
| action = Action(**action_data) |
| except Exception as exc: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| try: |
| result = _state_manager.step(action) |
| except ValueError as exc: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.EXECUTION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| await websocket.send_text( |
| WSObservationResponse(data={"observation": result.observation.model_dump(), "reward": result.reward, "done": result.done}).model_dump_json() |
| ) |
|
|
| elif msg_type == "state": |
| if not _state_manager.has_active_episode(): |
| await websocket.send_text( |
| WSErrorResponse(data={"message": "No active episode.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json() |
| ) |
| continue |
| obs = _state_manager.get_state() |
| await websocket.send_text( |
| WSStateResponse(data=obs.model_dump()).model_dump_json() |
| ) |
|
|
| elif msg_type == "close": |
| break |
|
|
| else: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": f"Unknown message type: {msg_type!r}", "code": WSErrorCode.UNKNOWN_TYPE}).model_dump_json() |
| ) |
|
|
| except WebSocketDisconnect: |
| pass |
|
|
|
|
| @app.get("/grader", response_model=EpisodeScore) |
| def grade() -> EpisodeScore: |
| if not _state_manager.has_active_episode(): |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") |
|
|
| episode = _state_manager.get_episode_state() |
| if not episode.observation.done: |
| raise HTTPException( |
| status_code=400, |
| detail="Episode is not finished yet. Complete the episode before grading.", |
| ) |
|
|
| score = _grader.score(episode) |
| logger.info("Graded episode: total=%.4f", score.total) |
| return score |
|
|
|
|
| @app.get("/baseline", response_model=BaselineResult) |
| def baseline(task_id: str = "easy_harassment", seed: int | None = None) -> BaselineResult: |
| """Run the built-in rule-based baseline agent and return its score.""" |
| from baseline.agent import BaselineAgent |
|
|
| if task_id not in TASKS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}", |
| ) |
|
|
| task = TASKS[task_id] |
| if seed is not None: |
| task = task.model_copy(update={"seed": seed}) |
|
|
| agent = BaselineAgent(state_manager=_state_manager, grader=_grader) |
| result = agent.run(task) |
| return result |
|
|
|
|
| @app.post("/agent/run", response_model=BaselineResult) |
| def agent_run(request: ResetRequest) -> BaselineResult: |
| """ |
| Run the selected LLM agent (OpenAI or Gemini) on a full episode and return the graded result. |
| |
| Requires OPENAI_API_KEY, or GOOGLE_API_KEY/GEMINI_API_KEY depending on LLM_PROVIDER. |
| """ |
| import os |
| from agent.openai_agent import OpenAIAgent |
| from agent.gemini_agent import GeminiAgent |
| |
| provider = os.getenv("LLM_PROVIDER", "openai").lower() |
|
|
| if request.task_id not in TASKS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown task_id '{request.task_id}'. Available: {list(TASKS.keys())}", |
| ) |
|
|
| task = TASKS[request.task_id] |
| if request.seed is not None: |
| task = task.model_copy(update={"seed": request.seed}) |
|
|
| try: |
| if provider == "gemini": |
| agent = GeminiAgent(state_manager=_state_manager, grader=_grader) |
| else: |
| agent = OpenAIAgent(state_manager=_state_manager, grader=_grader) |
| except EnvironmentError as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) |
|
|
| result = agent.run(task) |
| logger.info( |
| "%s agent finished: task=%s total=%.4f", provider.capitalize(), task.task_id, result.score.total |
| ) |
| return result |
|
|