from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from models import AttackAction, StepResult, ResetResponse, EpisodeState from server.environment import RedTeamEnvironment from server.config import get_settings from rewards.compute_rewards import RewardComputer from llm.pipeline import run_llm_pipeline env: RedTeamEnvironment = None @asynccontextmanager async def lifespan(app: FastAPI): global env settings = get_settings() env = RedTeamEnvironment(max_turns=settings.max_turns) reward_computer = RewardComputer() env.set_reward_computer(reward_computer) env.set_llm_pipeline(run_llm_pipeline) yield app = FastAPI( title = "RedTeamOS", version = "0.1.0", lifespan = lifespan, ) app.add_middleware( CORSMiddleware, allow_origins = ["*"], allow_credentials = True, allow_methods = ["*"], allow_headers = ["*"], ) @app.get("/health") async def health_check(): return {"status": "healthy", "version": "0.1.0"} @app.post("/reset", response_model=ResetResponse) async def reset_episode(): try: observation = await env.reset() return ResetResponse(observation=observation, episode_id=observation.episode_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/step", response_model=StepResult) async def step_episode(action: AttackAction): try: result = await env.step(action) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/state", response_model=EpisodeState) async def get_state(): return env.get_state() @app.get("/history") async def get_history(): return {"history": env.get_history()} @app.post("/grade") async def grade_episode(): if env.is_active: raise HTTPException(status_code=400, detail="Episode still active — finish it before grading.") history = env.get_history() if not history: raise HTTPException(status_code=400, detail="No episode history to grade.") from graders.programmatic_grader import grade_episode as do_grade result = do_grade(history) result["episode_id"] = env.episode_id return result