| package model
|
|
|
| import (
|
| "errors"
|
| "fmt"
|
| "strings"
|
| "sync"
|
|
|
| "github.com/QuantumNous/new-api/common"
|
|
|
| "github.com/samber/lo"
|
| "gorm.io/gorm"
|
| "gorm.io/gorm/clause"
|
| )
|
|
|
| type Ability struct {
|
| Group string `json:"group" gorm:"type:varchar(64);primaryKey;autoIncrement:false"`
|
| Model string `json:"model" gorm:"type:varchar(255);primaryKey;autoIncrement:false"`
|
| ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
|
| Enabled bool `json:"enabled"`
|
| Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
|
| Weight uint `json:"weight" gorm:"default:0;index"`
|
| Tag *string `json:"tag" gorm:"index"`
|
| }
|
|
|
| type AbilityWithChannel struct {
|
| Ability
|
| ChannelType int `json:"channel_type"`
|
| }
|
|
|
| func GetAllEnableAbilityWithChannels() ([]AbilityWithChannel, error) {
|
| var abilities []AbilityWithChannel
|
| err := DB.Table("abilities").
|
| Select("abilities.*, channels.type as channel_type").
|
| Joins("left join channels on abilities.channel_id = channels.id").
|
| Where("abilities.enabled = ?", true).
|
| Scan(&abilities).Error
|
| return abilities, err
|
| }
|
|
|
| func GetGroupEnabledModels(group string) []string {
|
| var models []string
|
|
|
| DB.Table("abilities").Where(commonGroupCol+" = ? and enabled = ?", group, true).Distinct("model").Pluck("model", &models)
|
| return models
|
| }
|
|
|
| func GetEnabledModels() []string {
|
| var models []string
|
|
|
| DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models)
|
| return models
|
| }
|
|
|
| func GetAllEnableAbilities() []Ability {
|
| var abilities []Ability
|
| DB.Find(&abilities, "enabled = ?", true)
|
| return abilities
|
| }
|
|
|
| func getPriority(group string, model string, retry int) (int, error) {
|
|
|
| var priorities []int
|
| err := DB.Model(&Ability{}).
|
| Select("DISTINCT(priority)").
|
| Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true).
|
| Order("priority DESC").
|
| Pluck("priority", &priorities).Error
|
|
|
| if err != nil {
|
|
|
| return 0, err
|
| }
|
|
|
| if len(priorities) == 0 {
|
|
|
| return 0, errors.New("数据库一致性被破坏")
|
| }
|
|
|
|
|
| var priorityToUse int
|
| if retry >= len(priorities) {
|
|
|
| priorityToUse = priorities[len(priorities)-1]
|
| } else {
|
| priorityToUse = priorities[retry]
|
| }
|
| return priorityToUse, nil
|
| }
|
|
|
| func getChannelQuery(group string, model string, retry int) (*gorm.DB, error) {
|
| maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(commonGroupCol+" = ? and model = ? and enabled = ?", group, model, true)
|
| channelQuery := DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = (?)", group, model, true, maxPrioritySubQuery)
|
| if retry != 0 {
|
| priority, err := getPriority(group, model, retry)
|
| if err != nil {
|
| return nil, err
|
| } else {
|
| channelQuery = DB.Where(commonGroupCol+" = ? and model = ? and enabled = ? and priority = ?", group, model, true, priority)
|
| }
|
| }
|
|
|
| return channelQuery, nil
|
| }
|
|
|
| func GetChannel(group string, model string, retry int) (*Channel, error) {
|
| var abilities []Ability
|
|
|
| var err error = nil
|
| channelQuery, err := getChannelQuery(group, model, retry)
|
| if err != nil {
|
| return nil, err
|
| }
|
| if common.UsingSQLite || common.UsingPostgreSQL {
|
| err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
| } else {
|
| err = channelQuery.Order("weight DESC").Find(&abilities).Error
|
| }
|
| if err != nil {
|
| return nil, err
|
| }
|
| channel := Channel{}
|
| if len(abilities) > 0 {
|
|
|
| weightSum := uint(0)
|
| for _, ability_ := range abilities {
|
| weightSum += ability_.Weight + 10
|
| }
|
|
|
| weight := common.GetRandomInt(int(weightSum))
|
| for _, ability_ := range abilities {
|
| weight -= int(ability_.Weight) + 10
|
|
|
| if weight <= 0 {
|
| channel.Id = ability_.ChannelId
|
| break
|
| }
|
| }
|
| } else {
|
| return nil, nil
|
| }
|
| err = DB.First(&channel, "id = ?", channel.Id).Error
|
| return &channel, err
|
| }
|
|
|
| func (channel *Channel) AddAbilities(tx *gorm.DB) error {
|
| models_ := strings.Split(channel.Models, ",")
|
| groups_ := strings.Split(channel.Group, ",")
|
| abilitySet := make(map[string]struct{})
|
| abilities := make([]Ability, 0, len(models_))
|
| for _, model := range models_ {
|
| for _, group := range groups_ {
|
| key := group + "|" + model
|
| if _, exists := abilitySet[key]; exists {
|
| continue
|
| }
|
| abilitySet[key] = struct{}{}
|
| ability := Ability{
|
| Group: group,
|
| Model: model,
|
| ChannelId: channel.Id,
|
| Enabled: channel.Status == common.ChannelStatusEnabled,
|
| Priority: channel.Priority,
|
| Weight: uint(channel.GetWeight()),
|
| Tag: channel.Tag,
|
| }
|
| abilities = append(abilities, ability)
|
| }
|
| }
|
| if len(abilities) == 0 {
|
| return nil
|
| }
|
|
|
| useDB := DB
|
| if tx != nil {
|
| useDB = tx
|
| }
|
| for _, chunk := range lo.Chunk(abilities, 50) {
|
| err := useDB.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
| if err != nil {
|
| return err
|
| }
|
| }
|
| return nil
|
| }
|
|
|
| func (channel *Channel) DeleteAbilities() error {
|
| return DB.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
|
| }
|
|
|
|
|
|
|
| func (channel *Channel) UpdateAbilities(tx *gorm.DB) error {
|
| isNewTx := false
|
|
|
| if tx == nil {
|
| tx = DB.Begin()
|
| if tx.Error != nil {
|
| return tx.Error
|
| }
|
| isNewTx = true
|
| defer func() {
|
| if r := recover(); r != nil {
|
| tx.Rollback()
|
| }
|
| }()
|
| }
|
|
|
|
|
| err := tx.Where("channel_id = ?", channel.Id).Delete(&Ability{}).Error
|
| if err != nil {
|
| if isNewTx {
|
| tx.Rollback()
|
| }
|
| return err
|
| }
|
|
|
|
|
| models_ := strings.Split(channel.Models, ",")
|
| groups_ := strings.Split(channel.Group, ",")
|
| abilitySet := make(map[string]struct{})
|
| abilities := make([]Ability, 0, len(models_))
|
| for _, model := range models_ {
|
| for _, group := range groups_ {
|
| key := group + "|" + model
|
| if _, exists := abilitySet[key]; exists {
|
| continue
|
| }
|
| abilitySet[key] = struct{}{}
|
| ability := Ability{
|
| Group: group,
|
| Model: model,
|
| ChannelId: channel.Id,
|
| Enabled: channel.Status == common.ChannelStatusEnabled,
|
| Priority: channel.Priority,
|
| Weight: uint(channel.GetWeight()),
|
| Tag: channel.Tag,
|
| }
|
| abilities = append(abilities, ability)
|
| }
|
| }
|
|
|
| if len(abilities) > 0 {
|
| for _, chunk := range lo.Chunk(abilities, 50) {
|
| err = tx.Clauses(clause.OnConflict{DoNothing: true}).Create(&chunk).Error
|
| if err != nil {
|
| if isNewTx {
|
| tx.Rollback()
|
| }
|
| return err
|
| }
|
| }
|
| }
|
|
|
|
|
| if isNewTx {
|
| return tx.Commit().Error
|
| }
|
|
|
| return nil
|
| }
|
|
|
| func UpdateAbilityStatus(channelId int, status bool) error {
|
| return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
|
| }
|
|
|
| func UpdateAbilityStatusByTag(tag string, status bool) error {
|
| return DB.Model(&Ability{}).Where("tag = ?", tag).Select("enabled").Update("enabled", status).Error
|
| }
|
|
|
| func UpdateAbilityByTag(tag string, newTag *string, priority *int64, weight *uint) error {
|
| ability := Ability{}
|
| if newTag != nil {
|
| ability.Tag = newTag
|
| }
|
| if priority != nil {
|
| ability.Priority = priority
|
| }
|
| if weight != nil {
|
| ability.Weight = *weight
|
| }
|
| return DB.Model(&Ability{}).Where("tag = ?", tag).Updates(ability).Error
|
| }
|
|
|
| var fixLock = sync.Mutex{}
|
|
|
| func FixAbility() (int, int, error) {
|
| lock := fixLock.TryLock()
|
| if !lock {
|
| return 0, 0, errors.New("已经有一个修复任务在运行中,请稍后再试")
|
| }
|
| defer fixLock.Unlock()
|
|
|
|
|
| if common.UsingSQLite {
|
| err := DB.Exec("DELETE FROM abilities").Error
|
| if err != nil {
|
| common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
| return 0, 0, err
|
| }
|
| } else {
|
| err := DB.Exec("TRUNCATE TABLE abilities").Error
|
| if err != nil {
|
| common.SysLog(fmt.Sprintf("Truncate abilities failed: %s", err.Error()))
|
| return 0, 0, err
|
| }
|
| }
|
| var channels []*Channel
|
|
|
| err := DB.Model(&Channel{}).Find(&channels).Error
|
| if err != nil {
|
| return 0, 0, err
|
| }
|
| if len(channels) == 0 {
|
| return 0, 0, nil
|
| }
|
| successCount := 0
|
| failCount := 0
|
| for _, chunk := range lo.Chunk(channels, 50) {
|
| ids := lo.Map(chunk, func(c *Channel, _ int) int { return c.Id })
|
|
|
| err = DB.Where("channel_id IN ?", ids).Delete(&Ability{}).Error
|
| if err != nil {
|
| common.SysLog(fmt.Sprintf("Delete abilities failed: %s", err.Error()))
|
| failCount += len(chunk)
|
| continue
|
| }
|
|
|
| for _, channel := range chunk {
|
| err = channel.AddAbilities(nil)
|
| if err != nil {
|
| common.SysLog(fmt.Sprintf("Add abilities for channel %d failed: %s", channel.Id, err.Error()))
|
| failCount++
|
| } else {
|
| successCount++
|
| }
|
| }
|
| }
|
| InitChannelCache()
|
| return successCount, failCount, nil
|
| }
|
|
|