File size: 2,398 Bytes
37ab50f
0e662d7
a3c4cc9
fc70913
37ab50f
 
 
b0ac794
 
 
37ab50f
 
 
 
 
 
 
b0ac794
 
 
 
37ab50f
f183aea
 
 
 
37ab50f
f183aea
a3c4cc9
 
 
 
 
 
 
 
574fde3
 
 
 
0e662d7
 
 
 
 
 
 
 
202f5d9
 
 
 
 
 
 
 
 
 
fc70913
 
 
 
7d5c46d
 
 
 
f5810e8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from models import AttackAction, StepResult, ResetResponse, EpisodeState
from server.environment import RedTeamEnvironment
from server.config import get_settings

from rewards.compute_rewards import RewardComputer
from llm.pipeline import run_llm_pipeline

env: RedTeamEnvironment = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global env
    settings = get_settings()
    env = RedTeamEnvironment(max_turns=settings.max_turns)

    reward_computer = RewardComputer()
    env.set_reward_computer(reward_computer)
    env.set_llm_pipeline(run_llm_pipeline)
    yield

app = FastAPI(
    title       = "RedTeamOS",
    version     = "0.1.0",
    lifespan    = lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins     = ["*"],
    allow_credentials = True,
    allow_methods     = ["*"],
    allow_headers     = ["*"],
)

@app.get("/health")
async def health_check():
    return {"status": "healthy", "version": "0.1.0"}

@app.post("/reset", response_model=ResetResponse)
async def reset_episode():
    try:
        observation = await env.reset()
        return ResetResponse(observation=observation, episode_id=observation.episode_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/step", response_model=StepResult)
async def step_episode(action: AttackAction):
    try:
        result = await env.step(action)
        return result
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/state", response_model=EpisodeState)
async def get_state():
    return env.get_state()

@app.get("/history")
async def get_history():
    return {"history": env.get_history()}

@app.post("/grade")
async def grade_episode():
    if env.is_active:
        raise HTTPException(status_code=400, detail="Episode still active — finish it before grading.")

    history = env.get_history()
    if not history:
        raise HTTPException(status_code=400, detail="No episode history to grade.")

    from graders.programmatic_grader import grade_episode as do_grade
    result = do_grade(history)
    result["episode_id"] = env.episode_id
    return result