new-api / controller /channel.go
liuzhao521
Deploy New API v0.9.25+ (commit b47cf4ef) to HuggingFace Spaces
4674012
package controller
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/dto"
"github.com/QuantumNous/new-api/model"
"github.com/QuantumNous/new-api/relay/channel/volcengine"
"github.com/QuantumNous/new-api/service"
"github.com/gin-gonic/gin"
)
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Permission []struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
AllowCreateEngine bool `json:"allow_create_engine"`
AllowSampling bool `json:"allow_sampling"`
AllowLogprobs bool `json:"allow_logprobs"`
AllowSearchIndices bool `json:"allow_search_indices"`
AllowView bool `json:"allow_view"`
AllowFineTuning bool `json:"allow_fine_tuning"`
Organization string `json:"organization"`
Group string `json:"group"`
IsBlocking bool `json:"is_blocking"`
} `json:"permission"`
Root string `json:"root"`
Parent string `json:"parent"`
}
type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"`
Success bool `json:"success"`
}
func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) {
case "enabled", "1":
return common.ChannelStatusEnabled
case "disabled", "0":
return 0
default:
return -1
}
}
func clearChannelInfo(channel *model.Channel) {
if channel.ChannelInfo.IsMultiKey {
channel.ChannelInfo.MultiKeyDisabledReason = nil
channel.ChannelInfo.MultiKeyDisabledTime = nil
}
}
func GetAllChannels(c *gin.Context) {
pageInfo := common.GetPageQuery(c)
channelData := make([]*model.Channel, 0)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
statusParam := c.Query("status")
// statusFilter: -1 all, 1 enabled, 0 disabled (include auto & manual)
statusFilter := parseStatusFilter(statusParam)
// type filter
typeStr := c.Query("type")
typeFilter := -1
if typeStr != "" {
if t, err := strconv.Atoi(typeStr); err == nil {
typeFilter = t
}
}
var total int64
if enableTagMode {
tags, err := model.GetPaginatedTags(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
for _, tag := range tags {
if tag == nil || *tag == "" {
continue
}
tagChannels, err := model.GetChannelsByTag(*tag, idSort, false)
if err != nil {
continue
}
filtered := make([]*model.Channel, 0)
for _, ch := range tagChannels {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
if typeFilter >= 0 && ch.Type != typeFilter {
continue
}
filtered = append(filtered, ch)
}
channelData = append(channelData, filtered...)
}
total, _ = model.CountAllTags()
} else {
baseQuery := model.DB.Model(&model.Channel{})
if typeFilter >= 0 {
baseQuery = baseQuery.Where("type = ?", typeFilter)
}
if statusFilter == common.ChannelStatusEnabled {
baseQuery = baseQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
baseQuery = baseQuery.Where("status != ?", common.ChannelStatusEnabled)
}
baseQuery.Count(&total)
order := "priority desc"
if idSort {
order = "id desc"
}
err := baseQuery.Order(order).Limit(pageInfo.GetPageSize()).Offset(pageInfo.GetStartIdx()).Omit("key").Find(&channelData).Error
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
}
for _, datum := range channelData {
clearChannelInfo(datum)
}
countQuery := model.DB.Model(&model.Channel{})
if statusFilter == common.ChannelStatusEnabled {
countQuery = countQuery.Where("status = ?", common.ChannelStatusEnabled)
} else if statusFilter == 0 {
countQuery = countQuery.Where("status != ?", common.ChannelStatusEnabled)
}
var results []struct {
Type int64
Count int64
}
_ = countQuery.Select("type, count(*) as count").Group("type").Find(&results).Error
typeCounts := make(map[int64]int64)
for _, r := range results {
typeCounts[r.Type] = r.Count
}
common.ApiSuccess(c, gin.H{
"items": channelData,
"total": total,
"page": pageInfo.GetPage(),
"page_size": pageInfo.GetPageSize(),
"type_counts": typeCounts,
})
return
}
func FetchUpstreamModels(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
common.ApiError(c, err)
return
}
baseURL := constant.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
var url string
switch channel.Type {
case constant.ChannelTypeGemini:
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
case constant.ChannelTypeZhipu_v4:
url = fmt.Sprintf("%s/api/paas/v4/models", baseURL)
case constant.ChannelTypeVolcEngine:
if baseURL == volcengine.DoubaoCodingPlan {
url = fmt.Sprintf("%s/v1/models", volcengine.DoubaoCodingPlanOpenAIBaseURL)
} else {
url = fmt.Sprintf("%s/v1/models", baseURL)
}
default:
url = fmt.Sprintf("%s/v1/models", baseURL)
}
// 获取用于请求的可用密钥(多密钥渠道优先使用启用状态的密钥)
key, _, apiErr := channel.GetNextEnabledKey()
if apiErr != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("获取渠道密钥失败: %s", apiErr.Error()),
})
return
}
key = strings.TrimSpace(key)
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
var body []byte
switch channel.Type {
case constant.ChannelTypeAnthropic:
body, err = GetResponseBody("GET", url, channel, GetClaudeAuthHeader(key))
default:
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
}
if err != nil {
common.ApiError(c, err)
return
}
var result OpenAIModelsResponse
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
var ids []string
for _, model := range result.Data {
id := model.ID
if channel.Type == constant.ChannelTypeGemini {
id = strings.TrimPrefix(id, "models/")
}
ids = append(ids, id)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": ids,
})
}
func FixChannelsAbilities(c *gin.Context) {
success, fails, err := model.FixAbility()
if err != nil {
common.ApiError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"success": success,
"fails": fails,
},
})
}
func SearchChannels(c *gin.Context) {
keyword := c.Query("keyword")
group := c.Query("group")
modelKeyword := c.Query("model")
statusParam := c.Query("status")
statusFilter := parseStatusFilter(statusParam)
idSort, _ := strconv.ParseBool(c.Query("id_sort"))
enableTagMode, _ := strconv.ParseBool(c.Query("tag_mode"))
channelData := make([]*model.Channel, 0)
if enableTagMode {
tags, err := model.SearchTags(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
for _, tag := range tags {
if tag != nil && *tag != "" {
tagChannel, err := model.GetChannelsByTag(*tag, idSort, false)
if err == nil {
channelData = append(channelData, tagChannel...)
}
}
}
} else {
channels, err := model.SearchChannels(keyword, group, modelKeyword, idSort)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channelData = channels
}
if statusFilter == common.ChannelStatusEnabled || statusFilter == 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if statusFilter == common.ChannelStatusEnabled && ch.Status != common.ChannelStatusEnabled {
continue
}
if statusFilter == 0 && ch.Status == common.ChannelStatusEnabled {
continue
}
filtered = append(filtered, ch)
}
channelData = filtered
}
// calculate type counts for search results
typeCounts := make(map[int64]int64)
for _, channel := range channelData {
typeCounts[int64(channel.Type)]++
}
typeParam := c.Query("type")
typeFilter := -1
if typeParam != "" {
if tp, err := strconv.Atoi(typeParam); err == nil {
typeFilter = tp
}
}
if typeFilter >= 0 {
filtered := make([]*model.Channel, 0, len(channelData))
for _, ch := range channelData {
if ch.Type == typeFilter {
filtered = append(filtered, ch)
}
}
channelData = filtered
}
page, _ := strconv.Atoi(c.DefaultQuery("p", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
if page < 1 {
page = 1
}
if pageSize <= 0 {
pageSize = 20
}
total := len(channelData)
startIdx := (page - 1) * pageSize
if startIdx > total {
startIdx = total
}
endIdx := startIdx + pageSize
if endIdx > total {
endIdx = total
}
pagedData := channelData[startIdx:endIdx]
for _, datum := range pagedData {
clearChannelInfo(datum)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": gin.H{
"items": pagedData,
"total": total,
"type_counts": typeCounts,
},
})
return
}
func GetChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(id, false)
if err != nil {
common.ApiError(c, err)
return
}
if channel != nil {
clearChannelInfo(channel)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
// GetChannelKey 获取渠道密钥(需要通过安全验证中间件)
// 此函数依赖 SecureVerificationRequired 中间件,确保用户已通过安全验证
func GetChannelKey(c *gin.Context) {
userId := c.GetInt("id")
channelId, err := strconv.Atoi(c.Param("id"))
if err != nil {
common.ApiError(c, fmt.Errorf("渠道ID格式错误: %v", err))
return
}
// 获取渠道信息(包含密钥)
channel, err := model.GetChannelById(channelId, true)
if err != nil {
common.ApiError(c, fmt.Errorf("获取渠道信息失败: %v", err))
return
}
if channel == nil {
common.ApiError(c, fmt.Errorf("渠道不存在"))
return
}
// 记录操作日志
model.RecordLog(userId, model.LogTypeSystem, fmt.Sprintf("查看渠道密钥信息 (渠道ID: %d)", channelId))
// 返回渠道密钥
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "获取成功",
"data": map[string]interface{}{
"key": channel.Key,
},
})
}
// validateTwoFactorAuth 统一的2FA验证函数
func validateTwoFactorAuth(twoFA *model.TwoFA, code string) bool {
// 尝试验证TOTP
if cleanCode, err := common.ValidateNumericCode(code); err == nil {
if isValid, _ := twoFA.ValidateTOTPAndUpdateUsage(cleanCode); isValid {
return true
}
}
// 尝试验证备用码
if isValid, err := twoFA.ValidateBackupCodeAndUpdateUsage(code); err == nil && isValid {
return true
}
return false
}
// validateChannel 通用的渠道校验函数
func validateChannel(channel *model.Channel, isAdd bool) error {
// 校验 channel settings
if err := channel.ValidateSettings(); err != nil {
return fmt.Errorf("渠道额外设置[channel setting] 格式错误:%s", err.Error())
}
// 如果是添加操作,检查 channel 和 key 是否为空
if isAdd {
if channel == nil || channel.Key == "" {
return fmt.Errorf("channel cannot be empty")
}
// 检查模型名称长度是否超过 255
for _, m := range channel.GetModels() {
if len(m) > 255 {
return fmt.Errorf("模型名称过长: %s", m)
}
}
}
// VertexAI 特殊校验
if channel.Type == constant.ChannelTypeVertexAi {
if channel.Other == "" {
return fmt.Errorf("部署地区不能为空")
}
regionMap, err := common.StrToMap(channel.Other)
if err != nil {
return fmt.Errorf("部署地区必须是标准的Json格式,例如{\"default\": \"us-central1\", \"region2\": \"us-east1\"}")
}
if regionMap["default"] == nil {
return fmt.Errorf("部署地区必须包含default字段")
}
}
return nil
}
type AddChannelRequest struct {
Mode string `json:"mode"`
MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"`
BatchAddSetKeyPrefix2Name bool `json:"batch_add_set_key_prefix_2_name"`
Channel *model.Channel `json:"channel"`
}
func getVertexArrayKeys(keys string) ([]string, error) {
if keys == "" {
return nil, nil
}
var keyArray []interface{}
err := common.Unmarshal([]byte(keys), &keyArray)
if err != nil {
return nil, fmt.Errorf("批量添加 Vertex AI 必须使用标准的JsonArray格式,例如[{key1}, {key2}...],请检查输入: %w", err)
}
cleanKeys := make([]string, 0, len(keyArray))
for _, key := range keyArray {
var keyStr string
switch v := key.(type) {
case string:
keyStr = strings.TrimSpace(v)
default:
bytes, err := json.Marshal(v)
if err != nil {
return nil, fmt.Errorf("Vertex AI key JSON 编码失败: %w", err)
}
keyStr = string(bytes)
}
if keyStr != "" {
cleanKeys = append(cleanKeys, keyStr)
}
}
if len(cleanKeys) == 0 {
return nil, fmt.Errorf("批量添加 Vertex AI 的 keys 不能为空")
}
return cleanKeys, nil
}
func AddChannel(c *gin.Context) {
addChannelRequest := AddChannelRequest{}
err := c.ShouldBindJSON(&addChannelRequest)
if err != nil {
common.ApiError(c, err)
return
}
// 使用统一的校验函数
if err := validateChannel(addChannelRequest.Channel, true); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.CreatedTime = common.GetTimestamp()
keys := make([]string, 0)
switch addChannelRequest.Mode {
case "multi_to_single":
addChannelRequest.Channel.ChannelInfo.IsMultiKey = true
addChannelRequest.Channel.ChannelInfo.MultiKeyMode = addChannelRequest.MultiKeyMode
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
array, err := getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(array)
addChannelRequest.Channel.Key = strings.Join(array, "\n")
} else {
cleanKeys := make([]string, 0)
for _, key := range strings.Split(addChannelRequest.Channel.Key, "\n") {
if key == "" {
continue
}
key = strings.TrimSpace(key)
cleanKeys = append(cleanKeys, key)
}
addChannelRequest.Channel.ChannelInfo.MultiKeySize = len(cleanKeys)
addChannelRequest.Channel.Key = strings.Join(cleanKeys, "\n")
}
keys = []string{addChannelRequest.Channel.Key}
case "batch":
if addChannelRequest.Channel.Type == constant.ChannelTypeVertexAi && addChannelRequest.Channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// multi json
keys, err = getVertexArrayKeys(addChannelRequest.Channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
keys = strings.Split(addChannelRequest.Channel.Key, "\n")
}
case "single":
keys = []string{addChannelRequest.Channel.Key}
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的添加模式",
})
return
}
channels := make([]model.Channel, 0, len(keys))
for _, key := range keys {
if key == "" {
continue
}
localChannel := addChannelRequest.Channel
localChannel.Key = key
if addChannelRequest.BatchAddSetKeyPrefix2Name && len(keys) > 1 {
keyPrefix := localChannel.Key
if len(localChannel.Key) > 8 {
keyPrefix = localChannel.Key[:8]
}
localChannel.Name = fmt.Sprintf("%s %s", localChannel.Name, keyPrefix)
}
channels = append(channels, *localChannel)
}
err = model.BatchInsertChannels(channels)
if err != nil {
common.ApiError(c, err)
return
}
service.ResetProxyClientCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteChannel(c *gin.Context) {
id, _ := strconv.Atoi(c.Param("id"))
channel := model.Channel{Id: id}
err := channel.Delete()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
type ChannelTag struct {
Tag string `json:"tag"`
NewTag *string `json:"new_tag"`
Priority *int64 `json:"priority"`
Weight *uint `json:"weight"`
ModelMapping *string `json:"model_mapping"`
Models *string `json:"models"`
Groups *string `json:"groups"`
ParamOverride *string `json:"param_override"`
HeaderOverride *string `json:"header_override"`
}
func DisableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.DisableChannelByTag(channelTag.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EnableTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil || channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.EnableChannelByTag(channelTag.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func EditTagChannels(c *gin.Context) {
channelTag := ChannelTag{}
err := c.ShouldBindJSON(&channelTag)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
if channelTag.Tag == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
if channelTag.ParamOverride != nil {
trimmed := strings.TrimSpace(*channelTag.ParamOverride)
if trimmed != "" && !json.Valid([]byte(trimmed)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数覆盖必须是合法的 JSON 格式",
})
return
}
channelTag.ParamOverride = common.GetPointer[string](trimmed)
}
if channelTag.HeaderOverride != nil {
trimmed := strings.TrimSpace(*channelTag.HeaderOverride)
if trimmed != "" && !json.Valid([]byte(trimmed)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "请求头覆盖必须是合法的 JSON 格式",
})
return
}
channelTag.HeaderOverride = common.GetPointer[string](trimmed)
}
err = model.EditChannelByTag(channelTag.Tag, channelTag.NewTag, channelTag.ModelMapping, channelTag.Models, channelTag.Groups, channelTag.Priority, channelTag.Weight, channelTag.ParamOverride, channelTag.HeaderOverride)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
type ChannelBatch struct {
Ids []int `json:"ids"`
Tag *string `json:"tag"`
}
func DeleteChannelBatch(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchDeleteChannels(channelBatch.Ids)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
type PatchChannel struct {
model.Channel
MultiKeyMode *string `json:"multi_key_mode"`
KeyMode *string `json:"key_mode"` // 多key模式下密钥覆盖或者追加
}
func UpdateChannel(c *gin.Context) {
channel := PatchChannel{}
err := c.ShouldBindJSON(&channel)
if err != nil {
common.ApiError(c, err)
return
}
// 使用统一的校验函数
if err := validateChannel(&channel.Channel, false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// Preserve existing ChannelInfo to ensure multi-key channels keep correct state even if the client does not send ChannelInfo in the request.
originChannel, err := model.GetChannelById(channel.Id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// Always copy the original ChannelInfo so that fields like IsMultiKey and MultiKeySize are retained.
channel.ChannelInfo = originChannel.ChannelInfo
// If the request explicitly specifies a new MultiKeyMode, apply it on top of the original info.
if channel.MultiKeyMode != nil && *channel.MultiKeyMode != "" {
channel.ChannelInfo.MultiKeyMode = constant.MultiKeyMode(*channel.MultiKeyMode)
}
// 处理多key模式下的密钥追加/覆盖逻辑
if channel.KeyMode != nil && channel.ChannelInfo.IsMultiKey {
switch *channel.KeyMode {
case "append":
// 追加模式:将新密钥添加到现有密钥列表
if originChannel.Key != "" {
var newKeys []string
var existingKeys []string
// 解析现有密钥
if strings.HasPrefix(strings.TrimSpace(originChannel.Key), "[") {
// JSON数组格式
var arr []json.RawMessage
if err := json.Unmarshal([]byte(strings.TrimSpace(originChannel.Key)), &arr); err == nil {
existingKeys = make([]string, len(arr))
for i, v := range arr {
existingKeys[i] = string(v)
}
}
} else {
// 换行分隔格式
existingKeys = strings.Split(strings.Trim(originChannel.Key, "\n"), "\n")
}
// 处理 Vertex AI 的特殊情况
if channel.Type == constant.ChannelTypeVertexAi && channel.GetOtherSettings().VertexKeyType != dto.VertexKeyTypeAPIKey {
// 尝试解析新密钥为JSON数组
if strings.HasPrefix(strings.TrimSpace(channel.Key), "[") {
array, err := getVertexArrayKeys(channel.Key)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "追加密钥解析失败: " + err.Error(),
})
return
}
newKeys = array
} else {
// 单个JSON密钥
newKeys = []string{channel.Key}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
} else {
// 普通渠道的处理
inputKeys := strings.Split(channel.Key, "\n")
for _, key := range inputKeys {
key = strings.TrimSpace(key)
if key != "" {
newKeys = append(newKeys, key)
}
}
// 合并密钥
allKeys := append(existingKeys, newKeys...)
channel.Key = strings.Join(allKeys, "\n")
}
}
case "replace":
// 覆盖模式:直接使用新密钥(默认行为,不需要特殊处理)
}
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
service.ResetProxyClientCache()
channel.Key = ""
clearChannelInfo(&channel.Channel)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channel,
})
return
}
func FetchModels(c *gin.Context) {
var req struct {
BaseURL string `json:"base_url"`
Type int `json:"type"`
Key string `json:"key"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "Invalid request",
})
return
}
baseURL := req.BaseURL
if baseURL == "" {
baseURL = constant.ChannelBaseURLs[req.Type]
}
client := &http.Client{}
url := fmt.Sprintf("%s/v1/models", baseURL)
request, err := http.NewRequest("GET", url, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
// remove line breaks and extra spaces.
key := strings.TrimSpace(req.Key)
// If the key contains a line break, only take the first part.
key = strings.Split(key, "\n")[0]
request.Header.Set("Authorization", "Bearer "+key)
response, err := client.Do(request)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
//check status code
if response.StatusCode != http.StatusOK {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": "Failed to fetch models",
})
return
}
defer response.Body.Close()
var result struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var models []string
for _, model := range result.Data {
models = append(models, model.ID)
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": models,
})
}
func BatchSetChannelTag(c *gin.Context) {
channelBatch := ChannelBatch{}
err := c.ShouldBindJSON(&channelBatch)
if err != nil || len(channelBatch.Ids) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "参数错误",
})
return
}
err = model.BatchSetChannelTag(channelBatch.Ids, channelBatch.Tag)
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": len(channelBatch.Ids),
})
return
}
func GetTagModels(c *gin.Context) {
tag := c.Query("tag")
if tag == "" {
c.JSON(http.StatusBadRequest, gin.H{
"success": false,
"message": "tag不能为空",
})
return
}
channels, err := model.GetChannelsByTag(tag, false, false) // idSort=false, selectAll=false
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"success": false,
"message": err.Error(),
})
return
}
var longestModels string
maxLength := 0
// Find the longest models string among all channels with the given tag
for _, channel := range channels {
if channel.Models != "" {
currentModels := strings.Split(channel.Models, ",")
if len(currentModels) > maxLength {
maxLength = len(currentModels)
longestModels = channel.Models
}
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": longestModels,
})
return
}
// CopyChannel handles cloning an existing channel with its key.
// POST /api/channel/copy/:id
// Optional query params:
//
// suffix - string appended to the original name (default "_复制")
// reset_balance - bool, when true will reset balance & used_quota to 0 (default true)
func CopyChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": "invalid id"})
return
}
suffix := c.DefaultQuery("suffix", "_复制")
resetBalance := true
if rbStr := c.DefaultQuery("reset_balance", "true"); rbStr != "" {
if v, err := strconv.ParseBool(rbStr); err == nil {
resetBalance = v
}
}
// fetch original channel with key
origin, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
// clone channel
clone := *origin // shallow copy is sufficient as we will overwrite primitives
clone.Id = 0 // let DB auto-generate
clone.CreatedTime = common.GetTimestamp()
clone.Name = origin.Name + suffix
clone.TestTime = 0
clone.ResponseTime = 0
if resetBalance {
clone.Balance = 0
clone.UsedQuota = 0
}
// insert
if err := model.BatchInsertChannels([]model.Channel{clone}); err != nil {
c.JSON(http.StatusOK, gin.H{"success": false, "message": err.Error()})
return
}
model.InitChannelCache()
// success
c.JSON(http.StatusOK, gin.H{"success": true, "message": "", "data": gin.H{"id": clone.Id}})
}
// MultiKeyManageRequest represents the request for multi-key management operations
type MultiKeyManageRequest struct {
ChannelId int `json:"channel_id"`
Action string `json:"action"` // "disable_key", "enable_key", "delete_key", "delete_disabled_keys", "get_key_status"
KeyIndex *int `json:"key_index,omitempty"` // for disable_key, enable_key, and delete_key actions
Page int `json:"page,omitempty"` // for get_key_status pagination
PageSize int `json:"page_size,omitempty"` // for get_key_status pagination
Status *int `json:"status,omitempty"` // for get_key_status filtering: 1=enabled, 2=manual_disabled, 3=auto_disabled, nil=all
}
// MultiKeyStatusResponse represents the response for key status query
type MultiKeyStatusResponse struct {
Keys []KeyStatus `json:"keys"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
// Statistics
EnabledCount int `json:"enabled_count"`
ManualDisabledCount int `json:"manual_disabled_count"`
AutoDisabledCount int `json:"auto_disabled_count"`
}
type KeyStatus struct {
Index int `json:"index"`
Status int `json:"status"` // 1: enabled, 2: disabled
DisabledTime int64 `json:"disabled_time,omitempty"`
Reason string `json:"reason,omitempty"`
KeyPreview string `json:"key_preview"` // first 10 chars of key for identification
}
// ManageMultiKeys handles multi-key management operations
func ManageMultiKeys(c *gin.Context) {
request := MultiKeyManageRequest{}
err := c.ShouldBindJSON(&request)
if err != nil {
common.ApiError(c, err)
return
}
channel, err := model.GetChannelById(request.ChannelId, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "渠道不存在",
})
return
}
if !channel.ChannelInfo.IsMultiKey {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该渠道不是多密钥模式",
})
return
}
lock := model.GetChannelPollingLock(channel.Id)
lock.Lock()
defer lock.Unlock()
switch request.Action {
case "get_key_status":
keys := channel.GetKeys()
// Default pagination parameters
page := request.Page
pageSize := request.PageSize
if page <= 0 {
page = 1
}
if pageSize <= 0 {
pageSize = 50 // Default page size
}
// Statistics for all keys (unchanged by filtering)
var enabledCount, manualDisabledCount, autoDisabledCount int
// Build all key status data first
var allKeyStatusList []KeyStatus
for i, key := range keys {
status := 1 // default enabled
var disabledTime int64
var reason string
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// Count for statistics (all keys)
switch status {
case 1:
enabledCount++
case 2:
manualDisabledCount++
case 3:
autoDisabledCount++
}
if status != 1 {
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
disabledTime = channel.ChannelInfo.MultiKeyDisabledTime[i]
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
reason = channel.ChannelInfo.MultiKeyDisabledReason[i]
}
}
// Create key preview (first 10 chars)
keyPreview := key
if len(key) > 10 {
keyPreview = key[:10] + "..."
}
allKeyStatusList = append(allKeyStatusList, KeyStatus{
Index: i,
Status: status,
DisabledTime: disabledTime,
Reason: reason,
KeyPreview: keyPreview,
})
}
// Apply status filter if specified
var filteredKeyStatusList []KeyStatus
if request.Status != nil {
for _, keyStatus := range allKeyStatusList {
if keyStatus.Status == *request.Status {
filteredKeyStatusList = append(filteredKeyStatusList, keyStatus)
}
}
} else {
filteredKeyStatusList = allKeyStatusList
}
// Calculate pagination based on filtered results
filteredTotal := len(filteredKeyStatusList)
totalPages := (filteredTotal + pageSize - 1) / pageSize
if totalPages == 0 {
totalPages = 1
}
if page > totalPages {
page = totalPages
}
// Calculate range for current page
start := (page - 1) * pageSize
end := start + pageSize
if end > filteredTotal {
end = filteredTotal
}
// Get the page data
var pageKeyStatusList []KeyStatus
if start < filteredTotal {
pageKeyStatusList = filteredKeyStatusList[start:end]
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": MultiKeyStatusResponse{
Keys: pageKeyStatusList,
Total: filteredTotal, // Total of filtered results
Page: page,
PageSize: pageSize,
TotalPages: totalPages,
EnabledCount: enabledCount, // Overall statistics
ManualDisabledCount: manualDisabledCount, // Overall statistics
AutoDisabledCount: autoDisabledCount, // Overall statistics
},
})
return
case "disable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要禁用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
channel.ChannelInfo.MultiKeyStatusList[keyIndex] = 2 // disabled
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已禁用",
})
return
case "enable_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要启用的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
// 从状态列表中删除该密钥的记录,使其回到默认启用状态
if channel.ChannelInfo.MultiKeyStatusList != nil {
delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
delete(channel.ChannelInfo.MultiKeyDisabledTime, keyIndex)
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
delete(channel.ChannelInfo.MultiKeyDisabledReason, keyIndex)
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已启用",
})
return
case "enable_all_keys":
// 清空所有禁用状态,使所有密钥回到默认启用状态
var enabledCount int
if channel.ChannelInfo.MultiKeyStatusList != nil {
enabledCount = len(channel.ChannelInfo.MultiKeyStatusList)
}
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已启用 %d 个密钥", enabledCount),
})
return
case "disable_all_keys":
// 禁用所有启用的密钥
if channel.ChannelInfo.MultiKeyStatusList == nil {
channel.ChannelInfo.MultiKeyStatusList = make(map[int]int)
}
if channel.ChannelInfo.MultiKeyDisabledTime == nil {
channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64)
}
if channel.ChannelInfo.MultiKeyDisabledReason == nil {
channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string)
}
var disabledCount int
for i := 0; i < channel.ChannelInfo.MultiKeySize; i++ {
status := 1 // default enabled
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
// 只禁用当前启用的密钥
if status == 1 {
channel.ChannelInfo.MultiKeyStatusList[i] = 2 // disabled
disabledCount++
}
}
if disabledCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有可禁用的密钥",
})
return
}
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已禁用 %d 个密钥", disabledCount),
})
return
case "delete_key":
if request.KeyIndex == nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "未指定要删除的密钥索引",
})
return
}
keyIndex := *request.KeyIndex
if keyIndex < 0 || keyIndex >= channel.ChannelInfo.MultiKeySize {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "密钥索引超出范围",
})
return
}
keys := channel.GetKeys()
var remainingKeys []string
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
// 跳过要删除的密钥
if i == keyIndex {
continue
}
remainingKeys = append(remainingKeys, key)
// 保留其他密钥的状态信息,重新索引
if channel.ChannelInfo.MultiKeyStatusList != nil {
if status, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists && status != 1 {
newStatusList[newIndex] = status
}
}
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
newIndex++
}
if len(remainingKeys) == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除最后一个密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "密钥已删除",
})
return
case "delete_disabled_keys":
keys := channel.GetKeys()
var remainingKeys []string
var deletedCount int
var newStatusList = make(map[int]int)
var newDisabledTime = make(map[int]int64)
var newDisabledReason = make(map[int]string)
newIndex := 0
for i, key := range keys {
status := 1 // default enabled
if channel.ChannelInfo.MultiKeyStatusList != nil {
if s, exists := channel.ChannelInfo.MultiKeyStatusList[i]; exists {
status = s
}
}
// 只删除自动禁用(status == 3)的密钥,保留启用(status == 1)和手动禁用(status == 2)的密钥
if status == 3 {
deletedCount++
} else {
remainingKeys = append(remainingKeys, key)
// 保留非自动禁用密钥的状态信息,重新索引
if status != 1 {
newStatusList[newIndex] = status
if channel.ChannelInfo.MultiKeyDisabledTime != nil {
if t, exists := channel.ChannelInfo.MultiKeyDisabledTime[i]; exists {
newDisabledTime[newIndex] = t
}
}
if channel.ChannelInfo.MultiKeyDisabledReason != nil {
if r, exists := channel.ChannelInfo.MultiKeyDisabledReason[i]; exists {
newDisabledReason[newIndex] = r
}
}
}
newIndex++
}
}
if deletedCount == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "没有需要删除的自动禁用密钥",
})
return
}
// Update channel with remaining keys
channel.Key = strings.Join(remainingKeys, "\n")
channel.ChannelInfo.MultiKeySize = len(remainingKeys)
channel.ChannelInfo.MultiKeyStatusList = newStatusList
channel.ChannelInfo.MultiKeyDisabledTime = newDisabledTime
channel.ChannelInfo.MultiKeyDisabledReason = newDisabledReason
err = channel.Update()
if err != nil {
common.ApiError(c, err)
return
}
model.InitChannelCache()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": fmt.Sprintf("已删除 %d 个自动禁用的密钥", deletedCount),
"data": deletedCount,
})
return
default:
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不支持的操作",
})
return
}
}