Spaces:
Build error
Build error
| package model | |
| import ( | |
| "database/sql/driver" | |
| "encoding/json" | |
| "errors" | |
| "fmt" | |
| "math/rand" | |
| "one-api/common" | |
| "one-api/constant" | |
| "one-api/dto" | |
| "one-api/types" | |
| "strings" | |
| "sync" | |
| "github.com/samber/lo" | |
| "gorm.io/gorm" | |
| ) | |
| type Channel struct { | |
| Id int `json:"id"` | |
| Type int `json:"type" gorm:"default:0"` | |
| Key string `json:"key" gorm:"not null"` | |
| OpenAIOrganization *string `json:"openai_organization"` | |
| TestModel *string `json:"test_model"` | |
| Status int `json:"status" gorm:"default:1"` | |
| Name string `json:"name" gorm:"index"` | |
| Weight *uint `json:"weight" gorm:"default:0"` | |
| CreatedTime int64 `json:"created_time" gorm:"bigint"` | |
| TestTime int64 `json:"test_time" gorm:"bigint"` | |
| ResponseTime int `json:"response_time"` // in milliseconds | |
| BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` | |
| Other string `json:"other"` | |
| Balance float64 `json:"balance"` // in USD | |
| BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` | |
| Models string `json:"models"` | |
| Group string `json:"group" gorm:"type:varchar(64);default:'default'"` | |
| UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` | |
| ModelMapping *string `json:"model_mapping" gorm:"type:text"` | |
| //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"` | |
| StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` | |
| Priority *int64 `json:"priority" gorm:"bigint;default:0"` | |
| AutoBan *int `json:"auto_ban" gorm:"default:1"` | |
| OtherInfo string `json:"other_info"` | |
| Tag *string `json:"tag" gorm:"index"` | |
| Setting *string `json:"setting" gorm:"type:text"` // 渠道额外设置 | |
| ParamOverride *string `json:"param_override" gorm:"type:text"` | |
| HeaderOverride *string `json:"header_override" gorm:"type:text"` | |
| Remark string `json:"remark,omitempty" gorm:"type:varchar(255)" validate:"max=255"` | |
| // add after v0.8.5 | |
| ChannelInfo ChannelInfo `json:"channel_info" gorm:"type:json"` | |
| OtherSettings string `json:"settings" gorm:"column:settings"` // 其他设置,存储azure版本等不需要检索的信息,详见dto.ChannelOtherSettings | |
| // cache info | |
| Keys []string `json:"-" gorm:"-"` | |
| } | |
| type ChannelInfo struct { | |
| IsMultiKey bool `json:"is_multi_key"` // 是否多Key模式 | |
| MultiKeySize int `json:"multi_key_size"` // 多Key模式下的Key数量 | |
| MultiKeyStatusList map[int]int `json:"multi_key_status_list"` // key状态列表,key index -> status | |
| MultiKeyDisabledReason map[int]string `json:"multi_key_disabled_reason,omitempty"` // key禁用原因列表,key index -> reason | |
| MultiKeyDisabledTime map[int]int64 `json:"multi_key_disabled_time,omitempty"` // key禁用时间列表,key index -> time | |
| MultiKeyPollingIndex int `json:"multi_key_polling_index"` // 多Key模式下轮询的key索引 | |
| MultiKeyMode constant.MultiKeyMode `json:"multi_key_mode"` | |
| } | |
| // Value implements driver.Valuer interface | |
| func (c ChannelInfo) Value() (driver.Value, error) { | |
| return common.Marshal(&c) | |
| } | |
| // Scan implements sql.Scanner interface | |
| func (c *ChannelInfo) Scan(value interface{}) error { | |
| bytesValue, _ := value.([]byte) | |
| return common.Unmarshal(bytesValue, c) | |
| } | |
| func (channel *Channel) GetKeys() []string { | |
| if channel.Key == "" { | |
| return []string{} | |
| } | |
| if len(channel.Keys) > 0 { | |
| return channel.Keys | |
| } | |
| trimmed := strings.TrimSpace(channel.Key) | |
| // If the key starts with '[', try to parse it as a JSON array (e.g., for Vertex AI scenarios) | |
| if strings.HasPrefix(trimmed, "[") { | |
| var arr []json.RawMessage | |
| if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { | |
| res := make([]string, len(arr)) | |
| for i, v := range arr { | |
| res[i] = string(v) | |
| } | |
| return res | |
| } | |
| } | |
| // Otherwise, fall back to splitting by newline | |
| keys := strings.Split(strings.Trim(channel.Key, "\n"), "\n") | |
| return keys | |
| } | |
| func (channel *Channel) GetNextEnabledKey() (string, int, *types.NewAPIError) { | |
| // If not in multi-key mode, return the original key string directly. | |
| if !channel.ChannelInfo.IsMultiKey { | |
| return channel.Key, 0, nil | |
| } | |
| // Obtain all keys (split by \n) | |
| keys := channel.GetKeys() | |
| if len(keys) == 0 { | |
| // No keys available, return error, should disable the channel | |
| return "", 0, types.NewError(errors.New("no keys available"), types.ErrorCodeChannelNoAvailableKey) | |
| } | |
| lock := GetChannelPollingLock(channel.Id) | |
| lock.Lock() | |
| defer lock.Unlock() | |
| statusList := channel.ChannelInfo.MultiKeyStatusList | |
| // helper to get key status, default to enabled when missing | |
| getStatus := func(idx int) int { | |
| if statusList == nil { | |
| return common.ChannelStatusEnabled | |
| } | |
| if status, ok := statusList[idx]; ok { | |
| return status | |
| } | |
| return common.ChannelStatusEnabled | |
| } | |
| // Collect indexes of enabled keys | |
| enabledIdx := make([]int, 0, len(keys)) | |
| for i := range keys { | |
| if getStatus(i) == common.ChannelStatusEnabled { | |
| enabledIdx = append(enabledIdx, i) | |
| } | |
| } | |
| // If no specific status list or none enabled, fall back to first key | |
| if len(enabledIdx) == 0 { | |
| return keys[0], 0, nil | |
| } | |
| switch channel.ChannelInfo.MultiKeyMode { | |
| case constant.MultiKeyModeRandom: | |
| // Randomly pick one enabled key | |
| selectedIdx := enabledIdx[rand.Intn(len(enabledIdx))] | |
| return keys[selectedIdx], selectedIdx, nil | |
| case constant.MultiKeyModePolling: | |
| // Use channel-specific lock to ensure thread-safe polling | |
| channelInfo, err := CacheGetChannelInfo(channel.Id) | |
| if err != nil { | |
| return "", 0, types.NewError(err, types.ErrorCodeGetChannelFailed, types.ErrOptionWithSkipRetry()) | |
| } | |
| //println("before polling index:", channel.ChannelInfo.MultiKeyPollingIndex) | |
| defer func() { | |
| if common.DebugEnabled { | |
| println(fmt.Sprintf("channel %d polling index: %d", channel.Id, channel.ChannelInfo.MultiKeyPollingIndex)) | |
| } | |
| if !common.MemoryCacheEnabled { | |
| _ = channel.SaveChannelInfo() | |
| } else { | |
| // CacheUpdateChannel(channel) | |
| } | |
| }() | |
| // Start from the saved polling index and look for the next enabled key | |
| start := channelInfo.MultiKeyPollingIndex | |
| if start < 0 || start >= len(keys) { | |
| start = 0 | |
| } | |
| for i := 0; i < len(keys); i++ { | |
| idx := (start + i) % len(keys) | |
| if getStatus(idx) == common.ChannelStatusEnabled { | |
| // update polling index for next call (point to the next position) | |
| channel.ChannelInfo.MultiKeyPollingIndex = (idx + 1) % len(keys) | |
| return keys[idx], idx, nil | |
| } | |
| } | |
| // Fallback – should not happen, but return first enabled key | |
| return keys[enabledIdx[0]], enabledIdx[0], nil | |
| default: | |
| // Unknown mode, default to first enabled key (or original key string) | |
| return keys[enabledIdx[0]], enabledIdx[0], nil | |
| } | |
| } | |
| func (channel *Channel) SaveChannelInfo() error { | |
| return DB.Model(channel).Update("channel_info", channel.ChannelInfo).Error | |
| } | |
| func (channel *Channel) GetModels() []string { | |
| if channel.Models == "" { | |
| return []string{} | |
| } | |
| return strings.Split(strings.Trim(channel.Models, ","), ",") | |
| } | |
| func (channel *Channel) GetGroups() []string { | |
| if channel.Group == "" { | |
| return []string{} | |
| } | |
| groups := strings.Split(strings.Trim(channel.Group, ","), ",") | |
| for i, group := range groups { | |
| groups[i] = strings.TrimSpace(group) | |
| } | |
| return groups | |
| } | |
| func (channel *Channel) GetOtherInfo() map[string]interface{} { | |
| otherInfo := make(map[string]interface{}) | |
| if channel.OtherInfo != "" { | |
| err := common.Unmarshal([]byte(channel.OtherInfo), &otherInfo) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to unmarshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) | |
| } | |
| } | |
| return otherInfo | |
| } | |
| func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { | |
| otherInfoBytes, err := json.Marshal(otherInfo) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to marshal other info: channel_id=%d, tag=%s, name=%s, error=%v", channel.Id, channel.GetTag(), channel.Name, err)) | |
| return | |
| } | |
| channel.OtherInfo = string(otherInfoBytes) | |
| } | |
| func (channel *Channel) GetTag() string { | |
| if channel.Tag == nil { | |
| return "" | |
| } | |
| return *channel.Tag | |
| } | |
| func (channel *Channel) SetTag(tag string) { | |
| channel.Tag = &tag | |
| } | |
| func (channel *Channel) GetAutoBan() bool { | |
| if channel.AutoBan == nil { | |
| return false | |
| } | |
| return *channel.AutoBan == 1 | |
| } | |
| func (channel *Channel) Save() error { | |
| return DB.Save(channel).Error | |
| } | |
| func (channel *Channel) SaveWithoutKey() error { | |
| return DB.Omit("key").Save(channel).Error | |
| } | |
| func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) { | |
| var channels []*Channel | |
| var err error | |
| order := "priority desc" | |
| if idSort { | |
| order = "id desc" | |
| } | |
| if selectAll { | |
| err = DB.Order(order).Find(&channels).Error | |
| } else { | |
| err = DB.Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | |
| } | |
| return channels, err | |
| } | |
| func GetChannelsByTag(tag string, idSort bool) ([]*Channel, error) { | |
| var channels []*Channel | |
| order := "priority desc" | |
| if idSort { | |
| order = "id desc" | |
| } | |
| err := DB.Where("tag = ?", tag).Order(order).Find(&channels).Error | |
| return channels, err | |
| } | |
| func SearchChannels(keyword string, group string, model string, idSort bool) ([]*Channel, error) { | |
| var channels []*Channel | |
| modelsCol := "`models`" | |
| // 如果是 PostgreSQL,使用双引号 | |
| if common.UsingPostgreSQL { | |
| modelsCol = `"models"` | |
| } | |
| baseURLCol := "`base_url`" | |
| // 如果是 PostgreSQL,使用双引号 | |
| if common.UsingPostgreSQL { | |
| baseURLCol = `"base_url"` | |
| } | |
| order := "priority desc" | |
| if idSort { | |
| order = "id desc" | |
| } | |
| // 构造基础查询 | |
| baseQuery := DB.Model(&Channel{}).Omit("key") | |
| // 构造WHERE子句 | |
| var whereClause string | |
| var args []interface{} | |
| if group != "" && group != "null" { | |
| var groupCondition string | |
| if common.UsingMySQL { | |
| groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` | |
| } else { | |
| // sqlite, PostgreSQL | |
| groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` | |
| } | |
| whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition | |
| args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") | |
| } else { | |
| whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" | |
| args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") | |
| } | |
| // 执行查询 | |
| err := baseQuery.Where(whereClause, args...).Order(order).Find(&channels).Error | |
| if err != nil { | |
| return nil, err | |
| } | |
| return channels, nil | |
| } | |
| func GetChannelById(id int, selectAll bool) (*Channel, error) { | |
| channel := &Channel{Id: id} | |
| var err error = nil | |
| if selectAll { | |
| err = DB.First(channel, "id = ?", id).Error | |
| } else { | |
| err = DB.Omit("key").First(channel, "id = ?", id).Error | |
| } | |
| if err != nil { | |
| return nil, err | |
| } | |
| if channel == nil { | |
| return nil, errors.New("channel not found") | |
| } | |
| return channel, nil | |
| } | |
| func BatchInsertChannels(channels []Channel) error { | |
| if len(channels) == 0 { | |
| return nil | |
| } | |
| tx := DB.Begin() | |
| if tx.Error != nil { | |
| return tx.Error | |
| } | |
| defer func() { | |
| if r := recover(); r != nil { | |
| tx.Rollback() | |
| } | |
| }() | |
| for _, chunk := range lo.Chunk(channels, 50) { | |
| if err := tx.Create(&chunk).Error; err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| for _, channel_ := range chunk { | |
| if err := channel_.AddAbilities(tx); err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| } | |
| } | |
| return tx.Commit().Error | |
| } | |
| func BatchDeleteChannels(ids []int) error { | |
| if len(ids) == 0 { | |
| return nil | |
| } | |
| // 使用事务 分批删除channel表和abilities表 | |
| tx := DB.Begin() | |
| if tx.Error != nil { | |
| return tx.Error | |
| } | |
| for _, chunk := range lo.Chunk(ids, 200) { | |
| if err := tx.Where("id in (?)", chunk).Delete(&Channel{}).Error; err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| if err := tx.Where("channel_id in (?)", chunk).Delete(&Ability{}).Error; err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| } | |
| return tx.Commit().Error | |
| } | |
| func (channel *Channel) GetPriority() int64 { | |
| if channel.Priority == nil { | |
| return 0 | |
| } | |
| return *channel.Priority | |
| } | |
| func (channel *Channel) GetWeight() int { | |
| if channel.Weight == nil { | |
| return 0 | |
| } | |
| return int(*channel.Weight) | |
| } | |
| func (channel *Channel) GetBaseURL() string { | |
| if channel.BaseURL == nil { | |
| return "" | |
| } | |
| url := *channel.BaseURL | |
| if url == "" { | |
| url = constant.ChannelBaseURLs[channel.Type] | |
| } | |
| return url | |
| } | |
| func (channel *Channel) GetModelMapping() string { | |
| if channel.ModelMapping == nil { | |
| return "" | |
| } | |
| return *channel.ModelMapping | |
| } | |
| func (channel *Channel) GetStatusCodeMapping() string { | |
| if channel.StatusCodeMapping == nil { | |
| return "" | |
| } | |
| return *channel.StatusCodeMapping | |
| } | |
| func (channel *Channel) Insert() error { | |
| var err error | |
| err = DB.Create(channel).Error | |
| if err != nil { | |
| return err | |
| } | |
| err = channel.AddAbilities(nil) | |
| return err | |
| } | |
| func (channel *Channel) Update() error { | |
| // If this is a multi-key channel, recalculate MultiKeySize based on the current key list to avoid inconsistency after editing keys | |
| if channel.ChannelInfo.IsMultiKey { | |
| var keyStr string | |
| if channel.Key != "" { | |
| keyStr = channel.Key | |
| } else { | |
| // If key is not provided, read the existing key from the database | |
| if existing, err := GetChannelById(channel.Id, true); err == nil { | |
| keyStr = existing.Key | |
| } | |
| } | |
| // Parse the key list (supports newline separation or JSON array) | |
| keys := []string{} | |
| if keyStr != "" { | |
| trimmed := strings.TrimSpace(keyStr) | |
| if strings.HasPrefix(trimmed, "[") { | |
| var arr []json.RawMessage | |
| if err := common.Unmarshal([]byte(trimmed), &arr); err == nil { | |
| keys = make([]string, len(arr)) | |
| for i, v := range arr { | |
| keys[i] = string(v) | |
| } | |
| } | |
| } | |
| if len(keys) == 0 { // fallback to newline split | |
| keys = strings.Split(strings.Trim(keyStr, "\n"), "\n") | |
| } | |
| } | |
| channel.ChannelInfo.MultiKeySize = len(keys) | |
| // Clean up status data that exceeds the new key count to prevent index out of range | |
| if channel.ChannelInfo.MultiKeyStatusList != nil { | |
| for idx := range channel.ChannelInfo.MultiKeyStatusList { | |
| if idx >= channel.ChannelInfo.MultiKeySize { | |
| delete(channel.ChannelInfo.MultiKeyStatusList, idx) | |
| } | |
| } | |
| } | |
| } | |
| var err error | |
| err = DB.Model(channel).Updates(channel).Error | |
| if err != nil { | |
| return err | |
| } | |
| DB.Model(channel).First(channel, "id = ?", channel.Id) | |
| err = channel.UpdateAbilities(nil) | |
| return err | |
| } | |
| func (channel *Channel) UpdateResponseTime(responseTime int64) { | |
| err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ | |
| TestTime: common.GetTimestamp(), | |
| ResponseTime: int(responseTime), | |
| }).Error | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to update response time: channel_id=%d, error=%v", channel.Id, err)) | |
| } | |
| } | |
| func (channel *Channel) UpdateBalance(balance float64) { | |
| err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ | |
| BalanceUpdatedTime: common.GetTimestamp(), | |
| Balance: balance, | |
| }).Error | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to update balance: channel_id=%d, error=%v", channel.Id, err)) | |
| } | |
| } | |
| func (channel *Channel) Delete() error { | |
| var err error | |
| err = DB.Delete(channel).Error | |
| if err != nil { | |
| return err | |
| } | |
| err = channel.DeleteAbilities() | |
| return err | |
| } | |
| var channelStatusLock sync.Mutex | |
| // channelPollingLocks stores locks for each channel.id to ensure thread-safe polling | |
| var channelPollingLocks sync.Map | |
| // GetChannelPollingLock returns or creates a mutex for the given channel ID | |
| func GetChannelPollingLock(channelId int) *sync.Mutex { | |
| if lock, exists := channelPollingLocks.Load(channelId); exists { | |
| return lock.(*sync.Mutex) | |
| } | |
| // Create new lock for this channel | |
| newLock := &sync.Mutex{} | |
| actual, _ := channelPollingLocks.LoadOrStore(channelId, newLock) | |
| return actual.(*sync.Mutex) | |
| } | |
| // CleanupChannelPollingLocks removes locks for channels that no longer exist | |
| // This is optional and can be called periodically to prevent memory leaks | |
| func CleanupChannelPollingLocks() { | |
| var activeChannelIds []int | |
| DB.Model(&Channel{}).Pluck("id", &activeChannelIds) | |
| activeChannelSet := make(map[int]bool) | |
| for _, id := range activeChannelIds { | |
| activeChannelSet[id] = true | |
| } | |
| channelPollingLocks.Range(func(key, value interface{}) bool { | |
| channelId := key.(int) | |
| if !activeChannelSet[channelId] { | |
| channelPollingLocks.Delete(channelId) | |
| } | |
| return true | |
| }) | |
| } | |
| func handlerMultiKeyUpdate(channel *Channel, usingKey string, status int, reason string) { | |
| keys := channel.GetKeys() | |
| if len(keys) == 0 { | |
| channel.Status = status | |
| } else { | |
| var keyIndex int | |
| for i, key := range keys { | |
| if key == usingKey { | |
| keyIndex = i | |
| break | |
| } | |
| } | |
| if channel.ChannelInfo.MultiKeyStatusList == nil { | |
| channel.ChannelInfo.MultiKeyStatusList = make(map[int]int) | |
| } | |
| if status == common.ChannelStatusEnabled { | |
| delete(channel.ChannelInfo.MultiKeyStatusList, keyIndex) | |
| } else { | |
| channel.ChannelInfo.MultiKeyStatusList[keyIndex] = status | |
| if channel.ChannelInfo.MultiKeyDisabledReason == nil { | |
| channel.ChannelInfo.MultiKeyDisabledReason = make(map[int]string) | |
| } | |
| if channel.ChannelInfo.MultiKeyDisabledTime == nil { | |
| channel.ChannelInfo.MultiKeyDisabledTime = make(map[int]int64) | |
| } | |
| channel.ChannelInfo.MultiKeyDisabledReason[keyIndex] = reason | |
| channel.ChannelInfo.MultiKeyDisabledTime[keyIndex] = common.GetTimestamp() | |
| } | |
| if len(channel.ChannelInfo.MultiKeyStatusList) >= channel.ChannelInfo.MultiKeySize { | |
| channel.Status = common.ChannelStatusAutoDisabled | |
| info := channel.GetOtherInfo() | |
| info["status_reason"] = "All keys are disabled" | |
| info["status_time"] = common.GetTimestamp() | |
| channel.SetOtherInfo(info) | |
| } | |
| } | |
| } | |
| func UpdateChannelStatus(channelId int, usingKey string, status int, reason string) bool { | |
| if common.MemoryCacheEnabled { | |
| channelStatusLock.Lock() | |
| defer channelStatusLock.Unlock() | |
| channelCache, _ := CacheGetChannel(channelId) | |
| if channelCache == nil { | |
| return false | |
| } | |
| if channelCache.ChannelInfo.IsMultiKey { | |
| // Use per-channel lock to prevent concurrent map read/write with GetNextEnabledKey | |
| pollingLock := GetChannelPollingLock(channelId) | |
| pollingLock.Lock() | |
| // 如果是多Key模式,更新缓存中的状态 | |
| handlerMultiKeyUpdate(channelCache, usingKey, status, reason) | |
| pollingLock.Unlock() | |
| //CacheUpdateChannel(channelCache) | |
| //return true | |
| } else { | |
| // 如果缓存渠道存在,且状态已是目标状态,直接返回 | |
| if channelCache.Status == status { | |
| return false | |
| } | |
| CacheUpdateChannelStatus(channelId, status) | |
| } | |
| } | |
| shouldUpdateAbilities := false | |
| defer func() { | |
| if shouldUpdateAbilities { | |
| err := UpdateAbilityStatus(channelId, status == common.ChannelStatusEnabled) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to update ability status: channel_id=%d, error=%v", channelId, err)) | |
| } | |
| } | |
| }() | |
| channel, err := GetChannelById(channelId, true) | |
| if err != nil { | |
| return false | |
| } else { | |
| if channel.Status == status { | |
| return false | |
| } | |
| if channel.ChannelInfo.IsMultiKey { | |
| beforeStatus := channel.Status | |
| // Protect map writes with the same per-channel lock used by readers | |
| pollingLock := GetChannelPollingLock(channelId) | |
| pollingLock.Lock() | |
| handlerMultiKeyUpdate(channel, usingKey, status, reason) | |
| pollingLock.Unlock() | |
| if beforeStatus != channel.Status { | |
| shouldUpdateAbilities = true | |
| } | |
| } else { | |
| info := channel.GetOtherInfo() | |
| info["status_reason"] = reason | |
| info["status_time"] = common.GetTimestamp() | |
| channel.SetOtherInfo(info) | |
| channel.Status = status | |
| shouldUpdateAbilities = true | |
| } | |
| err = channel.SaveWithoutKey() | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to update channel status: channel_id=%d, status=%d, error=%v", channel.Id, status, err)) | |
| return false | |
| } | |
| } | |
| return true | |
| } | |
| func EnableChannelByTag(tag string) error { | |
| err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusEnabled).Error | |
| if err != nil { | |
| return err | |
| } | |
| err = UpdateAbilityStatusByTag(tag, true) | |
| return err | |
| } | |
| func DisableChannelByTag(tag string) error { | |
| err := DB.Model(&Channel{}).Where("tag = ?", tag).Update("status", common.ChannelStatusManuallyDisabled).Error | |
| if err != nil { | |
| return err | |
| } | |
| err = UpdateAbilityStatusByTag(tag, false) | |
| return err | |
| } | |
| func EditChannelByTag(tag string, newTag *string, modelMapping *string, models *string, group *string, priority *int64, weight *uint) error { | |
| updateData := Channel{} | |
| shouldReCreateAbilities := false | |
| updatedTag := tag | |
| // 如果 newTag 不为空且不等于 tag,则更新 tag | |
| if newTag != nil && *newTag != tag { | |
| updateData.Tag = newTag | |
| updatedTag = *newTag | |
| } | |
| if modelMapping != nil && *modelMapping != "" { | |
| updateData.ModelMapping = modelMapping | |
| } | |
| if models != nil && *models != "" { | |
| shouldReCreateAbilities = true | |
| updateData.Models = *models | |
| } | |
| if group != nil && *group != "" { | |
| shouldReCreateAbilities = true | |
| updateData.Group = *group | |
| } | |
| if priority != nil { | |
| updateData.Priority = priority | |
| } | |
| if weight != nil { | |
| updateData.Weight = weight | |
| } | |
| err := DB.Model(&Channel{}).Where("tag = ?", tag).Updates(updateData).Error | |
| if err != nil { | |
| return err | |
| } | |
| if shouldReCreateAbilities { | |
| channels, err := GetChannelsByTag(updatedTag, false) | |
| if err == nil { | |
| for _, channel := range channels { | |
| err = channel.UpdateAbilities(nil) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to update abilities: channel_id=%d, tag=%s, error=%v", channel.Id, channel.GetTag(), err)) | |
| } | |
| } | |
| } | |
| } else { | |
| err := UpdateAbilityByTag(tag, newTag, priority, weight) | |
| if err != nil { | |
| return err | |
| } | |
| } | |
| return nil | |
| } | |
| func UpdateChannelUsedQuota(id int, quota int) { | |
| if common.BatchUpdateEnabled { | |
| addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) | |
| return | |
| } | |
| updateChannelUsedQuota(id, quota) | |
| } | |
| func updateChannelUsedQuota(id int, quota int) { | |
| err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to update channel used quota: channel_id=%d, delta_quota=%d, error=%v", id, quota, err)) | |
| } | |
| } | |
| func DeleteChannelByStatus(status int64) (int64, error) { | |
| result := DB.Where("status = ?", status).Delete(&Channel{}) | |
| return result.RowsAffected, result.Error | |
| } | |
| func DeleteDisabledChannel() (int64, error) { | |
| result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) | |
| return result.RowsAffected, result.Error | |
| } | |
| func GetPaginatedTags(offset int, limit int) ([]*string, error) { | |
| var tags []*string | |
| err := DB.Model(&Channel{}).Select("DISTINCT tag").Where("tag != ''").Offset(offset).Limit(limit).Find(&tags).Error | |
| return tags, err | |
| } | |
| func SearchTags(keyword string, group string, model string, idSort bool) ([]*string, error) { | |
| var tags []*string | |
| modelsCol := "`models`" | |
| // 如果是 PostgreSQL,使用双引号 | |
| if common.UsingPostgreSQL { | |
| modelsCol = `"models"` | |
| } | |
| baseURLCol := "`base_url`" | |
| // 如果是 PostgreSQL,使用双引号 | |
| if common.UsingPostgreSQL { | |
| baseURLCol = `"base_url"` | |
| } | |
| order := "priority desc" | |
| if idSort { | |
| order = "id desc" | |
| } | |
| // 构造基础查询 | |
| baseQuery := DB.Model(&Channel{}).Omit("key") | |
| // 构造WHERE子句 | |
| var whereClause string | |
| var args []interface{} | |
| if group != "" && group != "null" { | |
| var groupCondition string | |
| if common.UsingMySQL { | |
| groupCondition = `CONCAT(',', ` + commonGroupCol + `, ',') LIKE ?` | |
| } else { | |
| // sqlite, PostgreSQL | |
| groupCondition = `(',' || ` + commonGroupCol + ` || ',') LIKE ?` | |
| } | |
| whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + ` LIKE ? AND ` + groupCondition | |
| args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%", "%,"+group+",%") | |
| } else { | |
| whereClause = "(id = ? OR name LIKE ? OR " + commonKeyCol + " = ? OR " + baseURLCol + " LIKE ?) AND " + modelsCol + " LIKE ?" | |
| args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+keyword+"%", "%"+model+"%") | |
| } | |
| subQuery := baseQuery.Where(whereClause, args...). | |
| Select("tag"). | |
| Where("tag != ''"). | |
| Order(order) | |
| err := DB.Table("(?) as sub", subQuery). | |
| Select("DISTINCT tag"). | |
| Find(&tags).Error | |
| if err != nil { | |
| return nil, err | |
| } | |
| return tags, nil | |
| } | |
| func (channel *Channel) ValidateSettings() error { | |
| channelParams := &dto.ChannelSettings{} | |
| if channel.Setting != nil && *channel.Setting != "" { | |
| err := common.Unmarshal([]byte(*channel.Setting), channelParams) | |
| if err != nil { | |
| return err | |
| } | |
| } | |
| return nil | |
| } | |
| func (channel *Channel) GetSetting() dto.ChannelSettings { | |
| setting := dto.ChannelSettings{} | |
| if channel.Setting != nil && *channel.Setting != "" { | |
| err := common.Unmarshal([]byte(*channel.Setting), &setting) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) | |
| channel.Setting = nil // 清空设置以避免后续错误 | |
| _ = channel.Save() // 保存修改 | |
| } | |
| } | |
| return setting | |
| } | |
| func (channel *Channel) SetSetting(setting dto.ChannelSettings) { | |
| settingBytes, err := common.Marshal(setting) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) | |
| return | |
| } | |
| channel.Setting = common.GetPointer[string](string(settingBytes)) | |
| } | |
| func (channel *Channel) GetOtherSettings() dto.ChannelOtherSettings { | |
| setting := dto.ChannelOtherSettings{} | |
| if channel.OtherSettings != "" { | |
| err := common.UnmarshalJsonStr(channel.OtherSettings, &setting) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to unmarshal setting: channel_id=%d, error=%v", channel.Id, err)) | |
| channel.OtherSettings = "{}" // 清空设置以避免后续错误 | |
| _ = channel.Save() // 保存修改 | |
| } | |
| } | |
| return setting | |
| } | |
| func (channel *Channel) SetOtherSettings(setting dto.ChannelOtherSettings) { | |
| settingBytes, err := common.Marshal(setting) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to marshal setting: channel_id=%d, error=%v", channel.Id, err)) | |
| return | |
| } | |
| channel.OtherSettings = string(settingBytes) | |
| } | |
| func (channel *Channel) GetParamOverride() map[string]interface{} { | |
| paramOverride := make(map[string]interface{}) | |
| if channel.ParamOverride != nil && *channel.ParamOverride != "" { | |
| err := common.Unmarshal([]byte(*channel.ParamOverride), ¶mOverride) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to unmarshal param override: channel_id=%d, error=%v", channel.Id, err)) | |
| } | |
| } | |
| return paramOverride | |
| } | |
| func (channel *Channel) GetHeaderOverride() map[string]interface{} { | |
| headerOverride := make(map[string]interface{}) | |
| if channel.HeaderOverride != nil && *channel.HeaderOverride != "" { | |
| err := common.Unmarshal([]byte(*channel.HeaderOverride), &headerOverride) | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("failed to unmarshal header override: channel_id=%d, error=%v", channel.Id, err)) | |
| } | |
| } | |
| return headerOverride | |
| } | |
| func GetChannelsByIds(ids []int) ([]*Channel, error) { | |
| var channels []*Channel | |
| err := DB.Where("id in (?)", ids).Find(&channels).Error | |
| return channels, err | |
| } | |
| func BatchSetChannelTag(ids []int, tag *string) error { | |
| // 开启事务 | |
| tx := DB.Begin() | |
| if tx.Error != nil { | |
| return tx.Error | |
| } | |
| // 更新标签 | |
| err := tx.Model(&Channel{}).Where("id in (?)", ids).Update("tag", tag).Error | |
| if err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| // update ability status | |
| channels, err := GetChannelsByIds(ids) | |
| if err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| for _, channel := range channels { | |
| err = channel.UpdateAbilities(tx) | |
| if err != nil { | |
| tx.Rollback() | |
| return err | |
| } | |
| } | |
| // 提交事务 | |
| return tx.Commit().Error | |
| } | |
| // CountAllChannels returns total channels in DB | |
| func CountAllChannels() (int64, error) { | |
| var total int64 | |
| err := DB.Model(&Channel{}).Count(&total).Error | |
| return total, err | |
| } | |
| // CountAllTags returns number of non-empty distinct tags | |
| func CountAllTags() (int64, error) { | |
| var total int64 | |
| err := DB.Model(&Channel{}).Where("tag is not null AND tag != ''").Distinct("tag").Count(&total).Error | |
| return total, err | |
| } | |
| // Get channels of specified type with pagination | |
| func GetChannelsByType(startIdx int, num int, idSort bool, channelType int) ([]*Channel, error) { | |
| var channels []*Channel | |
| order := "priority desc" | |
| if idSort { | |
| order = "id desc" | |
| } | |
| err := DB.Where("type = ?", channelType).Order(order).Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error | |
| return channels, err | |
| } | |
| // Count channels of specific type | |
| func CountChannelsByType(channelType int) (int64, error) { | |
| var count int64 | |
| err := DB.Model(&Channel{}).Where("type = ?", channelType).Count(&count).Error | |
| return count, err | |
| } | |
| // Return map[type]count for all channels | |
| func CountChannelsGroupByType() (map[int64]int64, error) { | |
| type result struct { | |
| Type int64 `gorm:"column:type"` | |
| Count int64 `gorm:"column:count"` | |
| } | |
| var results []result | |
| err := DB.Model(&Channel{}).Select("type, count(*) as count").Group("type").Find(&results).Error | |
| if err != nil { | |
| return nil, err | |
| } | |
| counts := make(map[int64]int64) | |
| for _, r := range results { | |
| counts[r.Type] = r.Count | |
| } | |
| return counts, nil | |
| } | |