new-api / model /channel_cache.go
liuzhao521
Deploy New API v0.9.25+ (commit b47cf4ef) to HuggingFace Spaces
4674012
package model
import (
"errors"
"fmt"
"math/rand"
"sort"
"strings"
"sync"
"time"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/setting/ratio_setting"
)
var group2model2channels map[string]map[string][]int // enabled channel
var channelsIDM map[int]*Channel // all channels include disabled
var channelSyncLock sync.RWMutex
func InitChannelCache() {
if !common.MemoryCacheEnabled {
return
}
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
DB.Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
var abilities []*Ability
DB.Find(&abilities)
groups := make(map[string]bool)
for _, ability := range abilities {
groups[ability.Group] = true
}
newGroup2model2channels := make(map[string]map[string][]int)
for group := range groups {
newGroup2model2channels[group] = make(map[string][]int)
}
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue // skip disabled channels
}
groups := strings.Split(channel.Group, ",")
for _, group := range groups {
models := strings.Split(channel.Models, ",")
for _, model := range models {
if _, ok := newGroup2model2channels[group][model]; !ok {
newGroup2model2channels[group][model] = make([]int, 0)
}
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id)
}
}
}
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock()
group2model2channels = newGroup2model2channels
//channelsIDM = newChannelId2channel
for i, channel := range newChannelId2channel {
if channel.ChannelInfo.IsMultiKey {
channel.Keys = channel.GetKeys()
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
if oldChannel, ok := channelsIDM[i]; ok {
// 存在旧的渠道,如果是多key且轮询,保留轮询索引信息
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling {
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex
}
}
}
}
}
channelsIDM = newChannelId2channel
channelSyncLock.Unlock()
common.SysLog("channels synced from database")
}
func SyncChannelCache(frequency int) {
for {
time.Sleep(time.Duration(frequency) * time.Second)
common.SysLog("syncing channels from database")
InitChannelCache()
}
}
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) {
// if memory cache is disabled, get channel directly from database
if !common.MemoryCacheEnabled {
return GetChannel(group, model, retry)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
// First, try to find channels with the exact model name.
channels := group2model2channels[group][model]
// If no channels found, try to find channels with the normalized model name.
if len(channels) == 0 {
normalizedModel := ratio_setting.FormatMatchingModelName(model)
channels = group2model2channels[group][normalizedModel]
}
if len(channels) == 0 {
return nil, nil
}
if len(channels) == 1 {
if channel, ok := channelsIDM[channels[0]]; ok {
return channel, nil
}
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0])
}
uniquePriorities := make(map[int]bool)
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
uniquePriorities[int(channel.GetPriority())] = true
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}
var sortedUniquePriorities []int
for priority := range uniquePriorities {
sortedUniquePriorities = append(sortedUniquePriorities, priority)
}
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities)))
if retry >= len(uniquePriorities) {
retry = len(uniquePriorities) - 1
}
targetPriority := int64(sortedUniquePriorities[retry])
// get the priority for the given retry number
var sumWeight = 0
var targetChannels []*Channel
for _, channelId := range channels {
if channel, ok := channelsIDM[channelId]; ok {
if channel.GetPriority() == targetPriority {
sumWeight += channel.GetWeight()
targetChannels = append(targetChannels, channel)
}
} else {
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId)
}
}
if len(targetChannels) == 0 {
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority))
}
// smoothing factor and adjustment
smoothingFactor := 1
smoothingAdjustment := 0
if sumWeight == 0 {
// when all channels have weight 0, set sumWeight to the number of channels and set smoothing adjustment to 100
// each channel's effective weight = 100
sumWeight = len(targetChannels) * 100
smoothingAdjustment = 100
} else if sumWeight/len(targetChannels) < 10 {
// when the average weight is less than 10, set smoothing factor to 100
smoothingFactor = 100
}
// Calculate the total weight of all channels up to endIdx
totalWeight := sumWeight * smoothingFactor
// Generate a random value in the range [0, totalWeight)
randomWeight := rand.Intn(totalWeight)
// Find a channel based on its weight
for _, channel := range targetChannels {
randomWeight -= channel.GetWeight()*smoothingFactor + smoothingAdjustment
if randomWeight < 0 {
return channel, nil
}
}
// return null if no channel is not found
return nil, errors.New("channel not found")
}
func CacheGetChannel(id int) (*Channel, error) {
if !common.MemoryCacheEnabled {
return GetChannelById(id, true)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("渠道# %d,已不存在", id)
}
return c, nil
}
func CacheGetChannelInfo(id int) (*ChannelInfo, error) {
if !common.MemoryCacheEnabled {
channel, err := GetChannelById(id, true)
if err != nil {
return nil, err
}
return &channel.ChannelInfo, nil
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()
c, ok := channelsIDM[id]
if !ok {
return nil, fmt.Errorf("渠道# %d,已不存在", id)
}
return &c.ChannelInfo, nil
}
func CacheUpdateChannelStatus(id int, status int) {
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel, ok := channelsIDM[id]; ok {
channel.Status = status
}
if status != common.ChannelStatusEnabled {
// delete the channel from group2model2channels
for group, model2channels := range group2model2channels {
for model, channels := range model2channels {
for i, channelId := range channels {
if channelId == id {
// remove the channel from the slice
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...)
break
}
}
}
}
}
}
func CacheUpdateChannel(channel *Channel) {
if !common.MemoryCacheEnabled {
return
}
channelSyncLock.Lock()
defer channelSyncLock.Unlock()
if channel == nil {
return
}
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex)
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
channelsIDM[channel.Id] = channel
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex)
}