Spaces:
Build error
Build error
| package model | |
| import ( | |
| "errors" | |
| "fmt" | |
| "math/rand" | |
| "one-api/common" | |
| "one-api/constant" | |
| "one-api/setting" | |
| "one-api/setting/ratio_setting" | |
| "sort" | |
| "strings" | |
| "sync" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| var group2model2channels map[string]map[string][]int // enabled channel | |
| var channelsIDM map[int]*Channel // all channels include disabled | |
| var channelSyncLock sync.RWMutex | |
| func InitChannelCache() { | |
| if !common.MemoryCacheEnabled { | |
| return | |
| } | |
| newChannelId2channel := make(map[int]*Channel) | |
| var channels []*Channel | |
| DB.Find(&channels) | |
| for _, channel := range channels { | |
| newChannelId2channel[channel.Id] = channel | |
| } | |
| var abilities []*Ability | |
| DB.Find(&abilities) | |
| groups := make(map[string]bool) | |
| for _, ability := range abilities { | |
| groups[ability.Group] = true | |
| } | |
| newGroup2model2channels := make(map[string]map[string][]int) | |
| for group := range groups { | |
| newGroup2model2channels[group] = make(map[string][]int) | |
| } | |
| for _, channel := range channels { | |
| if channel.Status != common.ChannelStatusEnabled { | |
| continue // skip disabled channels | |
| } | |
| groups := strings.Split(channel.Group, ",") | |
| for _, group := range groups { | |
| models := strings.Split(channel.Models, ",") | |
| for _, model := range models { | |
| if _, ok := newGroup2model2channels[group][model]; !ok { | |
| newGroup2model2channels[group][model] = make([]int, 0) | |
| } | |
| newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id) | |
| } | |
| } | |
| } | |
| // sort by priority | |
| for group, model2channels := range newGroup2model2channels { | |
| for model, channels := range model2channels { | |
| sort.Slice(channels, func(i, j int) bool { | |
| return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority() | |
| }) | |
| newGroup2model2channels[group][model] = channels | |
| } | |
| } | |
| channelSyncLock.Lock() | |
| group2model2channels = newGroup2model2channels | |
| //channelsIDM = newChannelId2channel | |
| for i, channel := range newChannelId2channel { | |
| if channel.ChannelInfo.IsMultiKey { | |
| channel.Keys = channel.GetKeys() | |
| if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { | |
| if oldChannel, ok := channelsIDM[i]; ok { | |
| // 存在旧的渠道,如果是多key且轮询,保留轮询索引信息 | |
| if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { | |
| channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex | |
| } | |
| } | |
| } | |
| } | |
| } | |
| channelsIDM = newChannelId2channel | |
| channelSyncLock.Unlock() | |
| common.SysLog("channels synced from database") | |
| } | |
| func SyncChannelCache(frequency int) { | |
| for { | |
| time.Sleep(time.Duration(frequency) * time.Second) | |
| common.SysLog("syncing channels from database") | |
| InitChannelCache() | |
| } | |
| } | |
| func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string, retry int) (*Channel, string, error) { | |
| var channel *Channel | |
| var err error | |
| selectGroup := group | |
| if group == "auto" { | |
| if len(setting.AutoGroups) == 0 { | |
| return nil, selectGroup, errors.New("auto groups is not enabled") | |
| } | |
| for _, autoGroup := range setting.AutoGroups { | |
| if common.DebugEnabled { | |
| println("autoGroup:", autoGroup) | |
| } | |
| channel, _ = getRandomSatisfiedChannel(autoGroup, model, retry) | |
| if channel == nil { | |
| continue | |
| } else { | |
| c.Set("auto_group", autoGroup) | |
| selectGroup = autoGroup | |
| if common.DebugEnabled { | |
| println("selectGroup:", selectGroup) | |
| } | |
| break | |
| } | |
| } | |
| } else { | |
| channel, err = getRandomSatisfiedChannel(group, model, retry) | |
| if err != nil { | |
| return nil, group, err | |
| } | |
| } | |
| return channel, selectGroup, nil | |
| } | |
| func getRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { | |
| // if memory cache is disabled, get channel directly from database | |
| if !common.MemoryCacheEnabled { | |
| return GetRandomSatisfiedChannel(group, model, retry) | |
| } | |
| channelSyncLock.RLock() | |
| defer channelSyncLock.RUnlock() | |
| // First, try to find channels with the exact model name. | |
| channels := group2model2channels[group][model] | |
| // If no channels found, try to find channels with the normalized model name. | |
| if len(channels) == 0 { | |
| normalizedModel := ratio_setting.FormatMatchingModelName(model) | |
| channels = group2model2channels[group][normalizedModel] | |
| } | |
| if len(channels) == 0 { | |
| return nil, nil | |
| } | |
| if len(channels) == 1 { | |
| if channel, ok := channelsIDM[channels[0]]; ok { | |
| return channel, nil | |
| } | |
| return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]) | |
| } | |
| uniquePriorities := make(map[int]bool) | |
| for _, channelId := range channels { | |
| if channel, ok := channelsIDM[channelId]; ok { | |
| uniquePriorities[int(channel.GetPriority())] = true | |
| } else { | |
| return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) | |
| } | |
| } | |
| var sortedUniquePriorities []int | |
| for priority := range uniquePriorities { | |
| sortedUniquePriorities = append(sortedUniquePriorities, priority) | |
| } | |
| sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities))) | |
| if retry >= len(uniquePriorities) { | |
| retry = len(uniquePriorities) - 1 | |
| } | |
| targetPriority := int64(sortedUniquePriorities[retry]) | |
| // get the priority for the given retry number | |
| var targetChannels []*Channel | |
| for _, channelId := range channels { | |
| if channel, ok := channelsIDM[channelId]; ok { | |
| if channel.GetPriority() == targetPriority { | |
| targetChannels = append(targetChannels, channel) | |
| } | |
| } else { | |
| return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) | |
| } | |
| } | |
| // 平滑系数 | |
| smoothingFactor := 10 | |
| // Calculate the total weight of all channels up to endIdx | |
| totalWeight := 0 | |
| for _, channel := range targetChannels { | |
| totalWeight += channel.GetWeight() + smoothingFactor | |
| } | |
| // Generate a random value in the range [0, totalWeight) | |
| randomWeight := rand.Intn(totalWeight) | |
| // Find a channel based on its weight | |
| for _, channel := range targetChannels { | |
| randomWeight -= channel.GetWeight() + smoothingFactor | |
| if randomWeight < 0 { | |
| return channel, nil | |
| } | |
| } | |
| // return null if no channel is not found | |
| return nil, errors.New("channel not found") | |
| } | |
| func CacheGetChannel(id int) (*Channel, error) { | |
| if !common.MemoryCacheEnabled { | |
| return GetChannelById(id, true) | |
| } | |
| channelSyncLock.RLock() | |
| defer channelSyncLock.RUnlock() | |
| c, ok := channelsIDM[id] | |
| if !ok { | |
| return nil, fmt.Errorf("渠道# %d,已不存在", id) | |
| } | |
| return c, nil | |
| } | |
| func CacheGetChannelInfo(id int) (*ChannelInfo, error) { | |
| if !common.MemoryCacheEnabled { | |
| channel, err := GetChannelById(id, true) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return &channel.ChannelInfo, nil | |
| } | |
| channelSyncLock.RLock() | |
| defer channelSyncLock.RUnlock() | |
| c, ok := channelsIDM[id] | |
| if !ok { | |
| return nil, fmt.Errorf("渠道# %d,已不存在", id) | |
| } | |
| return &c.ChannelInfo, nil | |
| } | |
| func CacheUpdateChannelStatus(id int, status int) { | |
| if !common.MemoryCacheEnabled { | |
| return | |
| } | |
| channelSyncLock.Lock() | |
| defer channelSyncLock.Unlock() | |
| if channel, ok := channelsIDM[id]; ok { | |
| channel.Status = status | |
| } | |
| if status != common.ChannelStatusEnabled { | |
| // delete the channel from group2model2channels | |
| for group, model2channels := range group2model2channels { | |
| for model, channels := range model2channels { | |
| for i, channelId := range channels { | |
| if channelId == id { | |
| // remove the channel from the slice | |
| group2model2channels[group][model] = append(channels[:i], channels[i+1:]...) | |
| break | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| func CacheUpdateChannel(channel *Channel) { | |
| if !common.MemoryCacheEnabled { | |
| return | |
| } | |
| channelSyncLock.Lock() | |
| defer channelSyncLock.Unlock() | |
| if channel == nil { | |
| return | |
| } | |
| println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex) | |
| println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) | |
| channelsIDM[channel.Id] = channel | |
| println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) | |
| } | |