ThejasRao's picture
Upload folder using huggingface_hub
099b3c1 verified
"""
OpenENV Moderation Environment — FastAPI application.
Standard OpenEnv endpoints:
WS /ws — persistent WebSocket session (primary client interface)
GET /health — liveness check
POST /reset — start a new episode
POST /step — take an action
GET /state — current observation / state
GET /docs — OpenAPI documentation (auto-generated)
Custom endpoints:
GET /tasks — available tasks
GET /grader — final episode score
GET /baseline — run rule-based baseline agent and return its score
POST /agent/run — run selected LLM agent on a full episode
"""
from __future__ import annotations
import json
import logging
from dotenv import load_dotenv
load_dotenv() # loads .env from project root before anything else
from fastapi import FastAPI, HTTPException, Body, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from openenv.core.env_server.types import (
HealthResponse,
HealthStatus,
ResetRequest as OEResetRequest,
ResetResponse,
StepRequest,
StepResponse,
WSObservationResponse,
WSStateResponse,
WSErrorResponse,
WSErrorCode,
)
from data.tasks import TASKS
from env.grader import Grader
from env.state_manager import StateManager
from models.schemas import (
Action,
BaselineResult,
EpisodeScore,
ResetRequest,
TaskConfig,
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="OpenENV — Content Moderation Environment",
description=(
"A multi-step RL environment for AI content moderation agents. "
"Agents receive partial observations and must investigate context, "
"classify violations, and make final moderation decisions."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Open for HF Spaces + local dev
allow_methods=["*"],
allow_headers=["*"],
)
# Single shared state manager (single-threaded MVP)
_state_manager = StateManager()
_grader = Grader()
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health", response_model=HealthResponse)
def health() -> HealthResponse:
return HealthResponse(status=HealthStatus.HEALTHY)
@app.get("/tasks")
def list_tasks() -> dict[str, TaskConfig]:
return TASKS
@app.post("/reset", response_model=ResetResponse)
def reset(request: OEResetRequest | None = Body(default=None)) -> ResetResponse:
# task_id passed as extra field; fall back to episode_id or default
extra = (request.model_extra or {}) if request else {}
task_id = extra.get("task_id") or (request.episode_id if request else None) or "easy_harassment"
seed = (request.seed if request else None) or 42
if task_id not in TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}",
)
task = TASKS[task_id]
task = task.model_copy(update={"seed": seed})
obs = _state_manager.reset(task)
return ResetResponse(observation=obs.model_dump(), reward=None, done=obs.done)
@app.post("/step", response_model=StepResponse)
def step(request: StepRequest) -> StepResponse:
if not _state_manager.has_active_episode():
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
try:
action = Action(**request.action)
except Exception as exc:
raise HTTPException(status_code=422, detail=str(exc))
try:
result = _state_manager.step(action)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
logger.info(
"Step %d: action=%s reward=%.3f done=%s",
result.observation.step,
action.action_type.value,
result.reward,
result.done,
)
return StepResponse(
observation=result.observation.model_dump(),
reward=result.reward,
done=result.done,
)
@app.get("/state")
def get_state() -> dict:
if not _state_manager.has_active_episode():
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
return _state_manager.get_state().model_dump()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
while True:
try:
raw = await websocket.receive_text()
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_text(
WSErrorResponse(data={"message": "Invalid JSON", "code": WSErrorCode.INVALID_JSON}).model_dump_json()
)
continue
msg_type = data.get("type")
if msg_type == "reset":
reset_data = data.get("data", {})
task_id = reset_data.get("task_id") or reset_data.get("episode_id") or "easy_harassment"
seed = reset_data.get("seed") or 42
if task_id not in TASKS:
await websocket.send_text(
WSErrorResponse(data={"message": f"Unknown task_id '{task_id}'", "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json()
)
continue
task = TASKS[task_id].model_copy(update={"seed": seed})
obs = _state_manager.reset(task)
await websocket.send_text(
WSObservationResponse(data={"observation": obs.model_dump(), "reward": None, "done": obs.done}).model_dump_json()
)
elif msg_type == "step":
if not _state_manager.has_active_episode():
await websocket.send_text(
WSErrorResponse(data={"message": "No active episode. Send reset first.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json()
)
continue
action_data = data.get("data", {})
try:
action = Action(**action_data)
except Exception as exc:
await websocket.send_text(
WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.VALIDATION_ERROR}).model_dump_json()
)
continue
try:
result = _state_manager.step(action)
except ValueError as exc:
await websocket.send_text(
WSErrorResponse(data={"message": str(exc), "code": WSErrorCode.EXECUTION_ERROR}).model_dump_json()
)
continue
await websocket.send_text(
WSObservationResponse(data={"observation": result.observation.model_dump(), "reward": result.reward, "done": result.done}).model_dump_json()
)
elif msg_type == "state":
if not _state_manager.has_active_episode():
await websocket.send_text(
WSErrorResponse(data={"message": "No active episode.", "code": WSErrorCode.SESSION_ERROR}).model_dump_json()
)
continue
obs = _state_manager.get_state()
await websocket.send_text(
WSStateResponse(data=obs.model_dump()).model_dump_json()
)
elif msg_type == "close":
break
else:
await websocket.send_text(
WSErrorResponse(data={"message": f"Unknown message type: {msg_type!r}", "code": WSErrorCode.UNKNOWN_TYPE}).model_dump_json()
)
except WebSocketDisconnect:
pass
@app.get("/grader", response_model=EpisodeScore)
def grade() -> EpisodeScore:
if not _state_manager.has_active_episode():
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
episode = _state_manager.get_episode_state()
if not episode.observation.done:
raise HTTPException(
status_code=400,
detail="Episode is not finished yet. Complete the episode before grading.",
)
score = _grader.score(episode)
logger.info("Graded episode: total=%.4f", score.total)
return score
@app.get("/baseline", response_model=BaselineResult)
def baseline(task_id: str = "easy_harassment", seed: int | None = None) -> BaselineResult:
"""Run the built-in rule-based baseline agent and return its score."""
from baseline.agent import BaselineAgent
if task_id not in TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{task_id}'. Available: {list(TASKS.keys())}",
)
task = TASKS[task_id]
if seed is not None:
task = task.model_copy(update={"seed": seed})
agent = BaselineAgent(state_manager=_state_manager, grader=_grader)
result = agent.run(task)
return result
@app.post("/agent/run", response_model=BaselineResult)
def agent_run(request: ResetRequest) -> BaselineResult:
"""
Run the selected LLM agent (OpenAI or Gemini) on a full episode and return the graded result.
Requires OPENAI_API_KEY, or GOOGLE_API_KEY/GEMINI_API_KEY depending on LLM_PROVIDER.
"""
import os
from agent.openai_agent import OpenAIAgent
from agent.gemini_agent import GeminiAgent
provider = os.getenv("LLM_PROVIDER", "openai").lower()
if request.task_id not in TASKS:
raise HTTPException(
status_code=400,
detail=f"Unknown task_id '{request.task_id}'. Available: {list(TASKS.keys())}",
)
task = TASKS[request.task_id]
if request.seed is not None:
task = task.model_copy(update={"seed": request.seed})
try:
if provider == "gemini":
agent = GeminiAgent(state_manager=_state_manager, grader=_grader)
else:
agent = OpenAIAgent(state_manager=_state_manager, grader=_grader)
except EnvironmentError as exc:
raise HTTPException(status_code=500, detail=str(exc))
result = agent.run(task)
logger.info(
"%s agent finished: task=%s total=%.4f", provider.capitalize(), task.task_id, result.score.total
)
return result