|
|
package websocket |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"net/http" |
|
|
|
|
|
"github.com/Tencent/AI-Infra-Guard/common/utils/models" |
|
|
|
|
|
"github.com/Tencent/AI-Infra-Guard/pkg/database" |
|
|
"github.com/gin-gonic/gin" |
|
|
"trpc.group/trpc-go/trpc-go/log" |
|
|
) |
|
|
|
|
|
|
|
|
type ModelInfo struct { |
|
|
Model string `json:"model" binding:"required"` |
|
|
Token string `json:"token" binding:"required"` |
|
|
BaseURL string `json:"base_url" binding:"required"` |
|
|
Limit int `json:"limit"` |
|
|
Note string `json:"note"` |
|
|
} |
|
|
|
|
|
|
|
|
type CreateModelRequest struct { |
|
|
ModelID string `json:"model_id" binding:"required"` |
|
|
Model ModelInfo `json:"model" binding:"required"` |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
type UpdateModelInfo struct { |
|
|
Model string `json:"model"` |
|
|
Token string `json:"token"` |
|
|
BaseURL string `json:"base_url"` |
|
|
Limit int `json:"limit"` |
|
|
Note string `json:"note"` |
|
|
} |
|
|
|
|
|
|
|
|
type UpdateModelRequest struct { |
|
|
Model UpdateModelInfo `json:"model" binding:"required"` |
|
|
} |
|
|
|
|
|
|
|
|
type DeleteModelRequest struct { |
|
|
ModelIDs []string `json:"model_ids" binding:"required"` |
|
|
} |
|
|
|
|
|
|
|
|
type ModelManager struct { |
|
|
modelStore *database.ModelStore |
|
|
} |
|
|
|
|
|
const maskedToken = "********" |
|
|
|
|
|
|
|
|
|
|
|
func maskToken(token string) string { |
|
|
if token == "" { |
|
|
return "" |
|
|
} |
|
|
return maskedToken |
|
|
} |
|
|
|
|
|
|
|
|
func NewModelManager(modelStore *database.ModelStore) *ModelManager { |
|
|
return &ModelManager{ |
|
|
modelStore: modelStore, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func HandleGetModelList(c *gin.Context, mm *ModelManager) { |
|
|
traceID := getTraceID(c) |
|
|
username := c.GetString("username") |
|
|
|
|
|
log.Debugf("用户请求获取模型列表: trace_id=%s, username=%s", traceID, username) |
|
|
|
|
|
var userModels []*database.Model |
|
|
var publicModels []*database.Model |
|
|
var err error |
|
|
|
|
|
|
|
|
publicModels, err = mm.modelStore.GetUserModels("public_user") |
|
|
if err != nil { |
|
|
log.Errorf("获取公共模型列表失败: trace_id=%s, username=%s, error=%v", traceID, username, err) |
|
|
|
|
|
log.Warnf("获取public_user模型失败,继续返回用户模型: %v", err) |
|
|
publicModels = []*database.Model{} |
|
|
} |
|
|
|
|
|
|
|
|
if username != "public_user" { |
|
|
userModels, err = mm.modelStore.GetUserModels(username) |
|
|
if err != nil { |
|
|
log.Errorf("获取用户模型列表失败: trace_id=%s, username=%s, error=%v", traceID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "获取模型列表失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var result []map[string]interface{} |
|
|
|
|
|
|
|
|
for _, model := range publicModels { |
|
|
|
|
|
|
|
|
item := map[string]interface{}{ |
|
|
"model_id": model.ModelID, |
|
|
"model": map[string]interface{}{ |
|
|
"model": model.ModelName, |
|
|
"token": maskToken(model.Token), |
|
|
"base_url": model.BaseURL, |
|
|
"note": model.Note, |
|
|
"limit": model.Limit, |
|
|
}, |
|
|
} |
|
|
result = append(result, item) |
|
|
} |
|
|
|
|
|
|
|
|
for _, model := range userModels { |
|
|
item := map[string]interface{}{ |
|
|
"model_id": model.ModelID, |
|
|
"model": map[string]interface{}{ |
|
|
"model": model.ModelName, |
|
|
|
|
|
"token": maskToken(model.Token), |
|
|
"base_url": model.BaseURL, |
|
|
"note": model.Note, |
|
|
"limit": model.Limit, |
|
|
}, |
|
|
} |
|
|
result = append(result, item) |
|
|
} |
|
|
|
|
|
log.Debugf("获取模型列表成功: trace_id=%s, username=%s, userModels=%d, publicModels=%d, total=%d", |
|
|
traceID, username, len(userModels), len(publicModels), len(result)) |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 0, |
|
|
"message": "获取模型列表成功", |
|
|
"data": result, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func HandleGetModelDetail(c *gin.Context, mm *ModelManager) { |
|
|
traceID := getTraceID(c) |
|
|
modelID := c.Param("modelId") |
|
|
username := c.GetString("username") |
|
|
|
|
|
|
|
|
if modelID == "" { |
|
|
log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型ID不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Debugf("用户请求获取模型详情: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) |
|
|
|
|
|
|
|
|
model, err := mm.modelStore.GetModel(modelID) |
|
|
if err != nil { |
|
|
log.Errorf("获取模型详情失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型不存在", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
if model.Username != username { |
|
|
log.Errorf("无权限查看模型: trace_id=%s, modelID=%s, username=%s, owner=%s", traceID, modelID, username, model.Username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "无权限查看此模型", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Debugf("获取模型详情成功: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) |
|
|
|
|
|
|
|
|
result := map[string]interface{}{ |
|
|
"model_id": model.ModelID, |
|
|
"model": map[string]interface{}{ |
|
|
"model": model.ModelName, |
|
|
|
|
|
"token": maskToken(model.Token), |
|
|
"base_url": model.BaseURL, |
|
|
"note": model.Note, |
|
|
"limit": model.Limit, |
|
|
}, |
|
|
} |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 0, |
|
|
"message": "获取模型详情成功", |
|
|
"data": result, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func HandleCreateModel(c *gin.Context, mm *ModelManager) { |
|
|
traceID := getTraceID(c) |
|
|
username := c.GetString("username") |
|
|
|
|
|
|
|
|
var req CreateModelRequest |
|
|
if err := c.ShouldBindJSON(&req); err != nil { |
|
|
log.Errorf("请求参数解析失败: trace_id=%s, username=%s, error=%v", traceID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "请求参数错误: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
if req.ModelID == "" { |
|
|
log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型ID不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if req.Model.Model == "" { |
|
|
log.Errorf("模型名称为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型名称不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if req.Model.Token == "" { |
|
|
log.Errorf("API Token为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "API Token不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if req.Model.BaseURL == "" { |
|
|
log.Errorf("基础URL为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "基础URL不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
if req.Model.Limit == 0 { |
|
|
req.Model.Limit = 1000 |
|
|
} |
|
|
|
|
|
log.Debugf("用户请求创建模型: trace_id=%s, modelID=%s, modelName=%s, username=%s", traceID, req.ModelID, req.Model.Model, username) |
|
|
|
|
|
|
|
|
exists, err := mm.modelStore.CheckModelExists(req.ModelID) |
|
|
if err != nil { |
|
|
log.Errorf("检查模型是否存在失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "检查模型失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if exists { |
|
|
log.Errorf("模型已存在: trace_id=%s, modelID=%s, username=%s", traceID, req.ModelID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型ID已存在", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
ai := models.NewOpenAI(req.Model.Token, req.Model.Model, req.Model.BaseURL) |
|
|
err = ai.Vaild(context.Background()) |
|
|
if err != nil { |
|
|
log.Errorf("模型校验失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型校验失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
model := &database.Model{ |
|
|
ModelID: req.ModelID, |
|
|
Username: username, |
|
|
ModelName: req.Model.Model, |
|
|
Token: req.Model.Token, |
|
|
BaseURL: req.Model.BaseURL, |
|
|
Note: req.Model.Note, |
|
|
Limit: req.Model.Limit, |
|
|
} |
|
|
|
|
|
err = mm.modelStore.CreateModel(model) |
|
|
if err != nil { |
|
|
log.Errorf("创建模型失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, req.ModelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "创建模型失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Debugf("创建模型成功: trace_id=%s, modelID=%s, modelName=%s, username=%s", traceID, req.ModelID, req.Model.Model, username) |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 0, |
|
|
"message": "模型创建成功", |
|
|
"data": nil, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func HandleUpdateModel(c *gin.Context, mm *ModelManager) { |
|
|
traceID := getTraceID(c) |
|
|
modelID := c.Param("modelId") |
|
|
username := c.GetString("username") |
|
|
|
|
|
|
|
|
if modelID == "" { |
|
|
log.Errorf("模型ID为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型ID不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
var req UpdateModelRequest |
|
|
if err := c.ShouldBindJSON(&req); err != nil { |
|
|
log.Errorf("请求参数解析失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "请求参数错误: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Infof("用户请求更新模型: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) |
|
|
|
|
|
|
|
|
exists, err := mm.modelStore.CheckModelExistsByUser(modelID, username) |
|
|
if err != nil { |
|
|
log.Errorf("检查模型权限失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "检查模型权限失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if !exists { |
|
|
log.Errorf("模型不存在或无权限: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型不存在或无权限", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
updates := map[string]interface{}{ |
|
|
"model_name": req.Model.Model, |
|
|
"note": req.Model.Note, |
|
|
"limit": req.Model.Limit, |
|
|
} |
|
|
if req.Model.Token != "" && req.Model.Token != maskedToken { |
|
|
updates["token"] = req.Model.Token |
|
|
} |
|
|
if req.Model.BaseURL != "" { |
|
|
updates["base_url"] = req.Model.BaseURL |
|
|
} |
|
|
|
|
|
err = mm.modelStore.UpdateModel(modelID, username, updates) |
|
|
if err != nil { |
|
|
log.Errorf("更新模型失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "更新模型失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Infof("更新模型成功: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 0, |
|
|
"message": "模型更新成功", |
|
|
"data": nil, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func HandleDeleteModel(c *gin.Context, mm *ModelManager) { |
|
|
traceID := getTraceID(c) |
|
|
username := c.GetString("username") |
|
|
|
|
|
|
|
|
var req DeleteModelRequest |
|
|
if err := c.ShouldBindJSON(&req); err != nil { |
|
|
log.Errorf("请求参数解析失败: trace_id=%s, username=%s, error=%v", traceID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "请求参数错误: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if len(req.ModelIDs) == 0 { |
|
|
log.Errorf("模型ID列表为空: trace_id=%s, username=%s", traceID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型ID列表不能为空", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Infof("用户请求删除模型: trace_id=%s, modelIDs=%v, username=%s", traceID, req.ModelIDs, username) |
|
|
|
|
|
|
|
|
for _, modelID := range req.ModelIDs { |
|
|
exists, err := mm.modelStore.CheckModelExistsByUser(modelID, username) |
|
|
if err != nil { |
|
|
log.Errorf("检查模型权限失败: trace_id=%s, modelID=%s, username=%s, error=%v", traceID, modelID, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "检查模型权限失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
if !exists { |
|
|
log.Errorf("模型不存在或无权限: trace_id=%s, modelID=%s, username=%s", traceID, modelID, username) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "模型不存在或无权限", |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
deletedCount, err := mm.modelStore.BatchDeleteModels(req.ModelIDs, username) |
|
|
if err != nil { |
|
|
log.Errorf("删除模型失败: trace_id=%s, modelIDs=%v, username=%s, error=%v", traceID, req.ModelIDs, username, err) |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 1, |
|
|
"message": "删除模型失败: " + err.Error(), |
|
|
"data": nil, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
log.Infof("删除模型成功: trace_id=%s, modelIDs=%v, username=%s, deletedCount=%d", traceID, req.ModelIDs, username, deletedCount) |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": 0, |
|
|
"message": "删除成功", |
|
|
"data": nil, |
|
|
}) |
|
|
} |
|
|
|