// Package websocket provides API endpoints for AI Infrastructure Guard task management // // This package implements RESTful APIs for: // - Task submission and management // - Task status monitoring // - Task result retrieval // - Support for multiple task types: MCP scan, AI infra scan, and model redteam testing // // API Endpoints: // - POST /api/v1/app/taskapi/tasks - Create new tasks // - GET /api/v1/app/taskapi/status/{id} - Get task status and logs // - GET /api/v1/app/taskapi/result/{id} - Get task results package websocket import ( "encoding/json" "net/http" "strings" "time" "github.com/Tencent/AI-Infra-Guard/common/agent" "github.com/gin-gonic/gin" "github.com/google/uuid" "trpc.group/trpc-go/trpc-go/log" ) // ModelParams represents model configuration parameters type ModelParams struct { BaseUrl string `json:"base_url" example:"https://api.openai.com/v1"` // Model API base URL Token string `json:"token" example:"sk-xxx"` // API access token Model string `json:"model" example:"gpt-4"` // Model name Limit int `json:"limit" example:"1000"` // Request limit } // MCPTaskRequest represents MCP task request structure // @Description MCP (Model Context Protocol) 安全扫描任务请求参数 type MCPTaskRequest struct { Content string `json:"content,omitempty" example:"扫描目标MCP服务器"` // 任务内容描述 Model struct { Model string `json:"model" binding:"required" example:"gpt-4"` // 模型名称 - 必需 Token string `json:"token" binding:"required" example:"sk-xxx"` // API密钥 - 必需 BaseUrl string `json:"base_url,omitempty" example:"https://api.openai.com/v1"` // 基础URL - 可选 } `json:"model" binding:"required"` // 模型配置 - 必需 Thread int `json:"thread,omitempty" example:"4"` // 并发线程数 Language string `json:"language,omitempty" example:"zh"` // 语言代码 - 可选 Attachments string `json:"attachments,omitempty" example:"file1.zip"` // 附件文件路径 } // AIInfraScanTaskRequest AI基础设施扫描任务请求结构体 // @Description AI基础设施安全扫描任务请求参数 type AIInfraScanTaskRequest struct { Target []string `json:"target" example:"https://example.com"` // 扫描目标URL列表 Headers map[string]string `json:"headers" example:"{\"Authorization\":\"Bearer token\"}"` // 自定义请求头 Timeout int `json:"timeout" example:"30"` // 请求超时时间(秒) } // PromptSecurityTaskRequest 提示词安全测试任务请求结构体 // @Description 提示词安全测试任务请求参数 // @Description 支持的数据集: // @Description - JailBench-Tiny: 小型越狱基准测试数据集 // @Description - JailbreakPrompts-Tiny: 小型越狱提示词数据集 // @Description - ChatGPT-Jailbreak-Prompts: ChatGPT越狱提示词数据集 // @Description - JADE-db-v3.0: JADE数据库v3.0版本 // @Description - HarmfulEvalBenchmark: 有害内容评估基准数据集 type PromptSecurityTaskRequest struct { Model []ModelParams `json:"model"` // 测试模型列表 EvalModel ModelParams `json:"eval_model"` // 评估模型配置 Datasets struct { DataFile []string `json:"dataFile" example:"[\"JailBench-Tiny\",\"JailbreakPrompts-Tiny\"]"` // 数据集文件列表,可选: JailBench-Tiny, JailbreakPrompts-Tiny, ChatGPT-Jailbreak-Prompts, JADE-db-v3.0, HarmfulEvalBenchmark NumPrompts int `json:"numPrompts" example:"100"` // 提示词数量 RandomSeed int `json:"randomSeed" example:"42"` // 随机种子 } `json:"dataset"` // 数据集配置 } // APIResponse 通用API响应结构 type APIResponse struct { Status int `json:"status" example:"0"` // 状态码: 0=成功, 1=失败 Message string `json:"message" example:"操作成功"` // 响应消息 Data interface{} `json:"data"` // 响应数据 } // TaskStatusResponse 任务状态响应结构 type TaskStatusResponse struct { SessionID string `json:"session_id" example:"550e8400-e29b-41d4-a716-446655440000"` // 任务会话ID Status string `json:"status" example:"running"` // 任务状态: pending, running, completed, failed Title string `json:"title" example:"MCP安全扫描任务"` // 任务标题 CreatedAt int64 `json:"created_at" example:"1640995200000"` // 创建时间戳(毫秒) UpdatedAt int64 `json:"updated_at" example:"1640995200000"` // 更新时间戳(毫秒) Log string `json:"log" example:"任务执行日志..."` // 任务执行日志 } // TaskCreateResponse 任务创建响应结构 type TaskCreateResponse struct { SessionID string `json:"session_id" example:"550e8400-e29b-41d4-a716-446655440000"` // 任务会话ID } // Task Types and Parameters Documentation: // // 1. MCP Scan Task (type: "mcp_scan") // - Purpose: Model Context Protocol security scanning // - Request structure: // { // "type": "mcp_scan", // "content": { // "content": "任务描述", // 可选: 任务内容描述 // "model": { // "model": "gpt-4", // 必需: 模型名称 // "token": "sk-xxx", // 必需: API密钥 // "base_url": "https://api.openai.com/v1" // 可选: 基础URL // }, // "thread": 4, // 可选: 并发线程数 // "language": "zh", // 可选: 语言代码 // "attachments": "file.zip" // 可选: 附件文件路径 // } // } // // 2. AI Infra Scan Task (type: "ai_infra_scan") // - Purpose: AI infrastructure security scanning // - Request structure: // { // "type": "ai_infra_scan", // "content": { // "target": ["https://example.com"], // 必需: 扫描目标URL列表 // "headers": { // 可选: 自定义请求头 // "Authorization": "Bearer token" // }, // "timeout": 30 // 可选: 请求超时时间(秒) // } // } // // 3. Model Redteam Task (type: "model_redteam_report") // - Purpose: AI model red team testing and security assessment // - Request structure: // { // "type": "model_redteam_report", // "content": { // "model": [ // 必需: 测试模型列表 // { // "model": "gpt-4", // "token": "sk-xxx", // "base_url": "https://api.openai.com/v1" // } // ], // "eval_model": { // 必需: 评估模型配置 // "model": "gpt-4", // "token": "sk-xxx" // }, // "dataset": { // 必需: 数据集配置 // "dataFile": ["JailBench-Tiny", "JailbreakPrompts-Tiny"], // 数据集文件列表,可选: JailBench-Tiny, JailbreakPrompts-Tiny, ChatGPT-Jailbreak-Prompts, JADE-db-v3.0, HarmfulEvalBenchmark // "numPrompts": 100, // 提示词数量 // "randomSeed": 42 // 随机种子 // } // } // } // SubmitTask 创建任务接口 // @Summary Create a new task // @Description Submit a new task for processing. Supports three types of tasks: // @Description 1. MCP Scan (mcp_scan): Model Context Protocol security scanning // @Description 2. AI Infra Scan (ai_infra_scan): AI infrastructure security scanning // @Description 3. Model Redteam Report (model_redteam_report): AI model red team testing // @Description // @Description Request Body Examples: // @Description // @Description MCP Scan Task: // @Description { // @Description "type": "mcp_scan", // @Description "content": { // @Description "content": "扫描MCP服务器", // @Description "model": { // @Description "model": "gpt-4", // @Description "token": "sk-xxx", // @Description "base_url": "https://api.openai.com/v1" // @Description }, // @Description "thread": 4, // @Description "language": "zh", // @Description "attachments": "file.zip" // @Description } // @Description } // @Description // @Description AI Infra Scan Task: // @Description { // @Description "type": "ai_infra_scan", // @Description "content": { // @Description "target": ["https://example.com"], // @Description "headers": { // @Description "Authorization": "Bearer token" // @Description }, // @Description "timeout": 30 // @Description } // @Description } // @Description // @Description Model Redteam Task: // @Description { // @Description "type": "model_redteam_report", // @Description "content": { // @Description "model": [{ // @Description "model": "gpt-4", // @Description "token": "sk-xxx", // @Description "base_url": "https://api.openai.com/v1" // @Description }], // @Description "eval_model": { // @Description "model": "gpt-4", // @Description "token": "sk-xxx" // @Description }, // @Description "dataset": { // @Description "dataFile": ["JailBench-Tiny", "JailbreakPrompts-Tiny"], // @Description "numPrompts": 100, // @Description "randomSeed": 42 // @Description } // @Description } // @Description } // @Tags taskapi // @Accept json // @Produce json // @Param request body object{content=object,type=string} true "Task request body. Content should be JSON object containing task-specific parameters based on type" // @Success 200 {object} APIResponse{data=TaskCreateResponse} "Task created successfully" // @Failure 400 {object} APIResponse "Invalid request parameters" // @Failure 500 {object} APIResponse "Internal server error" // @Router /api/v1/app/taskapi/tasks [post] func SubmitTask(c *gin.Context, tm *TaskManager) { var content struct { Content json.RawMessage `json:"content"` Type string `json:"type"` } if err := c.ShouldBindJSON(&content); err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "参数错误: " + err.Error(), "data": nil, }) return } // 生成sessionId sessionId := uuid.New().String() // 生成消息ID messageId := uuid.New().String() // 设置默认用户名为开发者API用户 username := c.GetString("api_user") var taskReq TaskCreateRequest // content interface to byte switch content.Type { case "mcp_scan": var req MCPTaskRequest err := json.Unmarshal(content.Content, &req) if err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "参数错误: " + err.Error(), "data": nil, }) return } // 构建任务参数 params := map[string]interface{}{ "model": map[string]interface{}{ "model": req.Model.Model, "token": req.Model.Token, "base_url": req.Model.BaseUrl, }, "quick": false, "plugins": []string{ "auth_bypass", "cmd_injection", "credential_theft", "hardcoded_api_key", "indirect_prompt_injection", "name_confusion", "rug_pull", "tool_poisoning", "tool_shadowing", }, } var attachments []string if req.Attachments != "" { attachments = append(attachments, req.Attachments) } // 构建TaskCreateRequest taskReq = TaskCreateRequest{ ID: messageId, SessionID: sessionId, Username: username, Task: agent.TaskTypeMcpScan, Timestamp: time.Now().UnixMilli(), Content: req.Content, Params: params, Attachments: attachments, } case "ai_infra_scan": var req AIInfraScanTaskRequest err := json.Unmarshal(content.Content, &req) if err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "参数错误: " + err.Error(), "data": nil, }) return } scanParams := map[string]interface{}{ "headers": req.Headers, "timeout": req.Timeout, } taskReq = TaskCreateRequest{ ID: messageId, SessionID: sessionId, Username: username, Task: agent.TaskTypeAIInfraScan, Timestamp: time.Now().UnixMilli(), Params: scanParams, Content: strings.Join(req.Target, "\n"), Attachments: []string{}, } case "model_redteam_report": var req PromptSecurityTaskRequest err := json.Unmarshal(content.Content, &req) if err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "参数错误: " + err.Error(), "data": nil, }) return } params := map[string]interface{}{ "model": req.Model, "eval_model": req.EvalModel, "dataset": req.Datasets, } taskReq = TaskCreateRequest{ ID: messageId, SessionID: sessionId, Username: username, Task: agent.TaskTypeModelRedteamReport, Timestamp: time.Now().UnixMilli(), Content: "", Attachments: []string{}, Params: params, } default: c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "无效的任务类型", "data": nil, }) return } err := tm.AddTaskApi(&taskReq) if err != nil { log.Errorf("任务创建失败: sessionId=%s, error=%v", sessionId, err) c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "任务创建失败: " + err.Error(), "data": nil, }) return } c.JSON(http.StatusOK, gin.H{ "status": 0, "message": "任务创建成功,正在后台处理", "data": gin.H{ "session_id": sessionId, }, }) } // GetTaskStatus 获取任务状态接口(开发者API) // @Summary Get task status // @Description Retrieve the current status and logs of a task by session ID. Returns task metadata and execution logs. // @Tags taskapi // @Produce json // @Param id path string true "Task Session ID" example:"550e8400-e29b-41d4-a716-446655440000" // @Success 200 {object} APIResponse{data=TaskStatusResponse} "Task status retrieved successfully" // @Failure 400 {object} APIResponse "Invalid session ID format" // @Failure 404 {object} APIResponse "Task not found" // @Failure 500 {object} APIResponse "Internal server error" // @Router /api/v1/app/taskapi/status/{id} [get] func GetTaskStatus(c *gin.Context, tm *TaskManager) { sessionId := c.Param("id") if sessionId == "" { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "任务ID不能为空", "data": nil, }) return } // 验证sessionId格式 if !isValidSessionID(sessionId) { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "无效的任务ID格式", "data": nil, }) return } // 从数据库获取任务信息 session, err := tm.taskStore.GetSession(sessionId) if err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "任务不存在", "data": nil, }) return } // 获取任务的所有消息/事件 messages, err := tm.taskStore.GetSessionEventsByType(sessionId, "actionLog") if err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "获取任务结果失败", "data": nil, }) return } msg := "" type logStruct struct { ActionLog string `json:"actionLog"` } for _, m := range messages { var x logStruct err = json.Unmarshal([]byte(m.EventData.String()), &x) if err != nil { continue } msg += x.ActionLog } // 构建状态响应 statusData := gin.H{ "session_id": session.ID, "status": session.Status, "title": session.Title, "created_at": session.CreatedAt, "updated_at": session.UpdatedAt, "log": msg, } c.JSON(http.StatusOK, gin.H{ "status": 0, "message": "获取任务状态成功", "data": statusData, }) } // GetTaskResult 获取任务结果接口(开发者API) // @Summary Get task result // @Description Retrieve the final result of a completed task. Returns detailed scan results, vulnerabilities found, and security assessment data. // @Tags taskapi // @Produce json // @Param id path string true "Task Session ID" example:"550e8400-e29b-41d4-a716-446655440000" // @Success 200 {object} APIResponse "Task result retrieved successfully. Data contains scan results, vulnerabilities, and security findings" // @Failure 400 {object} APIResponse "Invalid session ID format" // @Failure 404 {object} APIResponse "Task not found or not completed" // @Failure 500 {object} APIResponse "Internal server error" // @Router /api/v1/app/taskapi/result/{id} [get] func GetTaskResult(c *gin.Context, tm *TaskManager) { traceID := getTraceID(c) sessionId := c.Param("id") if sessionId == "" { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "任务ID不能为空", "data": nil, }) return } // 验证sessionId格式 if !isValidSessionID(sessionId) { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "无效的任务ID格式", "data": nil, }) return } log.Infof("开始获取任务结果: trace_id=%s, sessionId=%s", traceID, sessionId) // 获取任务的所有消息/事件 messages, err := tm.taskStore.GetSessionEventsByType(sessionId, "resultUpdate") if err != nil || len(messages) == 0 { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "获取任务结果失败,任务可能尚未完成", "data": nil, }) return } msg := messages[0] // 解析事件数据 var eventData map[string]interface{} if err := json.Unmarshal(msg.EventData, &eventData); err != nil { c.JSON(http.StatusOK, gin.H{ "status": 1, "message": "获取任务结果失败", "data": nil, }) return } c.JSON(http.StatusOK, gin.H{ "status": 0, "message": "ok", "data": eventData, }) }