File size: 3,370 Bytes
37ab50f
f61eeae
0e662d7
a3c4cc9
f61eeae
 
abd2333
37ab50f
 
 
b0ac794
905ac2f
abd2333
b0ac794
37ab50f
 
 
 
 
 
 
b0ac794
 
 
 
37ab50f
f183aea
 
4df57fe
f183aea
37ab50f
f183aea
a3c4cc9
 
 
 
 
 
 
 
574fde3
f61eeae
 
 
 
 
 
 
 
574fde3
 
 
0e662d7
 
 
 
905ac2f
0e662d7
 
 
 
202f5d9
 
 
 
 
 
 
 
 
 
fc70913
 
 
 
7d5c46d
 
 
 
f5810e8
 
 
 
 
 
 
 
 
 
 
 
 
 
abd2333
 
 
 
 
 
 
2cbf425
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from models import AttackAction, StepResult, ResetResponse, EpisodeState, AutoAttackRequest
from server.environment import RedTeamEnvironment
from server.config import get_settings

from rewards.compute_rewards import RewardComputer
from llm.pipeline import run_llm_pipeline, reset_conversation
from llm.automated_attacker import generate_automated_attack

env: RedTeamEnvironment = None

@asynccontextmanager
async def lifespan(app: FastAPI):
    global env
    settings = get_settings()
    env = RedTeamEnvironment(max_turns=settings.max_turns)

    reward_computer = RewardComputer()
    env.set_reward_computer(reward_computer)
    env.set_llm_pipeline(run_llm_pipeline)
    yield

app = FastAPI(
    title       = "BreachOS",
    version     = "0.1.0",
    lifespan    = lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins     = ["*"],
    allow_credentials = True,
    allow_methods     = ["*"],
    allow_headers     = ["*"],
)

_FRONTEND = Path(__file__).parent.parent / "frontend"
if _FRONTEND.exists():
    app.mount("/static", StaticFiles(directory=str(_FRONTEND)), name="static")

@app.get("/", include_in_schema=False)
async def serve_ui():
    return FileResponse(str(_FRONTEND / "index.html"))

@app.get("/health")
async def health_check():
    return {"status": "healthy", "version": "0.1.0"}

@app.post("/reset", response_model=ResetResponse)
async def reset_episode():
    try:
        reset_conversation()
        observation = await env.reset()
        return ResetResponse(observation=observation, episode_id=observation.episode_id)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/step", response_model=StepResult)
async def step_episode(action: AttackAction):
    try:
        result = await env.step(action)
        return result
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/state", response_model=EpisodeState)
async def get_state():
    return env.get_state()

@app.get("/history")
async def get_history():
    return {"history": env.get_history()}

@app.post("/grade")
async def grade_episode():
    if env.is_active:
        raise HTTPException(status_code=400, detail="Episode still active — finish it before grading.")

    history = env.get_history()
    if not history:
        raise HTTPException(status_code=400, detail="No episode history to grade.")

    from graders.programmatic_grader import grade_episode as do_grade
    result = do_grade(history)
    result["episode_id"] = env.episode_id
    return result

@app.post("/auto-attack")
async def auto_attack(request: AutoAttackRequest):
    if not env.is_active:
        raise HTTPException(status_code=400, detail="No active episode.")
    framing = generate_automated_attack(request.strategy_type.value, request.target_category.value)
    return {"framing": framing}


def main():
    import uvicorn
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)


if __name__ == "__main__":
    main()