Spaces:
Sleeping
Sleeping
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from models import AttackAction, StepResult, ResetResponse, EpisodeState | |
| from server.environment import RedTeamEnvironment | |
| from server.config import get_settings | |
| from rewards.compute_rewards import RewardComputer | |
| from llm.pipeline import run_llm_pipeline | |
| env: RedTeamEnvironment = None | |
| async def lifespan(app: FastAPI): | |
| global env | |
| settings = get_settings() | |
| env = RedTeamEnvironment(max_turns=settings.max_turns) | |
| reward_computer = RewardComputer() | |
| env.set_reward_computer(reward_computer) | |
| env.set_llm_pipeline(run_llm_pipeline) | |
| yield | |
| app = FastAPI( | |
| title = "RedTeamOS", | |
| version = "0.1.0", | |
| lifespan = lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins = ["*"], | |
| allow_credentials = True, | |
| allow_methods = ["*"], | |
| allow_headers = ["*"], | |
| ) | |
| async def health_check(): | |
| return {"status": "healthy", "version": "0.1.0"} | |
| async def reset_episode(): | |
| try: | |
| observation = await env.reset() | |
| return ResetResponse(observation=observation, episode_id=observation.episode_id) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def step_episode(action: AttackAction): | |
| try: | |
| result = await env.step(action) | |
| return result | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_state(): | |
| return env.get_state() | |
| async def get_history(): | |
| return {"history": env.get_history()} | |
| async def grade_episode(): | |
| if env.is_active: | |
| raise HTTPException(status_code=400, detail="Episode still active — finish it before grading.") | |
| history = env.get_history() | |
| if not history: | |
| raise HTTPException(status_code=400, detail="No episode history to grade.") | |
| from graders.programmatic_grader import grade_episode as do_grade | |
| result = do_grade(history) | |
| result["episode_id"] = env.episode_id | |
| return result | |