import asyncio import json import logging from concurrent.futures import ThreadPoolExecutor from fastapi import FastAPI, HTTPException, Header from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from typing import List, Dict, Any, Optional import uvicorn logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) from config import load_config from model import XGuardModel config = load_config() app = FastAPI(title="XGuard MaaS", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) xguard_model: Optional[XGuardModel] = None executor: Optional[ThreadPoolExecutor] = None MAX_CONCURRENT_REQUESTS = 10 request_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) class Message(BaseModel): role: str content: str class Tool(BaseModel): name: str description: str parameters: Any class GuardCheckRequest(BaseModel): conversationId: str messages: List[Message] tools: List[Tool] enableReasoning: bool = Field(default=False, description="是否启用归因分析") class GuardCheckResponse(BaseModel): err_code: int data: Dict[str, Any] msg: str def build_check_content(messages: List[Dict], tools: List[Dict]) -> str: """将消息和工具调用信息拼接成检测内容""" # 提取用户消息内容 user_contents = [] for msg in messages: if msg.get("role") == "user": user_contents.append(msg.get("content", "")) content = "\n".join(user_contents) if user_contents else "" # 如果有工具信息,拼接工具调用详情 if tools: tool_infos = [] for tool in tools: tool_name = tool.get("name", "") tool_desc = tool.get("description", "") tool_params = tool.get("parameters", {}) tool_info = f"\n[Tool Call] {tool_name}" if tool_desc: tool_info += f"\nDescription: {tool_desc}" if tool_params: tool_info += f"\nParameters: {json.dumps(tool_params, ensure_ascii=False)}" tool_infos.append(tool_info) content += "\n" + "\n".join(tool_infos) return content.strip() @app.on_event("startup") async def startup_event(): global xguard_model, executor try: xguard_model = XGuardModel(config.model_path, config.device) executor = ThreadPoolExecutor(max_workers=4) print(f"XGuard model loaded on {config.device}") except Exception as e: print(f"Failed to load model: {e}") raise @app.on_event("shutdown") async def shutdown_event(): global executor if executor: executor.shutdown(wait=True) @app.get("/health") async def health_check(): return {"status": "ok", "model_loaded": xguard_model is not None} @app.post("/v1/guard/check", response_model=GuardCheckResponse) async def guard_check( request: GuardCheckRequest, x_api_key: str = Header(..., alias="x-api-key") ): if x_api_key != config.api_key: raise HTTPException(status_code=401, detail="Invalid API key") if xguard_model is None: raise HTTPException(status_code=503, detail="Model not loaded") async with request_semaphore: try: messages = [{"role": m.role, "content": m.content} for m in request.messages] tools = [{"name": t.name, "description": t.description, "parameters": t.parameters} for t in request.tools] # 将消息和工具信息拼接成检测内容 check_content = build_check_content(messages, tools) logger.info("会话 [%s] 检测内容:\n%s", request.conversationId, check_content) # 构建用于检测的消息 check_messages = [{"role": "user", "content": check_content}] loop = asyncio.get_event_loop() result = await loop.run_in_executor( executor, lambda: xguard_model.analyze( check_messages, [], # 工具已拼接到内容中,不再单独传递 enable_reasoning=request.enableReasoning ) ) # 构建响应数据 response_data = { "is_safe": result["is_safe"], "risk_level": result.get("risk_level", "safe" if result["is_safe"] == 1 else "medium"), "confidence": result.get("confidence", 0.0), "risk_type": result["risk_type"], "reason": result["reason"] } # 如果启用了归因分析,添加 explanation if request.enableReasoning and "explanation" in result: response_data["explanation"] = result["explanation"] return GuardCheckResponse( err_code=0, data=response_data, msg="success" ) except Exception as e: raise HTTPException(status_code=500, detail=f"Inference error: {str(e)}") if __name__ == "__main__": uvicorn.run(app, host=config.host, port=config.port)