File size: 10,460 Bytes
c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 d9b66e9 c492c3f d9b66e9 c492c3f d9b66e9 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f 099b3c1 c492c3f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 | """
OpenENV Moderation Environment β FastAPI application.
Standard OpenEnv endpoints:
WS /ws β persistent WebSocket session (primary client interface)
GET /health β liveness check
POST /reset β start a new episode
POST /step β take an action
GET /state β current observation / state
GET /docs β OpenAPI documentation (auto-generated)
Custom endpoints:
GET /tasks β available tasks
GET /grader β final episode score
GET /baseline β run rule-based baseline agent and return its score
POST /agent/run β run selected LLM agent on a full episode
"""
from __future__ import annotations
import json
import logging
from dotenv import load_dotenv
load_dotenv() # loads .env from project root before anything else
from fastapi import FastAPI, HTTPException, Body, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from openenv.core.env_server.types import (
HealthResponse,
HealthStatus,
ResetRequest as OEResetRequest,
ResetResponse,
StepRequest,
StepResponse,
WSObservationResponse,
WSStateResponse,
WSErrorResponse,
WSErrorCode,
)
from data.tasks import TASKS
from env.grader import Grader
from env.state_manager import StateManager
from models.schemas import (
Action,
BaselineResult,
EpisodeScore,
ResetRequest,
TaskConfig,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="OpenENV β Content Moderation Environment",
description=(
"A multi-step RL environment for AI content moderation agents. "
"Agents receive partial observations and must investigate context, "
"classify violations, and make final moderation decisions."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Open for HF Spaces + local dev
allow_methods=["*"],
allow_headers=["*"],
)
# Single shared state manager (single-threaded MVP)
_state_manager = StateManager()
_grader = Grader()
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
return HealthResponse(status=HealthStatus.HEALTHY)
@app.get("/tasks")
def list_tasks() -> dict[str, TaskConfig]:
return TASKS
@app.post("/reset", response_model=ResetResponse)
def reset(request: OEResetRequest | None = Body(default=None)) -> ResetResponse:
# task_id passed as extra field; fall back to episode_id or default
extra = (request.model_extra or {}) if request else {}
task_id = extra.get("task_id") or (request.episode_id if request else None) or "easy_harassment"
seed = (request.seed if request else None) or 42
if task_id not in TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}",
)
task = TASKS[task_id]
task = task.model_copy(update={"seed": seed})
obs = _state_manager.reset(task)
return ResetResponse(observation=obs.model_dump(), reward=None, done=obs.done)
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest) -> StepResponse:
if not _state_manager.has_active_episode():
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
try:
action = Action(**request.action)
except Exception as exc:
raise HTTPException(status_code=422, detail=str(exc))
try:
result = _state_manager.step(action)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
logger.info(
"Step %d: action=%s reward=%.3f done=%s",
result.observation.step,
action.action_type.value,
result.reward,
result.done,
)
return StepResponse(
observation=result.observation.model_dump(),
reward=result.reward,
done=result.done,
)
@app.get("/state")
def get_state() -> dict:
if not _state_manager.has_active_episode():
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
return _state_manager.get_state().model_dump()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
while True:
try:
raw = await websocket.receive_text()
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_text(
WSErrorResponse(data={"message": "Invalid JSON", "code": WSErrorCode.INVALID_JSON}).model_dump_json()
)
continue
msg_type = data.get("type")
if msg_type == "reset":
reset_data = data.get("data", {})
task_id = reset_data.get("task_id") or reset_data.get("episode_id") or "easy_harassment"
seed = reset_data.get("seed") or 42
if task_id not in TASKS:
await websocket.send_text(
WSErrorResponse(data={"message": f"Unknown task_id '{task_id}'", "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json()
)
continue
task = TASKS[task_id].model_copy(update={"seed": seed})
obs = _state_manager.reset(task)
await websocket.send_text(
WSObservationResponse(data={"observation": obs.model_dump(), "reward": None, "done": obs.done}).model_dump_json()
)
elif msg_type == "step":
if not _state_manager.has_active_episode():
await websocket.send_text(
WSErrorResponse(data={"message": "No active episode. Send reset first.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json()
)
continue
action_data = data.get("data", {})
try:
action = Action(**action_data)
except Exception as exc:
await websocket.send_text(
WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json()
)
continue
try:
result = _state_manager.step(action)
except ValueError as exc:
await websocket.send_text(
WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.EXECUTION_ERROR}).model_dump_json()
)
continue
await websocket.send_text(
WSObservationResponse(data={"observation": result.observation.model_dump(), "reward": result.reward, "done": result.done}).model_dump_json()
)
elif msg_type == "state":
if not _state_manager.has_active_episode():
await websocket.send_text(
WSErrorResponse(data={"message": "No active episode.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json()
)
continue
obs = _state_manager.get_state()
await websocket.send_text(
WSStateResponse(data=obs.model_dump()).model_dump_json()
)
elif msg_type == "close":
break
else:
await websocket.send_text(
WSErrorResponse(data={"message": f"Unknown message type: {msg_type!r}", "code": WSErrorCode.UNKNOWN_TYPE}).model_dump_json()
)
except WebSocketDisconnect:
pass
@app.get("/grader", response_model=EpisodeScore)
def grade() -> EpisodeScore:
if not _state_manager.has_active_episode():
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
episode = _state_manager.get_episode_state()
if not episode.observation.done:
raise HTTPException(
status_code=400,
detail="Episode is not finished yet. Complete the episode before grading.",
)
score = _grader.score(episode)
logger.info("Graded episode: total=%.4f", score.total)
return score
@app.get("/baseline", response_model=BaselineResult)
def baseline(task_id: str = "easy_harassment", seed: int | None = None) -> BaselineResult:
"""Run the built-in rule-based baseline agent and return its score."""
from baseline.agent import BaselineAgent
if task_id not in TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}",
)
task = TASKS[task_id]
if seed is not None:
task = task.model_copy(update={"seed": seed})
agent = BaselineAgent(state_manager=_state_manager, grader=_grader)
result = agent.run(task)
return result
@app.post("/agent/run", response_model=BaselineResult)
def agent_run(request: ResetRequest) -> BaselineResult:
"""
Run the selected LLM agent (OpenAI or Gemini) on a full episode and return the graded result.
Requires OPENAI_API_KEY, or GOOGLE_API_KEY/GEMINI_API_KEY depending on LLM_PROVIDER.
"""
import os
from agent.openai_agent import OpenAIAgent
from agent.gemini_agent import GeminiAgent
provider = os.getenv("LLM_PROVIDER", "openai").lower()
if request.task_id not in TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{request.task_id}'. Available: {list(TASKS.keys())}",
)
task = TASKS[request.task_id]
if request.seed is not None:
task = task.model_copy(update={"seed": request.seed})
try:
if provider == "gemini":
agent = GeminiAgent(state_manager=_state_manager, grader=_grader)
else:
agent = OpenAIAgent(state_manager=_state_manager, grader=_grader)
except EnvironmentError as exc:
raise HTTPException(status_code=500, detail=str(exc))
result = agent.run(task)
logger.info(
"%s agent finished: task=%s total=%.4f", provider.capitalize(), task.task_id, result.score.total
)
return result
|