Spaces:
Sleeping
Sleeping
File size: 2,398 Bytes
37ab50f 0e662d7 a3c4cc9 fc70913 37ab50f b0ac794 37ab50f b0ac794 37ab50f f183aea 37ab50f f183aea a3c4cc9 574fde3 0e662d7 202f5d9 fc70913 7d5c46d f5810e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from models import AttackAction, StepResult, ResetResponse, EpisodeState
from server.environment import RedTeamEnvironment
from server.config import get_settings
from rewards.compute_rewards import RewardComputer
from llm.pipeline import run_llm_pipeline
env: RedTeamEnvironment = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global env
settings = get_settings()
env = RedTeamEnvironment(max_turns=settings.max_turns)
reward_computer = RewardComputer()
env.set_reward_computer(reward_computer)
env.set_llm_pipeline(run_llm_pipeline)
yield
app = FastAPI(
title = "RedTeamOS",
version = "0.1.0",
lifespan = lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins = ["*"],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
@app.get("/health")
async def health_check():
return {"status": "healthy", "version": "0.1.0"}
@app.post("/reset", response_model=ResetResponse)
async def reset_episode():
try:
observation = await env.reset()
return ResetResponse(observation=observation, episode_id=observation.episode_id)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/step", response_model=StepResult)
async def step_episode(action: AttackAction):
try:
result = await env.step(action)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/state", response_model=EpisodeState)
async def get_state():
return env.get_state()
@app.get("/history")
async def get_history():
return {"history": env.get_history()}
@app.post("/grade")
async def grade_episode():
if env.is_active:
raise HTTPException(status_code=400, detail="Episode still active — finish it before grading.")
history = env.get_history()
if not history:
raise HTTPException(status_code=400, detail="No episode history to grade.")
from graders.programmatic_grader import grade_episode as do_grade
result = do_grade(history)
result["episode_id"] = env.episode_id
return result
|