File size: 5,423 Bytes
5f7092b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import asyncio
import json
import logging
from concurrent.futures import ThreadPoolExecutor
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Optional
import uvicorn
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
from config import load_config
from model import XGuardModel
config = load_config()
app = FastAPI(title="XGuard MaaS", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
xguard_model: Optional[XGuardModel] = None
executor: Optional[ThreadPoolExecutor] = None
MAX_CONCURRENT_REQUESTS = 10
request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
class Message(BaseModel):
role: str
content: str
class Tool(BaseModel):
name: str
description: str
parameters: Any
class GuardCheckRequest(BaseModel):
conversationId: str
messages: List[Message]
tools: List[Tool]
enableReasoning: bool = Field(default=False, description="是否启用归因分析")
class GuardCheckResponse(BaseModel):
err_code: int
data: Dict[str, Any]
msg: str
def build_check_content(messages: List[Dict], tools: List[Dict]) -> str:
"""将消息和工具调用信息拼接成检测内容"""
# 提取用户消息内容
user_contents = []
for msg in messages:
if msg.get("role") == "user":
user_contents.append(msg.get("content", ""))
content = "\n".join(user_contents) if user_contents else ""
# 如果有工具信息,拼接工具调用详情
if tools:
tool_infos = []
for tool in tools:
tool_name = tool.get("name", "")
tool_desc = tool.get("description", "")
tool_params = tool.get("parameters", {})
tool_info = f"\n[Tool Call] {tool_name}"
if tool_desc:
tool_info += f"\nDescription: {tool_desc}"
if tool_params:
tool_info += f"\nParameters: {json.dumps(tool_params, ensure_ascii=False)}"
tool_infos.append(tool_info)
content += "\n" + "\n".join(tool_infos)
return content.strip()
@app.on_event("startup")
async def startup_event():
global xguard_model, executor
try:
xguard_model = XGuardModel(config.model_path, config.device)
executor = ThreadPoolExecutor(max_workers=4)
print(f"XGuard model loaded on {config.device}")
except Exception as e:
print(f"Failed to load model: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
global executor
if executor:
executor.shutdown(wait=True)
@app.get("/health")
async def health_check():
return {"status": "ok", "model_loaded": xguard_model is not None}
@app.post("/v1/guard/check", response_model=GuardCheckResponse)
async def guard_check(
request: GuardCheckRequest,
x_api_key: str = Header(..., alias="x-api-key")
):
if x_api_key != config.api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
if xguard_model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
async with request_semaphore:
try:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
tools = [{"name": t.name, "description": t.description, "parameters": t.parameters} for t in request.tools]
# 将消息和工具信息拼接成检测内容
check_content = build_check_content(messages, tools)
logger.info("会话 [%s] 检测内容:\n%s", request.conversationId, check_content)
# 构建用于检测的消息
check_messages = [{"role": "user", "content": check_content}]
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
executor,
lambda: xguard_model.analyze(
check_messages,
[], # 工具已拼接到内容中,不再单独传递
enable_reasoning=request.enableReasoning
)
)
# 构建响应数据
response_data = {
"is_safe": result["is_safe"],
"risk_level": result.get("risk_level", "safe" if result["is_safe"] == 1 else "medium"),
"confidence": result.get("confidence", 0.0),
"risk_type": result["risk_type"],
"reason": result["reason"]
}
# 如果启用了归因分析,添加 explanation
if request.enableReasoning and "explanation" in result:
response_data["explanation"] = result["explanation"]
return GuardCheckResponse(
err_code=0,
data=response_data,
msg="success"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host=config.host, port=config.port)
|