|
|
package model |
|
|
|
|
|
import ( |
|
|
"errors" |
|
|
"fmt" |
|
|
"math/rand" |
|
|
"sort" |
|
|
"strings" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
"github.com/QuantumNous/new-api/setting/ratio_setting" |
|
|
) |
|
|
|
|
|
var group2model2channels map[string]map[string][]int |
|
|
var channelsIDM map[int]*Channel |
|
|
var channelSyncLock sync.RWMutex |
|
|
|
|
|
func InitChannelCache() { |
|
|
if !common.MemoryCacheEnabled { |
|
|
return |
|
|
} |
|
|
newChannelId2channel := make(map[int]*Channel) |
|
|
var channels []*Channel |
|
|
DB.Find(&channels) |
|
|
for _, channel := range channels { |
|
|
newChannelId2channel[channel.Id] = channel |
|
|
} |
|
|
var abilities []*Ability |
|
|
DB.Find(&abilities) |
|
|
groups := make(map[string]bool) |
|
|
for _, ability := range abilities { |
|
|
groups[ability.Group] = true |
|
|
} |
|
|
newGroup2model2channels := make(map[string]map[string][]int) |
|
|
for group := range groups { |
|
|
newGroup2model2channels[group] = make(map[string][]int) |
|
|
} |
|
|
for _, channel := range channels { |
|
|
if channel.Status != common.ChannelStatusEnabled { |
|
|
continue |
|
|
} |
|
|
groups := strings.Split(channel.Group, ",") |
|
|
for _, group := range groups { |
|
|
models := strings.Split(channel.Models, ",") |
|
|
for _, model := range models { |
|
|
if _, ok := newGroup2model2channels[group][model]; !ok { |
|
|
newGroup2model2channels[group][model] = make([]int, 0) |
|
|
} |
|
|
newGroup2model2channels[group][model] = append(newGroup2model2channels[group][model], channel.Id) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for group, model2channels := range newGroup2model2channels { |
|
|
for model, channels := range model2channels { |
|
|
sort.Slice(channels, func(i, j int) bool { |
|
|
return newChannelId2channel[channels[i]].GetPriority() > newChannelId2channel[channels[j]].GetPriority() |
|
|
}) |
|
|
newGroup2model2channels[group][model] = channels |
|
|
} |
|
|
} |
|
|
|
|
|
channelSyncLock.Lock() |
|
|
group2model2channels = newGroup2model2channels |
|
|
|
|
|
for i, channel := range newChannelId2channel { |
|
|
if channel.ChannelInfo.IsMultiKey { |
|
|
channel.Keys = channel.GetKeys() |
|
|
if channel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { |
|
|
if oldChannel, ok := channelsIDM[i]; ok { |
|
|
|
|
|
if oldChannel.ChannelInfo.IsMultiKey && oldChannel.ChannelInfo.MultiKeyMode == constant.MultiKeyModePolling { |
|
|
channel.ChannelInfo.MultiKeyPollingIndex = oldChannel.ChannelInfo.MultiKeyPollingIndex |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
channelsIDM = newChannelId2channel |
|
|
channelSyncLock.Unlock() |
|
|
common.SysLog("channels synced from database") |
|
|
} |
|
|
|
|
|
func SyncChannelCache(frequency int) { |
|
|
for { |
|
|
time.Sleep(time.Duration(frequency) * time.Second) |
|
|
common.SysLog("syncing channels from database") |
|
|
InitChannelCache() |
|
|
} |
|
|
} |
|
|
|
|
|
func GetRandomSatisfiedChannel(group string, model string, retry int) (*Channel, error) { |
|
|
|
|
|
if !common.MemoryCacheEnabled { |
|
|
return GetChannel(group, model, retry) |
|
|
} |
|
|
|
|
|
channelSyncLock.RLock() |
|
|
defer channelSyncLock.RUnlock() |
|
|
|
|
|
|
|
|
channels := group2model2channels[group][model] |
|
|
|
|
|
|
|
|
if len(channels) == 0 { |
|
|
normalizedModel := ratio_setting.FormatMatchingModelName(model) |
|
|
channels = group2model2channels[group][normalizedModel] |
|
|
} |
|
|
|
|
|
if len(channels) == 0 { |
|
|
return nil, nil |
|
|
} |
|
|
|
|
|
if len(channels) == 1 { |
|
|
if channel, ok := channelsIDM[channels[0]]; ok { |
|
|
return channel, nil |
|
|
} |
|
|
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channels[0]) |
|
|
} |
|
|
|
|
|
uniquePriorities := make(map[int]bool) |
|
|
for _, channelId := range channels { |
|
|
if channel, ok := channelsIDM[channelId]; ok { |
|
|
uniquePriorities[int(channel.GetPriority())] = true |
|
|
} else { |
|
|
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) |
|
|
} |
|
|
} |
|
|
var sortedUniquePriorities []int |
|
|
for priority := range uniquePriorities { |
|
|
sortedUniquePriorities = append(sortedUniquePriorities, priority) |
|
|
} |
|
|
sort.Sort(sort.Reverse(sort.IntSlice(sortedUniquePriorities))) |
|
|
|
|
|
if retry >= len(uniquePriorities) { |
|
|
retry = len(uniquePriorities) - 1 |
|
|
} |
|
|
targetPriority := int64(sortedUniquePriorities[retry]) |
|
|
|
|
|
|
|
|
var sumWeight = 0 |
|
|
var targetChannels []*Channel |
|
|
for _, channelId := range channels { |
|
|
if channel, ok := channelsIDM[channelId]; ok { |
|
|
if channel.GetPriority() == targetPriority { |
|
|
sumWeight += channel.GetWeight() |
|
|
targetChannels = append(targetChannels, channel) |
|
|
} |
|
|
} else { |
|
|
return nil, fmt.Errorf("数据库一致性错误,渠道# %d 不存在,请联系管理员修复", channelId) |
|
|
} |
|
|
} |
|
|
|
|
|
if len(targetChannels) == 0 { |
|
|
return nil, errors.New(fmt.Sprintf("no channel found, group: %s, model: %s, priority: %d", group, model, targetPriority)) |
|
|
} |
|
|
|
|
|
|
|
|
smoothingFactor := 1 |
|
|
smoothingAdjustment := 0 |
|
|
|
|
|
if sumWeight == 0 { |
|
|
|
|
|
|
|
|
sumWeight = len(targetChannels) * 100 |
|
|
smoothingAdjustment = 100 |
|
|
} else if sumWeight/len(targetChannels) < 10 { |
|
|
|
|
|
smoothingFactor = 100 |
|
|
} |
|
|
|
|
|
|
|
|
totalWeight := sumWeight * smoothingFactor |
|
|
|
|
|
|
|
|
randomWeight := rand.Intn(totalWeight) |
|
|
|
|
|
|
|
|
for _, channel := range targetChannels { |
|
|
randomWeight -= channel.GetWeight()*smoothingFactor + smoothingAdjustment |
|
|
if randomWeight < 0 { |
|
|
return channel, nil |
|
|
} |
|
|
} |
|
|
|
|
|
return nil, errors.New("channel not found") |
|
|
} |
|
|
|
|
|
func CacheGetChannel(id int) (*Channel, error) { |
|
|
if !common.MemoryCacheEnabled { |
|
|
return GetChannelById(id, true) |
|
|
} |
|
|
channelSyncLock.RLock() |
|
|
defer channelSyncLock.RUnlock() |
|
|
|
|
|
c, ok := channelsIDM[id] |
|
|
if !ok { |
|
|
return nil, fmt.Errorf("渠道# %d,已不存在", id) |
|
|
} |
|
|
return c, nil |
|
|
} |
|
|
|
|
|
func CacheGetChannelInfo(id int) (*ChannelInfo, error) { |
|
|
if !common.MemoryCacheEnabled { |
|
|
channel, err := GetChannelById(id, true) |
|
|
if err != nil { |
|
|
return nil, err |
|
|
} |
|
|
return &channel.ChannelInfo, nil |
|
|
} |
|
|
channelSyncLock.RLock() |
|
|
defer channelSyncLock.RUnlock() |
|
|
|
|
|
c, ok := channelsIDM[id] |
|
|
if !ok { |
|
|
return nil, fmt.Errorf("渠道# %d,已不存在", id) |
|
|
} |
|
|
return &c.ChannelInfo, nil |
|
|
} |
|
|
|
|
|
func CacheUpdateChannelStatus(id int, status int) { |
|
|
if !common.MemoryCacheEnabled { |
|
|
return |
|
|
} |
|
|
channelSyncLock.Lock() |
|
|
defer channelSyncLock.Unlock() |
|
|
if channel, ok := channelsIDM[id]; ok { |
|
|
channel.Status = status |
|
|
} |
|
|
if status != common.ChannelStatusEnabled { |
|
|
|
|
|
for group, model2channels := range group2model2channels { |
|
|
for model, channels := range model2channels { |
|
|
for i, channelId := range channels { |
|
|
if channelId == id { |
|
|
|
|
|
group2model2channels[group][model] = append(channels[:i], channels[i+1:]...) |
|
|
break |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func CacheUpdateChannel(channel *Channel) { |
|
|
if !common.MemoryCacheEnabled { |
|
|
return |
|
|
} |
|
|
channelSyncLock.Lock() |
|
|
defer channelSyncLock.Unlock() |
|
|
if channel == nil { |
|
|
return |
|
|
} |
|
|
|
|
|
println("CacheUpdateChannel:", channel.Id, channel.Name, channel.Status, channel.ChannelInfo.MultiKeyPollingIndex) |
|
|
|
|
|
println("before:", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) |
|
|
channelsIDM[channel.Id] = channel |
|
|
println("after :", channelsIDM[channel.Id].ChannelInfo.MultiKeyPollingIndex) |
|
|
} |
|
|
|