AbdulElahGwaith's picture
Upload folder using huggingface_hub
ffb6330 verified
// Package websocket provides API endpoints for AI Infrastructure Guard task management
//
// This package implements RESTful APIs for:
// - Task submission and management
// - Task status monitoring
// - Task result retrieval
// - Support for multiple task types: MCP scan, AI infra scan, and model redteam testing
//
// API Endpoints:
// - POST /api/v1/app/taskapi/tasks - Create new tasks
// - GET /api/v1/app/taskapi/status/{id} - Get task status and logs
// - GET /api/v1/app/taskapi/result/{id} - Get task results
package websocket
import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/Tencent/AI-Infra-Guard/common/agent"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"trpc.group/trpc-go/trpc-go/log"
)
// ModelParams represents model configuration parameters
type ModelParams struct {
BaseUrl string `json:"base_url" example:"https://api.openai.com/v1"` // Model API base URL
Token string `json:"token" example:"sk-xxx"` // API access token
Model string `json:"model" example:"gpt-4"` // Model name
Limit int `json:"limit" example:"1000"` // Request limit
}
// MCPTaskRequest represents MCP task request structure
// @Description MCP (Model Context Protocol) 安全扫描任务请求参数
type MCPTaskRequest struct {
Content string `json:"content,omitempty" example:"扫描目标MCP服务器"` // 任务内容描述
Model struct {
Model string `json:"model" binding:"required" example:"gpt-4"` // 模型名称 - 必需
Token string `json:"token" binding:"required" example:"sk-xxx"` // API密钥 - 必需
BaseUrl string `json:"base_url,omitempty" example:"https://api.openai.com/v1"` // 基础URL - 可选
} `json:"model" binding:"required"` // 模型配置 - 必需
Thread int `json:"thread,omitempty" example:"4"` // 并发线程数
Language string `json:"language,omitempty" example:"zh"` // 语言代码 - 可选
Attachments string `json:"attachments,omitempty" example:"file1.zip"` // 附件文件路径
}
// AIInfraScanTaskRequest AI基础设施扫描任务请求结构体
// @Description AI基础设施安全扫描任务请求参数
type AIInfraScanTaskRequest struct {
Target []string `json:"target" example:"https://example.com"` // 扫描目标URL列表
Headers map[string]string `json:"headers" example:"{\"Authorization\":\"Bearer token\"}"` // 自定义请求头
Timeout int `json:"timeout" example:"30"` // 请求超时时间(秒)
}
// PromptSecurityTaskRequest 提示词安全测试任务请求结构体
// @Description 提示词安全测试任务请求参数
// @Description 支持的数据集:
// @Description - JailBench-Tiny: 小型越狱基准测试数据集
// @Description - JailbreakPrompts-Tiny: 小型越狱提示词数据集
// @Description - ChatGPT-Jailbreak-Prompts: ChatGPT越狱提示词数据集
// @Description - JADE-db-v3.0: JADE数据库v3.0版本
// @Description - HarmfulEvalBenchmark: 有害内容评估基准数据集
type PromptSecurityTaskRequest struct {
Model []ModelParams `json:"model"` // 测试模型列表
EvalModel ModelParams `json:"eval_model"` // 评估模型配置
Datasets struct {
DataFile []string `json:"dataFile" example:"[\"JailBench-Tiny\",\"JailbreakPrompts-Tiny\"]"` // 数据集文件列表,可选: JailBench-Tiny, JailbreakPrompts-Tiny, ChatGPT-Jailbreak-Prompts, JADE-db-v3.0, HarmfulEvalBenchmark
NumPrompts int `json:"numPrompts" example:"100"` // 提示词数量
RandomSeed int `json:"randomSeed" example:"42"` // 随机种子
} `json:"dataset"` // 数据集配置
}
// APIResponse 通用API响应结构
type APIResponse struct {
Status int `json:"status" example:"0"` // 状态码: 0=成功, 1=失败
Message string `json:"message" example:"操作成功"` // 响应消息
Data interface{} `json:"data"` // 响应数据
}
// TaskStatusResponse 任务状态响应结构
type TaskStatusResponse struct {
SessionID string `json:"session_id" example:"550e8400-e29b-41d4-a716-446655440000"` // 任务会话ID
Status string `json:"status" example:"running"` // 任务状态: pending, running, completed, failed
Title string `json:"title" example:"MCP安全扫描任务"` // 任务标题
CreatedAt int64 `json:"created_at" example:"1640995200000"` // 创建时间戳(毫秒)
UpdatedAt int64 `json:"updated_at" example:"1640995200000"` // 更新时间戳(毫秒)
Log string `json:"log" example:"任务执行日志..."` // 任务执行日志
}
// TaskCreateResponse 任务创建响应结构
type TaskCreateResponse struct {
SessionID string `json:"session_id" example:"550e8400-e29b-41d4-a716-446655440000"` // 任务会话ID
}
// Task Types and Parameters Documentation:
//
// 1. MCP Scan Task (type: "mcp_scan")
// - Purpose: Model Context Protocol security scanning
// - Request structure:
// {
// "type": "mcp_scan",
// "content": {
// "content": "任务描述", // 可选: 任务内容描述
// "model": {
// "model": "gpt-4", // 必需: 模型名称
// "token": "sk-xxx", // 必需: API密钥
// "base_url": "https://api.openai.com/v1" // 可选: 基础URL
// },
// "thread": 4, // 可选: 并发线程数
// "language": "zh", // 可选: 语言代码
// "attachments": "file.zip" // 可选: 附件文件路径
// }
// }
//
// 2. AI Infra Scan Task (type: "ai_infra_scan")
// - Purpose: AI infrastructure security scanning
// - Request structure:
// {
// "type": "ai_infra_scan",
// "content": {
// "target": ["https://example.com"], // 必需: 扫描目标URL列表
// "headers": { // 可选: 自定义请求头
// "Authorization": "Bearer token"
// },
// "timeout": 30 // 可选: 请求超时时间(秒)
// }
// }
//
// 3. Model Redteam Task (type: "model_redteam_report")
// - Purpose: AI model red team testing and security assessment
// - Request structure:
// {
// "type": "model_redteam_report",
// "content": {
// "model": [ // 必需: 测试模型列表
// {
// "model": "gpt-4",
// "token": "sk-xxx",
// "base_url": "https://api.openai.com/v1"
// }
// ],
// "eval_model": { // 必需: 评估模型配置
// "model": "gpt-4",
// "token": "sk-xxx"
// },
// "dataset": { // 必需: 数据集配置
// "dataFile": ["JailBench-Tiny", "JailbreakPrompts-Tiny"], // 数据集文件列表,可选: JailBench-Tiny, JailbreakPrompts-Tiny, ChatGPT-Jailbreak-Prompts, JADE-db-v3.0, HarmfulEvalBenchmark
// "numPrompts": 100, // 提示词数量
// "randomSeed": 42 // 随机种子
// }
// }
// }
// SubmitTask 创建任务接口
// @Summary Create a new task
// @Description Submit a new task for processing. Supports three types of tasks:
// @Description 1. MCP Scan (mcp_scan): Model Context Protocol security scanning
// @Description 2. AI Infra Scan (ai_infra_scan): AI infrastructure security scanning
// @Description 3. Model Redteam Report (model_redteam_report): AI model red team testing
// @Description
// @Description Request Body Examples:
// @Description
// @Description MCP Scan Task:
// @Description {
// @Description "type": "mcp_scan",
// @Description "content": {
// @Description "content": "扫描MCP服务器",
// @Description "model": {
// @Description "model": "gpt-4",
// @Description "token": "sk-xxx",
// @Description "base_url": "https://api.openai.com/v1"
// @Description },
// @Description "thread": 4,
// @Description "language": "zh",
// @Description "attachments": "file.zip"
// @Description }
// @Description }
// @Description
// @Description AI Infra Scan Task:
// @Description {
// @Description "type": "ai_infra_scan",
// @Description "content": {
// @Description "target": ["https://example.com"],
// @Description "headers": {
// @Description "Authorization": "Bearer token"
// @Description },
// @Description "timeout": 30
// @Description }
// @Description }
// @Description
// @Description Model Redteam Task:
// @Description {
// @Description "type": "model_redteam_report",
// @Description "content": {
// @Description "model": [{
// @Description "model": "gpt-4",
// @Description "token": "sk-xxx",
// @Description "base_url": "https://api.openai.com/v1"
// @Description }],
// @Description "eval_model": {
// @Description "model": "gpt-4",
// @Description "token": "sk-xxx"
// @Description },
// @Description "dataset": {
// @Description "dataFile": ["JailBench-Tiny", "JailbreakPrompts-Tiny"],
// @Description "numPrompts": 100,
// @Description "randomSeed": 42
// @Description }
// @Description }
// @Description }
// @Tags taskapi
// @Accept json
// @Produce json
// @Param request body object{content=object,type=string} true "Task request body. Content should be JSON object containing task-specific parameters based on type"
// @Success 200 {object} APIResponse{data=TaskCreateResponse} "Task created successfully"
// @Failure 400 {object} APIResponse "Invalid request parameters"
// @Failure 500 {object} APIResponse "Internal server error"
// @Router /api/v1/app/taskapi/tasks [post]
func SubmitTask(c *gin.Context, tm *TaskManager) {
var content struct {
Content json.RawMessage `json:"content"`
Type string `json:"type"`
}
if err := c.ShouldBindJSON(&content); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "参数错误: " + err.Error(),
"data": nil,
})
return
}
// 生成sessionId
sessionId := uuid.New().String()
// 生成消息ID
messageId := uuid.New().String()
// 设置默认用户名为开发者API用户
username := c.GetString("api_user")
var taskReq TaskCreateRequest
// content interface to byte
switch content.Type {
case "mcp_scan":
var req MCPTaskRequest
err := json.Unmarshal(content.Content, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "参数错误: " + err.Error(),
"data": nil,
})
return
}
// 构建任务参数
params := map[string]interface{}{
"model": map[string]interface{}{
"model": req.Model.Model,
"token": req.Model.Token,
"base_url": req.Model.BaseUrl,
},
"quick": false,
"plugins": []string{
"auth_bypass", "cmd_injection", "credential_theft", "hardcoded_api_key", "indirect_prompt_injection", "name_confusion", "rug_pull", "tool_poisoning", "tool_shadowing",
},
}
var attachments []string
if req.Attachments != "" {
attachments = append(attachments, req.Attachments)
}
// 构建TaskCreateRequest
taskReq = TaskCreateRequest{
ID: messageId,
SessionID: sessionId,
Username: username,
Task: agent.TaskTypeMcpScan,
Timestamp: time.Now().UnixMilli(),
Content: req.Content,
Params: params,
Attachments: attachments,
}
case "ai_infra_scan":
var req AIInfraScanTaskRequest
err := json.Unmarshal(content.Content, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "参数错误: " + err.Error(),
"data": nil,
})
return
}
scanParams := map[string]interface{}{
"headers": req.Headers,
"timeout": req.Timeout,
}
taskReq = TaskCreateRequest{
ID: messageId,
SessionID: sessionId,
Username: username,
Task: agent.TaskTypeAIInfraScan,
Timestamp: time.Now().UnixMilli(),
Params: scanParams,
Content: strings.Join(req.Target, "\n"),
Attachments: []string{},
}
case "model_redteam_report":
var req PromptSecurityTaskRequest
err := json.Unmarshal(content.Content, &req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "参数错误: " + err.Error(),
"data": nil,
})
return
}
params := map[string]interface{}{
"model": req.Model,
"eval_model": req.EvalModel,
"dataset": req.Datasets,
}
taskReq = TaskCreateRequest{
ID: messageId,
SessionID: sessionId,
Username: username,
Task: agent.TaskTypeModelRedteamReport,
Timestamp: time.Now().UnixMilli(),
Content: "",
Attachments: []string{},
Params: params,
}
default:
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "无效的任务类型",
"data": nil,
})
return
}
err := tm.AddTaskApi(&taskReq)
if err != nil {
log.Errorf("任务创建失败: sessionId=%s, error=%v", sessionId, err)
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "任务创建失败: " + err.Error(),
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": 0,
"message": "任务创建成功,正在后台处理",
"data": gin.H{
"session_id": sessionId,
},
})
}
// GetTaskStatus 获取任务状态接口(开发者API)
// @Summary Get task status
// @Description Retrieve the current status and logs of a task by session ID. Returns task metadata and execution logs.
// @Tags taskapi
// @Produce json
// @Param id path string true "Task Session ID" example:"550e8400-e29b-41d4-a716-446655440000"
// @Success 200 {object} APIResponse{data=TaskStatusResponse} "Task status retrieved successfully"
// @Failure 400 {object} APIResponse "Invalid session ID format"
// @Failure 404 {object} APIResponse "Task not found"
// @Failure 500 {object} APIResponse "Internal server error"
// @Router /api/v1/app/taskapi/status/{id} [get]
func GetTaskStatus(c *gin.Context, tm *TaskManager) {
sessionId := c.Param("id")
if sessionId == "" {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "任务ID不能为空",
"data": nil,
})
return
}
// 验证sessionId格式
if !isValidSessionID(sessionId) {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "无效的任务ID格式",
"data": nil,
})
return
}
// 从数据库获取任务信息
session, err := tm.taskStore.GetSession(sessionId)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "任务不存在",
"data": nil,
})
return
}
// 获取任务的所有消息/事件
messages, err := tm.taskStore.GetSessionEventsByType(sessionId, "actionLog")
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "获取任务结果失败",
"data": nil,
})
return
}
msg := ""
type logStruct struct {
ActionLog string `json:"actionLog"`
}
for _, m := range messages {
var x logStruct
err = json.Unmarshal([]byte(m.EventData.String()), &x)
if err != nil {
continue
}
msg += x.ActionLog
}
// 构建状态响应
statusData := gin.H{
"session_id": session.ID,
"status": session.Status,
"title": session.Title,
"created_at": session.CreatedAt,
"updated_at": session.UpdatedAt,
"log": msg,
}
c.JSON(http.StatusOK, gin.H{
"status": 0,
"message": "获取任务状态成功",
"data": statusData,
})
}
// GetTaskResult 获取任务结果接口(开发者API)
// @Summary Get task result
// @Description Retrieve the final result of a completed task. Returns detailed scan results, vulnerabilities found, and security assessment data.
// @Tags taskapi
// @Produce json
// @Param id path string true "Task Session ID" example:"550e8400-e29b-41d4-a716-446655440000"
// @Success 200 {object} APIResponse "Task result retrieved successfully. Data contains scan results, vulnerabilities, and security findings"
// @Failure 400 {object} APIResponse "Invalid session ID format"
// @Failure 404 {object} APIResponse "Task not found or not completed"
// @Failure 500 {object} APIResponse "Internal server error"
// @Router /api/v1/app/taskapi/result/{id} [get]
func GetTaskResult(c *gin.Context, tm *TaskManager) {
traceID := getTraceID(c)
sessionId := c.Param("id")
if sessionId == "" {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "任务ID不能为空",
"data": nil,
})
return
}
// 验证sessionId格式
if !isValidSessionID(sessionId) {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "无效的任务ID格式",
"data": nil,
})
return
}
log.Infof("开始获取任务结果: trace_id=%s, sessionId=%s", traceID, sessionId)
// 获取任务的所有消息/事件
messages, err := tm.taskStore.GetSessionEventsByType(sessionId, "resultUpdate")
if err != nil || len(messages) == 0 {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "获取任务结果失败,任务可能尚未完成",
"data": nil,
})
return
}
msg := messages[0]
// 解析事件数据
var eventData map[string]interface{}
if err := json.Unmarshal(msg.EventData, &eventData); err != nil {
c.JSON(http.StatusOK, gin.H{
"status": 1,
"message": "获取任务结果失败",
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"status": 0,
"message": "ok",
"data": eventData,
})
}