| """ |
| OpenENV Moderation Environment — FastAPI application. |
| |
| Standard OpenEnv endpoints: |
| WS /ws — persistent WebSocket session (primary client interface) |
| GET /health — liveness check |
| POST /reset — start a new episode |
| POST /step — take an action |
| GET /state — current observation / state |
| GET /docs — OpenAPI documentation (auto-generated) |
| |
| Custom endpoints: |
| GET /tasks — available tasks |
| GET /grader — final episode score |
| GET /baseline — run rule-based baseline agent and return its score |
| POST /agent/run — run selected LLM agent on a full episode |
| |
| Web interface (when ENABLE_WEB_INTERFACE=true): |
| GET /web — Gradio playground UI |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
|
|
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| from fastapi import FastAPI, HTTPException, Body, WebSocket, WebSocketDisconnect |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from openenv.core.env_server.types import ( |
| HealthResponse, |
| HealthStatus, |
| ResetRequest as OEResetRequest, |
| ResetResponse, |
| StepRequest, |
| StepResponse, |
| WSObservationResponse, |
| WSStateResponse, |
| WSErrorResponse, |
| WSErrorCode, |
| ) |
|
|
| from .tasks import TASKS |
| from .grader import Grader |
| from .state_manager import StateManager |
| from models import ( |
| Action, |
| BaselineResult, |
| EpisodeScore, |
| ResetRequest, |
| TaskConfig, |
| ) |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="OpenENV — Content Moderation Environment", |
| description=( |
| "A multi-step RL environment for AI content moderation agents. " |
| "Agents receive partial observations and must investigate context, " |
| "classify violations, and make final moderation decisions." |
| ), |
| version="1.0.0", |
| ) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| _state_manager = StateManager() |
| _grader = Grader() |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health() -> HealthResponse: |
| return HealthResponse(status=HealthStatus.HEALTHY) |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks() -> dict[str, TaskConfig]: |
| return TASKS |
|
|
|
|
| @app.post("/reset", response_model=ResetResponse) |
| def reset(request: OEResetRequest | None = Body(default=None)) -> ResetResponse: |
| |
| extra = (request.model_extra or {}) if request else {} |
| task_id = extra.get("task_id") or (request.episode_id if request else None) or "easy_harassment" |
| seed = (request.seed if request else None) or 42 |
|
|
| if task_id not in TASKS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}", |
| ) |
|
|
| task = TASKS[task_id] |
| task = task.model_copy(update={"seed": seed}) |
|
|
| obs = _state_manager.reset(task) |
| return ResetResponse(observation=obs.model_dump(), reward=None, done=obs.done) |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(request: StepRequest) -> StepResponse: |
| if not _state_manager.has_active_episode(): |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") |
|
|
| try: |
| action = Action(**request.action) |
| except Exception as exc: |
| raise HTTPException(status_code=422, detail=str(exc)) |
|
|
| try: |
| result = _state_manager.step(action) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) |
|
|
| logger.info( |
| "Step %d: action=%s reward=%.3f done=%s", |
| result.observation.step, |
| action.action_type.value, |
| result.reward, |
| result.done, |
| ) |
| return StepResponse( |
| observation=result.observation.model_dump(), |
| reward=result.reward, |
| done=result.done, |
| ) |
|
|
|
|
| @app.get("/state") |
| def get_state() -> dict: |
| if not _state_manager.has_active_episode(): |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") |
| return _state_manager.get_state().model_dump() |
|
|
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket) -> None: |
| await websocket.accept() |
| try: |
| while True: |
| try: |
| raw = await websocket.receive_text() |
| data = json.loads(raw) |
| except json.JSONDecodeError: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": "Invalid JSON", "code": WSErrorCode.INVALID_JSON}).model_dump_json() |
| ) |
| continue |
|
|
| msg_type = data.get("type") |
|
|
| if msg_type == "reset": |
| reset_data = data.get("data", {}) |
| task_id = reset_data.get("task_id") or reset_data.get("episode_id") or "easy_harassment" |
| seed = reset_data.get("seed") or 42 |
|
|
| if task_id not in TASKS: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": f"Unknown task_id '{task_id}'", "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| task = TASKS[task_id].model_copy(update={"seed": seed}) |
| obs = _state_manager.reset(task) |
| await websocket.send_text( |
| WSObservationResponse(data={"observation": obs.model_dump(), "reward": None, "done": obs.done}).model_dump_json() |
| ) |
|
|
| elif msg_type == "step": |
| if not _state_manager.has_active_episode(): |
| await websocket.send_text( |
| WSErrorResponse(data={"message": "No active episode. Send reset first.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| action_data = data.get("data", {}) |
| try: |
| action = Action(**action_data) |
| except Exception as exc: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| try: |
| result = _state_manager.step(action) |
| except ValueError as exc: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.EXECUTION_ERROR}).model_dump_json() |
| ) |
| continue |
|
|
| await websocket.send_text( |
| WSObservationResponse(data={"observation": result.observation.model_dump(), "reward": result.reward, "done": result.done}).model_dump_json() |
| ) |
|
|
| elif msg_type == "state": |
| if not _state_manager.has_active_episode(): |
| await websocket.send_text( |
| WSErrorResponse(data={"message": "No active episode.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json() |
| ) |
| continue |
| obs = _state_manager.get_state() |
| await websocket.send_text( |
| WSStateResponse(data=obs.model_dump()).model_dump_json() |
| ) |
|
|
| elif msg_type == "close": |
| break |
|
|
| else: |
| await websocket.send_text( |
| WSErrorResponse(data={"message": f"Unknown message type: {msg_type!r}", "code": WSErrorCode.UNKNOWN_TYPE}).model_dump_json() |
| ) |
|
|
| except WebSocketDisconnect: |
| pass |
|
|
|
|
| @app.get("/grader", response_model=EpisodeScore) |
| def grade() -> EpisodeScore: |
| if not _state_manager.has_active_episode(): |
| raise HTTPException(status_code=400, detail="No active episode. Call /reset first.") |
|
|
| episode = _state_manager.get_episode_state() |
| if not episode.observation.done: |
| raise HTTPException( |
| status_code=400, |
| detail="Episode is not finished yet. Complete the episode before grading.", |
| ) |
|
|
| score = _grader.score(episode) |
| logger.info("Graded episode: total=%.4f", score.total) |
| return score |
|
|
|
|
| @app.get("/baseline", response_model=BaselineResult) |
| def baseline(task_id: str = "easy_harassment", seed: int | None = None) -> BaselineResult: |
| """Run the built-in rule-based baseline agent and return its score.""" |
| from baseline.agent import BaselineAgent |
|
|
| if task_id not in TASKS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}", |
| ) |
|
|
| task = TASKS[task_id] |
| if seed is not None: |
| task = task.model_copy(update={"seed": seed}) |
|
|
| agent = BaselineAgent(state_manager=_state_manager, grader=_grader) |
| result = agent.run(task) |
| return result |
|
|
|
|
| @app.post("/agent/run", response_model=BaselineResult) |
| def agent_run(request: ResetRequest) -> BaselineResult: |
| """ |
| Run the selected LLM agent (OpenAI or Gemini) on a full episode and return the graded result. |
| |
| Requires OPENAI_API_KEY, or GOOGLE_API_KEY/GEMINI_API_KEY depending on LLM_PROVIDER. |
| """ |
| import os |
|
|
| provider = os.getenv("LLM_PROVIDER", "openai").lower() |
|
|
| if request.task_id not in TASKS: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Unknown task_id '{request.task_id}'. Available: {list(TASKS.keys())}", |
| ) |
|
|
| task = TASKS[request.task_id] |
| if request.seed is not None: |
| task = task.model_copy(update={"seed": request.seed}) |
|
|
| try: |
| if provider == "gemini": |
| from agent.gemini_agent import GeminiAgent |
| agent = GeminiAgent(state_manager=_state_manager, grader=_grader) |
| else: |
| from agent.openai_agent import OpenAIAgent |
| agent = OpenAIAgent(state_manager=_state_manager, grader=_grader) |
| except EnvironmentError as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) |
|
|
| result = agent.run(task) |
| logger.info( |
| "%s agent finished: task=%s total=%.4f", provider.capitalize(), task.task_id, result.score.total |
| ) |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def _build_web_interface(): |
| """Build and mount a Gradio playground UI at /web.""" |
| import gradio as gr |
| from fastapi.responses import RedirectResponse |
| from models import ActionType |
|
|
| task_choices = list(TASKS.keys()) |
| action_choices = [a.value for a in ActionType] |
|
|
| def _fmt_obs(obs_dict: dict) -> str: |
| lines = [] |
| for k, v in obs_dict.items(): |
| if v is not None: |
| lines.append(f"**{k}:** {v}") |
| return "\n\n".join(lines) if lines else "*No observation yet*" |
|
|
| def do_reset(task_id: str) -> tuple[str, str, str]: |
| try: |
| if task_id not in TASKS: |
| return "", "", f"Unknown task_id: {task_id}" |
| task = TASKS[task_id].model_copy(update={"seed": 42}) |
| obs = _state_manager.reset(task) |
| obs_dict = obs.model_dump() |
| return _fmt_obs(obs_dict), json.dumps(obs_dict, indent=2, default=str), "Reset successful." |
| except Exception as exc: |
| return "", "", f"Error: {exc}" |
|
|
| def do_step(action_type: str, params_json: str) -> tuple[str, str, str]: |
| if not _state_manager.has_active_episode(): |
| return "", "", "No active episode. Click Reset first." |
| try: |
| params = json.loads(params_json) if params_json.strip() else {} |
| except json.JSONDecodeError as exc: |
| return "", "", f"Invalid JSON params: {exc}" |
| try: |
| action = Action(action_type=ActionType(action_type), parameters=params) |
| result = _state_manager.step(action) |
| obs_dict = result.observation.model_dump() |
| status = f"Reward: {result.reward:.3f} | Done: {result.done}" |
| return _fmt_obs(obs_dict), json.dumps(obs_dict, indent=2, default=str), status |
| except Exception as exc: |
| return "", "", f"Error: {exc}" |
|
|
| def do_state() -> str: |
| if not _state_manager.has_active_episode(): |
| return "No active episode." |
| try: |
| return json.dumps(_state_manager.get_state().model_dump(), indent=2, default=str) |
| except Exception as exc: |
| return f"Error: {exc}" |
|
|
| with gr.Blocks(title="OpenENV — Content Moderation Playground") as demo: |
| gr.Markdown("# OpenENV Content Moderation Playground") |
| gr.Markdown( |
| "Interact with the content moderation environment. " |
| "Select a task, click **Reset**, then send **Step** actions." |
| ) |
| with gr.Row(): |
| with gr.Column(scale=2): |
| obs_display = gr.Markdown(value="*Click Reset to start an episode.*") |
| with gr.Column(scale=1): |
| raw_json = gr.Code(label="Raw JSON", language="json", interactive=False) |
|
|
| with gr.Row(): |
| task_dd = gr.Dropdown( |
| choices=task_choices, |
| value=task_choices[0] if task_choices else None, |
| label="Task", |
| ) |
| reset_btn = gr.Button("Reset", variant="secondary") |
|
|
| with gr.Row(): |
| action_dd = gr.Dropdown( |
| choices=action_choices, |
| value=action_choices[0] if action_choices else None, |
| label="Action type", |
| ) |
| params_tb = gr.Textbox( |
| label="Parameters (JSON)", |
| placeholder='{"clause_id": "harassment_1"}', |
| value="{}", |
| ) |
| step_btn = gr.Button("Step", variant="primary") |
|
|
| status_tb = gr.Textbox(label="Status", interactive=False) |
| state_btn = gr.Button("Get state", variant="secondary") |
|
|
| reset_btn.click(fn=do_reset, inputs=[task_dd], outputs=[obs_display, raw_json, status_tb]) |
| step_btn.click(fn=do_step, inputs=[action_dd, params_tb], outputs=[obs_display, raw_json, status_tb]) |
| state_btn.click(fn=do_state, outputs=[raw_json]) |
|
|
| return demo |
|
|
|
|
| if os.getenv("ENABLE_WEB_INTERFACE", "").lower() in ("true", "1", "yes"): |
| try: |
| import gradio as gr |
| from fastapi.responses import RedirectResponse |
|
|
| _gradio_demo = _build_web_interface() |
|
|
| @app.get("/web", include_in_schema=False) |
| async def _web_redirect(): |
| return RedirectResponse(url="/web/") |
|
|
| app = gr.mount_gradio_app(app, _gradio_demo, path="/web") |
| logger.info("Web interface mounted at /web") |
| except ImportError: |
| logger.warning("gradio not installed — ENABLE_WEB_INTERFACE=true has no effect") |
|
|
|
|
| def main() -> None: |
| import uvicorn |
| uvicorn.run( |
| "server.app:app", |
| host=os.getenv("HOST", "0.0.0.0"), |
| port=int(os.getenv("PORT", "8000")), |
| reload=False, |
| ) |
|
|