|
|
import asyncio
|
|
|
import json
|
|
|
import logging
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from fastapi import FastAPI, HTTPException, Header
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel, Field
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
import uvicorn
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
from config import load_config
|
|
|
from model import XGuardModel
|
|
|
|
|
|
config = load_config()
|
|
|
app = FastAPI(title="XGuard MaaS", version="1.0.0")
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
xguard_model: Optional[XGuardModel] = None
|
|
|
executor: Optional[ThreadPoolExecutor] = None
|
|
|
|
|
|
MAX_CONCURRENT_REQUESTS = 10
|
|
|
request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
|
|
|
|
|
|
|
|
class Message(BaseModel):
|
|
|
role: str
|
|
|
content: str
|
|
|
|
|
|
|
|
|
class Tool(BaseModel):
|
|
|
name: str
|
|
|
description: str
|
|
|
parameters: Any
|
|
|
|
|
|
|
|
|
class GuardCheckRequest(BaseModel):
|
|
|
conversationId: str
|
|
|
messages: List[Message]
|
|
|
tools: List[Tool]
|
|
|
enableReasoning: bool = Field(default=False, description="是否启用归因分析")
|
|
|
|
|
|
|
|
|
class GuardCheckResponse(BaseModel):
|
|
|
err_code: int
|
|
|
data: Dict[str, Any]
|
|
|
msg: str
|
|
|
|
|
|
|
|
|
def build_check_content(messages: List[Dict], tools: List[Dict]) -> str:
|
|
|
"""将消息和工具调用信息拼接成检测内容"""
|
|
|
|
|
|
user_contents = []
|
|
|
for msg in messages:
|
|
|
if msg.get("role") == "user":
|
|
|
user_contents.append(msg.get("content", ""))
|
|
|
|
|
|
content = "\n".join(user_contents) if user_contents else ""
|
|
|
|
|
|
|
|
|
if tools:
|
|
|
tool_infos = []
|
|
|
for tool in tools:
|
|
|
tool_name = tool.get("name", "")
|
|
|
tool_desc = tool.get("description", "")
|
|
|
tool_params = tool.get("parameters", {})
|
|
|
|
|
|
tool_info = f"\n[Tool Call] {tool_name}"
|
|
|
if tool_desc:
|
|
|
tool_info += f"\nDescription: {tool_desc}"
|
|
|
if tool_params:
|
|
|
tool_info += f"\nParameters: {json.dumps(tool_params, ensure_ascii=False)}"
|
|
|
tool_infos.append(tool_info)
|
|
|
|
|
|
content += "\n" + "\n".join(tool_infos)
|
|
|
|
|
|
return content.strip()
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
global xguard_model, executor
|
|
|
try:
|
|
|
xguard_model = XGuardModel(config.model_path, config.device)
|
|
|
executor = ThreadPoolExecutor(max_workers=4)
|
|
|
print(f"XGuard model loaded on {config.device}")
|
|
|
except Exception as e:
|
|
|
print(f"Failed to load model: {e}")
|
|
|
raise
|
|
|
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
|
async def shutdown_event():
|
|
|
global executor
|
|
|
if executor:
|
|
|
executor.shutdown(wait=True)
|
|
|
|
|
|
|
|
|
@app.get("/health")
|
|
|
async def health_check():
|
|
|
return {"status": "ok", "model_loaded": xguard_model is not None}
|
|
|
|
|
|
|
|
|
@app.post("/v1/guard/check", response_model=GuardCheckResponse)
|
|
|
async def guard_check(
|
|
|
request: GuardCheckRequest,
|
|
|
x_api_key: str = Header(..., alias="x-api-key")
|
|
|
):
|
|
|
if x_api_key != config.api_key:
|
|
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
|
|
|
|
if xguard_model is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
async with request_semaphore:
|
|
|
try:
|
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
|
tools = [{"name": t.name, "description": t.description, "parameters": t.parameters} for t in request.tools]
|
|
|
|
|
|
|
|
|
check_content = build_check_content(messages, tools)
|
|
|
logger.info("会话 [%s] 检测内容:\n%s", request.conversationId, check_content)
|
|
|
|
|
|
|
|
|
check_messages = [{"role": "user", "content": check_content}]
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
result = await loop.run_in_executor(
|
|
|
executor,
|
|
|
lambda: xguard_model.analyze(
|
|
|
check_messages,
|
|
|
[],
|
|
|
enable_reasoning=request.enableReasoning
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
response_data = {
|
|
|
"is_safe": result["is_safe"],
|
|
|
"risk_level": result.get("risk_level", "safe" if result["is_safe"] == 1 else "medium"),
|
|
|
"confidence": result.get("confidence", 0.0),
|
|
|
"risk_type": result["risk_type"],
|
|
|
"reason": result["reason"]
|
|
|
}
|
|
|
|
|
|
|
|
|
if request.enableReasoning and "explanation" in result:
|
|
|
response_data["explanation"] = result["explanation"]
|
|
|
|
|
|
return GuardCheckResponse(
|
|
|
err_code=0,
|
|
|
data=response_data,
|
|
|
msg="success"
|
|
|
)
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run(app, host=config.host, port=config.port)
|
|
|
|