|
|
|
|
|
|
|
|
|
|
|
|
|
|
package websocket |
|
|
|
|
|
import ( |
|
|
"embed" |
|
|
"mime" |
|
|
"path/filepath" |
|
|
"strings" |
|
|
|
|
|
"github.com/Tencent/AI-Infra-Guard/common/trpc" |
|
|
_ "github.com/Tencent/AI-Infra-Guard/docs" |
|
|
"github.com/Tencent/AI-Infra-Guard/internal/options" |
|
|
"github.com/Tencent/AI-Infra-Guard/pkg/database" |
|
|
"github.com/gin-gonic/gin" |
|
|
swaggerFiles "github.com/swaggo/files" |
|
|
ginSwagger "github.com/swaggo/gin-swagger" |
|
|
"trpc.group/trpc-go/trpc-go/log" |
|
|
) |
|
|
|
|
|
|
|
|
var staticFS embed.FS |
|
|
|
|
|
func RunWebServer(options *options.Options) { |
|
|
|
|
|
if err := trpc.InitTrpc("./trpc_go.yaml"); err != nil { |
|
|
log.Fatalf("Trpc-go初始化失败: %v", err) |
|
|
} |
|
|
log.Infof("Trpc-go initialized successfully: trace_id=system_startup") |
|
|
|
|
|
r := gin.Default() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dbConfig := database.LoadConfigFromEnv() |
|
|
db, err := database.InitDB(dbConfig) |
|
|
if err != nil { |
|
|
log.Errorf("数据库初始化失败: trace_id=system_startup, error=%v", err) |
|
|
|
|
|
} |
|
|
taskStore := database.NewTaskStore(db) |
|
|
if err := taskStore.Init(); err != nil { |
|
|
log.Errorf("初始化tasks表失败: trace_id=system_startup, error=%v", err) |
|
|
log.Fatalf("初始化tasks表失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
modelStore := database.NewModelStore(db) |
|
|
if err := modelStore.Init(); err != nil { |
|
|
log.Errorf("初始化models表失败: trace_id=system_startup, error=%v", err) |
|
|
|
|
|
} |
|
|
|
|
|
modelStore.AutoAddModels() |
|
|
|
|
|
|
|
|
agentManager := NewAgentManager() |
|
|
|
|
|
|
|
|
modelManager := NewModelManager(modelStore) |
|
|
|
|
|
|
|
|
fileConfig := LoadFileUploadConfigFromEnv() |
|
|
|
|
|
|
|
|
if err := fileConfig.ValidateConfig(); err != nil { |
|
|
log.Errorf("文件上传配置验证失败: trace_id=system_startup, error=%v", err) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
sseManager := NewSSEManager() |
|
|
|
|
|
taskManager := NewTaskManager(agentManager, taskStore, modelStore, fileConfig, sseManager) |
|
|
err = taskManager.taskStore.ResetRunningTasks() |
|
|
if err != nil { |
|
|
log.Fatalf("重置运行中的任务失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
agentManager.SetTaskManager(taskManager) |
|
|
|
|
|
|
|
|
v1 := r.Group("/api/v1") |
|
|
{ |
|
|
v1.GET("/images/:path", func(context *gin.Context) { |
|
|
path := context.Param("path") |
|
|
if strings.Contains(path, "..") { |
|
|
context.String(403, "Forbidden") |
|
|
return |
|
|
} |
|
|
context.File(filepath.Join("uploads", path)) |
|
|
}) |
|
|
|
|
|
knowledge := v1.Group("/knowledge") |
|
|
{ |
|
|
|
|
|
fingerprints := knowledge.Group("/fingerprints") |
|
|
{ |
|
|
|
|
|
fingerprints.GET("", HandleListFingerprints) |
|
|
fingerprints.POST("", HandleCreateFingerprint) |
|
|
fingerprints.PUT("/:name", HandleEditFingerprint) |
|
|
fingerprints.DELETE("", HandleDeleteFingerprint) |
|
|
} |
|
|
|
|
|
vulnerabilities := knowledge.Group("/vulnerabilities") |
|
|
{ |
|
|
|
|
|
vulnerabilities.GET("", HandleListVulnerabilities()) |
|
|
vulnerabilities.POST("", HandleCreateVulnerability()) |
|
|
vulnerabilities.PUT("/:cve", HandleEditVulnerability) |
|
|
vulnerabilities.DELETE("", HandleBatchDeleteVulnerabilities) |
|
|
} |
|
|
|
|
|
evaluations := knowledge.Group("/evaluations") |
|
|
{ |
|
|
|
|
|
evaluations.GET("/:name", HandleGetEvaluationDetail) |
|
|
evaluations.GET("", HandleListEvaluations) |
|
|
evaluations.POST("", HandleCreateEvaluation) |
|
|
evaluations.PUT("/:name", HandleEditEvaluation) |
|
|
evaluations.DELETE("", HandleDeleteEvaluation) |
|
|
} |
|
|
|
|
|
mcp := knowledge.Group("/mcp") |
|
|
{ |
|
|
mcp.GET("names", GetMcpPluginList) |
|
|
mcp.GET("", HandleList(MCPROOT, McpLoadFile)) |
|
|
mcp.POST("", HandleCreate(mcpReadAndSave)) |
|
|
mcp.PUT("/:id", HandleEdit(mcpUpdateFunc)) |
|
|
mcp.DELETE("/:id", HandleDelete(mcpDeleteFunc)) |
|
|
} |
|
|
|
|
|
collections := knowledge.Group("/prompt_collections") |
|
|
{ |
|
|
collections.GET("", HandleList(PromptCollectionsRoot, promptCollectionLoadFile)) |
|
|
collections.POST("", HandleCreate(promptCollectionReadAndSave)) |
|
|
collections.PUT("/:id", HandleEdit(promptCollectionUpdateFunc)) |
|
|
collections.DELETE("", HandleDelete(promptCollectionDeleteFunc)) |
|
|
} |
|
|
|
|
|
knowledge.GET("/jailbreak", GetJailBreak) |
|
|
} |
|
|
appSecurity := v1.Group("/app") |
|
|
{ |
|
|
appSecurity.Use(setupIdentityMiddleware()) |
|
|
|
|
|
tasks := appSecurity.Group("/tasks") |
|
|
{ |
|
|
|
|
|
tasks.GET("", func(c *gin.Context) { |
|
|
HandleGetTaskList(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.GET("/:sessionId", func(c *gin.Context) { |
|
|
HandleGetTaskDetail(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.POST("/share", func(c *gin.Context) { |
|
|
HandleShare(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.GET("/sse/:sessionId", func(c *gin.Context) { |
|
|
HandleTaskSSE(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.POST("", func(c *gin.Context) { |
|
|
HandleTaskCreate(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.POST("/uploadFile", func(c *gin.Context) { |
|
|
HandleUploadFile(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.POST("/:sessionId/downloadFile", func(c *gin.Context) { |
|
|
HandleDownloadFile(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.PUT("/:sessionId", func(c *gin.Context) { |
|
|
HandleUpdateTask(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.DELETE("/:sessionId", func(c *gin.Context) { |
|
|
HandleDeleteTask(c, taskManager) |
|
|
}) |
|
|
|
|
|
tasks.POST("/:sessionId/terminate", func(c *gin.Context) { |
|
|
HandleTerminateTask(c, taskManager) |
|
|
}) |
|
|
} |
|
|
|
|
|
models := appSecurity.Group("/models") |
|
|
{ |
|
|
|
|
|
models.GET("", func(c *gin.Context) { |
|
|
HandleGetModelList(c, modelManager) |
|
|
}) |
|
|
|
|
|
models.GET("/:modelId", func(c *gin.Context) { |
|
|
HandleGetModelDetail(c, modelManager) |
|
|
}) |
|
|
|
|
|
models.POST("", func(c *gin.Context) { |
|
|
HandleCreateModel(c, modelManager) |
|
|
}) |
|
|
|
|
|
models.PUT("/:modelId", func(c *gin.Context) { |
|
|
HandleUpdateModel(c, modelManager) |
|
|
}) |
|
|
|
|
|
models.DELETE("", func(c *gin.Context) { |
|
|
HandleDeleteModel(c, modelManager) |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
agents := v1.Group("/agents") |
|
|
{ |
|
|
|
|
|
agents.GET("/ws", agentManager.HandleAgentWebSocket()) |
|
|
} |
|
|
|
|
|
taskApi := appSecurity.Group("/taskapi") |
|
|
{ |
|
|
|
|
|
taskApi.POST("/tasks", func(c *gin.Context) { |
|
|
SubmitTask(c, taskManager) |
|
|
}) |
|
|
|
|
|
taskApi.GET("/status/:id", func(c *gin.Context) { |
|
|
GetTaskStatus(c, taskManager) |
|
|
}) |
|
|
|
|
|
taskApi.GET("/result/:id", func(c *gin.Context) { |
|
|
GetTaskResult(c, taskManager) |
|
|
}) |
|
|
taskApi.POST("/upload", func(c *gin.Context) { |
|
|
HandleUploadFile(c, taskManager) |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
r.GET("/docs/*any", func(c *gin.Context) { |
|
|
if c.Request.URL.Path == "/docs/" { |
|
|
c.Redirect(302, "/docs/index.html") |
|
|
} else { |
|
|
ginSwagger.WrapHandler(swaggerFiles.Handler)(c) |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
r.NoRoute(func(c *gin.Context) { |
|
|
assetPath := "static" + c.Request.URL.Path |
|
|
if c.Request.URL.Path == "/" { |
|
|
assetPath = "static/index.html" |
|
|
} |
|
|
|
|
|
assetData, err := staticFS.ReadFile(assetPath) |
|
|
if err != nil { |
|
|
assetData, err = staticFS.ReadFile("static/index.html") |
|
|
if err != nil { |
|
|
c.String(500, "Internal Server Error") |
|
|
return |
|
|
} |
|
|
c.Header("Content-Type", "text/html") |
|
|
c.Data(200, "text/html", assetData) |
|
|
return |
|
|
} |
|
|
|
|
|
mimeType := mime.TypeByExtension(filepath.Ext(assetPath)) |
|
|
if mimeType == "" { |
|
|
mimeType = "text/plain" |
|
|
} |
|
|
c.Header("Content-Type", mimeType) |
|
|
c.Data(200, mimeType, assetData) |
|
|
}) |
|
|
|
|
|
log.Infof("Starting WebServer: trace_id=system_startup, addr=%s", options.WebServerAddr) |
|
|
if err := r.Run(options.WebServerAddr); err != nil { |
|
|
log.Errorf("Could not start WebSocket server: trace_id=system_startup, error=%s", err) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func setupIdentityMiddleware() gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
|
|
|
username := c.GetHeader("username") |
|
|
|
|
|
|
|
|
if username == "" { |
|
|
username = "public_user" |
|
|
} |
|
|
|
|
|
c.Set("username", username) |
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|