| | package model |
| |
|
| | import ( |
| | "errors" |
| | "fmt" |
| | "math/rand" |
| | "one-api/common" |
| | "sort" |
| | "strings" |
| | "sync" |
| | "time" |
| | ) |
| |
|
| | var group2model2channels map[string]map[string][]*Channel |
| | var channelsIDM map[int]*Channel |
| | var channelSyncLock sync.RWMutex |
| |
|
| | func InitChannelCache() { |
| | if !common.MemoryCacheEnabled { |
| | return |
| | } |
| | newChannelId2channel := make(map[int]*Channel) |
| | var channels []*Channel |
| | DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) |
| | for _, channel := range channels { |
| | newChannelId2channel[channel.Id] = channel |
| | } |
| | var abilities []*Ability |
| | DB.Find(&abilities) |
| | groups := make(map[string]bool) |
| | for _, ability := range abilities { |
| | groups[ability.Group] = true |
| | } |
| | newGroup2model2channels := make(map[string]map[string][]*Channel) |
| | newChannelsIDM := make(map[int]*Channel) |
| | for group := range groups { |
| | newGroup2model2channels[group] = make(map[string][]*Channel) |
| | } |
| | for _, channel := range channels { |
| | newChannelsIDM[channel.Id] = channel |
| | groups := strings.Split(channel.Group, ",") |
| | for _, group := range groups { |
| | models := strings.Split(channel.Models, ",") |
| | for _, model := range models { |
| | if _, ok := newGroup2model2channels[group][model]; !ok { |
| | newGroup2model2channels[group][model] = make([]*Channel, 0) |
| | } |
| | newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel) |
| | } |
| | } |
| | } |
| |
|
| | |
| | for group, model2channels := range newGroup2model2channels { |
| | for model, channels := range model2channels { |
| | sort.Slice(channels, func(i, j int) bool { |
| | return channels[i].GetPriority() > channels[j].GetPriority() |
| | }) |
| | newGroup2model2channels[group][model] = channels |
| | } |
| | } |
| |
|
| | channelSyncLock.Lock() |
| | group2model2channels = newGroup2model2channels |
| | channelsIDM = newChannelsIDM |
| | channelSyncLock.Unlock() |
| | common.SysLog("channels synced from database") |
| | } |
| |
|
| | func SyncChannelCache(frequency int) { |
| | for { |
| | time.Sleep(time.Duration(frequency) * time.Second) |
| | common.SysLog("syncing channels from database") |
| | InitChannelCache() |
| | } |
| | } |
| |
|
| | func CacheGetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { |
| | if strings.HasPrefix(model, "gpt-4-gizmo") { |
| | model = "gpt-4-gizmo-*" |
| | } |
| | if strings.HasPrefix(model, "gpt-4o-gizmo") { |
| | model = "gpt-4o-gizmo-*" |
| | } |
| |
|
| | |
| | if !common.MemoryCacheEnabled { |
| | return GetRandomSatisfiedChannel(group, model, retry) |
| | } |
| |
|
| | channelSyncLock.RLock() |
| | channels := group2model2channels[group][model] |
| | channelSyncLock.RUnlock() |
| |
|
| | if len(channels) == 0 { |
| | return nil, errors.New("channel not found") |
| | } |
| |
|
| | uniquePriorities := make(map[int]bool) |
| | for _, channel := range channels { |
| | uniquePriorities[int(channel.GetPriority())] = true |
| | } |
| | var sortedUniquePriorities []int |
| | for priority := range uniquePriorities { |
| | sortedUniquePriorities = append(sortedUniquePriorities, priority) |
| | } |
| | sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities))) |
| |
|
| | if retry >= len(uniquePriorities) { |
| | retry = len(uniquePriorities) - 1 |
| | } |
| | targetPriority := int64(sortedUniquePriorities[retry]) |
| |
|
| | |
| | var targetChannels []*Channel |
| | for _, channel := range channels { |
| | if channel.GetPriority() == targetPriority { |
| | targetChannels = append(targetChannels, channel) |
| | } |
| | } |
| |
|
| | |
| | smoothingFactor := 10 |
| | |
| | totalWeight := 0 |
| | for _, channel := range targetChannels { |
| | totalWeight += channel.GetWeight() + smoothingFactor |
| | } |
| | |
| | randomWeight := rand.Intn(totalWeight) |
| |
|
| | |
| | for _, channel := range targetChannels { |
| | randomWeight -= channel.GetWeight() + smoothingFactor |
| | if randomWeight < 0 { |
| | return channel, nil |
| | } |
| | } |
| | |
| | return nil, errors.New("channel not found") |
| | } |
| |
|
| | func CacheGetChannel(id int) (*Channel, error) { |
| | if !common.MemoryCacheEnabled { |
| | return GetChannelById(id, true) |
| | } |
| | channelSyncLock.RLock() |
| | defer channelSyncLock.RUnlock() |
| |
|
| | c, ok := channelsIDM[id] |
| | if !ok { |
| | return nil, errors.New(fmt.Sprintf("当前渠道# %d,已不存在", id)) |
| | } |
| | return c, nil |
| | } |
| |
|
| | func CacheUpdateChannelStatus(id int, status int) { |
| | if !common.MemoryCacheEnabled { |
| | return |
| | } |
| | channelSyncLock.Lock() |
| | defer channelSyncLock.Unlock() |
| | if channel, ok := channelsIDM[id]; ok { |
| | channel.Status = status |
| | } |
| | } |
| |
|