File size: 5,423 Bytes
5f7092b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import asyncio
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import uvicorn

logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)

from config import load_config
from model import XGuardModel

config = load_config()
app = FastAPI(title="XGuard MaaS", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

xguard_model: Optional[XGuardModel] = None
executor: Optional[ThreadPoolExecutor] = None

MAX_CONCURRENT_REQUESTS = 10
request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)


class Message(BaseModel):
    role: str
    content: str


class Tool(BaseModel):
    name: str
    description: str
    parameters: Any


class GuardCheckRequest(BaseModel):
    conversationId: str
    messages: List[Message]
    tools: List[Tool]
    enableReasoning: bool = Field(default=False, description="是否启用归因分析")


class GuardCheckResponse(BaseModel):
    err_code: int
    data: Dict[str, Any]
    msg: str


def build_check_content(messages: List[Dict], tools: List[Dict]) -> str:
    """将消息和工具调用信息拼接成检测内容"""
    # 提取用户消息内容
    user_contents = []
    for msg in messages:
        if msg.get("role") == "user":
            user_contents.append(msg.get("content", ""))

    content = "\n".join(user_contents) if user_contents else ""

    # 如果有工具信息,拼接工具调用详情
    if tools:
        tool_infos = []
        for tool in tools:
            tool_name = tool.get("name", "")
            tool_desc = tool.get("description", "")
            tool_params = tool.get("parameters", {})

            tool_info = f"\n[Tool Call] {tool_name}"
            if tool_desc:
                tool_info += f"\nDescription: {tool_desc}"
            if tool_params:
                tool_info += f"\nParameters: {json.dumps(tool_params, ensure_ascii=False)}"
            tool_infos.append(tool_info)

        content += "\n" + "\n".join(tool_infos)

    return content.strip()


@app.on_event("startup")
async def startup_event():
    global xguard_model, executor
    try:
        xguard_model = XGuardModel(config.model_path, config.device)
        executor = ThreadPoolExecutor(max_workers=4)
        print(f"XGuard model loaded on {config.device}")
    except Exception as e:
        print(f"Failed to load model: {e}")
        raise


@app.on_event("shutdown")
async def shutdown_event():
    global executor
    if executor:
        executor.shutdown(wait=True)


@app.get("/health")
async def health_check():
    return {"status": "ok", "model_loaded": xguard_model is not None}


@app.post("/v1/guard/check", response_model=GuardCheckResponse)
async def guard_check(

    request: GuardCheckRequest,

    x_api_key: str = Header(..., alias="x-api-key")

):
    if x_api_key != config.api_key:
        raise HTTPException(status_code=401, detail="Invalid API key")

    if xguard_model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    async with request_semaphore:
        try:
            messages = [{"role": m.role, "content": m.content} for m in request.messages]
            tools = [{"name": t.name, "description": t.description, "parameters": t.parameters} for t in request.tools]

            # 将消息和工具信息拼接成检测内容
            check_content = build_check_content(messages, tools)
            logger.info("会话 [%s] 检测内容:\n%s", request.conversationId, check_content)

            # 构建用于检测的消息
            check_messages = [{"role": "user", "content": check_content}]

            loop = asyncio.get_event_loop()
            result = await loop.run_in_executor(
                executor,
                lambda: xguard_model.analyze(
                    check_messages,
                    [],  # 工具已拼接到内容中,不再单独传递
                    enable_reasoning=request.enableReasoning
                )
            )

            # 构建响应数据
            response_data = {
                "is_safe": result["is_safe"],
                "risk_level": result.get("risk_level", "safe" if result["is_safe"] == 1 else "medium"),
                "confidence": result.get("confidence", 0.0),
                "risk_type": result["risk_type"],
                "reason": result["reason"]
            }
            
            # 如果启用了归因分析,添加 explanation
            if request.enableReasoning and "explanation" in result:
                response_data["explanation"] = result["explanation"]

            return GuardCheckResponse(
                err_code=0,
                data=response_data,
                msg="success"
            )
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")


if __name__ == "__main__":
    uvicorn.run(app, host=config.host, port=config.port)