Breach-OS / server /app.py
subhdotsol's picture
feat(app): add POST /grade endpoint with active-episode guard and programmatic grader
f5810e8
raw
history blame
2.4 kB
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from models import AttackAction, StepResult, ResetResponse, EpisodeState
from server.environment import RedTeamEnvironment
from server.config import get_settings
from rewards.compute_rewards import RewardComputer
from llm.pipeline import run_llm_pipeline
env: RedTeamEnvironment = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global env
settings = get_settings()
env = RedTeamEnvironment(max_turns=settings.max_turns)
reward_computer = RewardComputer()
env.set_reward_computer(reward_computer)
env.set_llm_pipeline(run_llm_pipeline)
yield
app = FastAPI(
title = "RedTeamOS",
version = "0.1.0",
lifespan = lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
@app.get("/health")
async def health_check():
return {"status": "healthy", "version": "0.1.0"}
@app.post("/reset", response_model=ResetResponse)
async def reset_episode():
try:
observation = await env.reset()
return ResetResponse(observation=observation, episode_id=observation.episode_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/step", response_model=StepResult)
async def step_episode(action: AttackAction):
try:
result = await env.step(action)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/state", response_model=EpisodeState)
async def get_state():
return env.get_state()
@app.get("/history")
async def get_history():
return {"history": env.get_history()}
@app.post("/grade")
async def grade_episode():
if env.is_active:
raise HTTPException(status_code=400, detail="Episode still active — finish it before grading.")
history = env.get_history()
if not history:
raise HTTPException(status_code=400, detail="No episode history to grade.")
from graders.programmatic_grader import grade_episode as do_grade
result = do_grade(history)
result["episode_id"] = env.episode_id
return result