| package controller
|
|
|
| import (
|
| "encoding/json"
|
| "sort"
|
| "strconv"
|
| "strings"
|
|
|
| "github.com/QuantumNous/new-api/common"
|
| "github.com/QuantumNous/new-api/constant"
|
| "github.com/QuantumNous/new-api/model"
|
|
|
| "github.com/gin-gonic/gin"
|
| )
|
|
|
|
|
| func GetAllModelsMeta(c *gin.Context) {
|
|
|
| pageInfo := common.GetPageQuery(c)
|
| modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
|
|
| enrichModels(modelsMeta)
|
| var total int64
|
| model.DB.Model(&model.Model{}).Count(&total)
|
|
|
|
|
| vendorCounts, _ := model.GetVendorModelCounts()
|
|
|
| pageInfo.SetTotal(int(total))
|
| pageInfo.SetItems(modelsMeta)
|
| common.ApiSuccess(c, gin.H{
|
| "items": modelsMeta,
|
| "total": total,
|
| "page": pageInfo.GetPage(),
|
| "page_size": pageInfo.GetPageSize(),
|
| "vendor_counts": vendorCounts,
|
| })
|
| }
|
|
|
|
|
| func SearchModelsMeta(c *gin.Context) {
|
|
|
| keyword := c.Query("keyword")
|
| vendor := c.Query("vendor")
|
| pageInfo := common.GetPageQuery(c)
|
|
|
| modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize())
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
|
|
| enrichModels(modelsMeta)
|
| pageInfo.SetTotal(int(total))
|
| pageInfo.SetItems(modelsMeta)
|
| common.ApiSuccess(c, pageInfo)
|
| }
|
|
|
|
|
| func GetModelMeta(c *gin.Context) {
|
| idStr := c.Param("id")
|
| id, err := strconv.Atoi(idStr)
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| var m model.Model
|
| if err := model.DB.First(&m, id).Error; err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| enrichModels([]*model.Model{&m})
|
| common.ApiSuccess(c, &m)
|
| }
|
|
|
|
|
| func CreateModelMeta(c *gin.Context) {
|
| var m model.Model
|
| if err := c.ShouldBindJSON(&m); err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| if m.ModelName == "" {
|
| common.ApiErrorMsg(c, "模型名称不能为空")
|
| return
|
| }
|
|
|
| if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil {
|
| common.ApiError(c, err)
|
| return
|
| } else if dup {
|
| common.ApiErrorMsg(c, "模型名称已存在")
|
| return
|
| }
|
|
|
| if err := m.Insert(); err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| model.RefreshPricing()
|
| common.ApiSuccess(c, &m)
|
| }
|
|
|
|
|
| func UpdateModelMeta(c *gin.Context) {
|
| statusOnly := c.Query("status_only") == "true"
|
|
|
| var m model.Model
|
| if err := c.ShouldBindJSON(&m); err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| if m.Id == 0 {
|
| common.ApiErrorMsg(c, "缺少模型 ID")
|
| return
|
| }
|
|
|
| if statusOnly {
|
|
|
| if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| } else {
|
|
|
| if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil {
|
| common.ApiError(c, err)
|
| return
|
| } else if dup {
|
| common.ApiErrorMsg(c, "模型名称已存在")
|
| return
|
| }
|
|
|
| if err := m.Update(); err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| }
|
| model.RefreshPricing()
|
| common.ApiSuccess(c, &m)
|
| }
|
|
|
|
|
| func DeleteModelMeta(c *gin.Context) {
|
| idStr := c.Param("id")
|
| id, err := strconv.Atoi(idStr)
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| if err := model.DB.Delete(&model.Model{}, id).Error; err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| model.RefreshPricing()
|
| common.ApiSuccess(c, nil)
|
| }
|
|
|
|
|
| func enrichModels(models []*model.Model) {
|
| if len(models) == 0 {
|
| return
|
| }
|
|
|
|
|
| exactNames := make([]string, 0)
|
| exactIdx := make(map[string][]int)
|
| ruleIndices := make([]int, 0)
|
| for i, m := range models {
|
| if m == nil {
|
| continue
|
| }
|
| if m.NameRule == model.NameRuleExact {
|
| exactNames = append(exactNames, m.ModelName)
|
| exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i)
|
| } else {
|
| ruleIndices = append(ruleIndices, i)
|
| }
|
| }
|
|
|
|
|
| channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames)
|
|
|
|
|
| for name, indices := range exactIdx {
|
| chs := channelsByModel[name]
|
| for _, idx := range indices {
|
| mm := models[idx]
|
| if mm.Endpoints == "" {
|
| eps := model.GetModelSupportEndpointTypes(mm.ModelName)
|
| if b, err := json.Marshal(eps); err == nil {
|
| mm.Endpoints = string(b)
|
| }
|
| }
|
| mm.BoundChannels = chs
|
| mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName)
|
| mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName)
|
| }
|
| }
|
|
|
| if len(ruleIndices) == 0 {
|
| return
|
| }
|
|
|
|
|
| pricings := model.GetPricing()
|
|
|
|
|
| matchedNamesByIdx := make(map[int][]string)
|
| endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{})
|
| groupSetByIdx := make(map[int]map[string]struct{})
|
| quotaSetByIdx := make(map[int]map[int]struct{})
|
|
|
| for _, p := range pricings {
|
| for _, idx := range ruleIndices {
|
| mm := models[idx]
|
| var matched bool
|
| switch mm.NameRule {
|
| case model.NameRulePrefix:
|
| matched = strings.HasPrefix(p.ModelName, mm.ModelName)
|
| case model.NameRuleSuffix:
|
| matched = strings.HasSuffix(p.ModelName, mm.ModelName)
|
| case model.NameRuleContains:
|
| matched = strings.Contains(p.ModelName, mm.ModelName)
|
| }
|
| if !matched {
|
| continue
|
| }
|
| matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName)
|
|
|
| es := endpointSetByIdx[idx]
|
| if es == nil {
|
| es = make(map[constant.EndpointType]struct{})
|
| endpointSetByIdx[idx] = es
|
| }
|
| for _, et := range p.SupportedEndpointTypes {
|
| es[et] = struct{}{}
|
| }
|
|
|
| gs := groupSetByIdx[idx]
|
| if gs == nil {
|
| gs = make(map[string]struct{})
|
| groupSetByIdx[idx] = gs
|
| }
|
| for _, g := range p.EnableGroup {
|
| gs[g] = struct{}{}
|
| }
|
|
|
| qs := quotaSetByIdx[idx]
|
| if qs == nil {
|
| qs = make(map[int]struct{})
|
| quotaSetByIdx[idx] = qs
|
| }
|
| qs[p.QuotaType] = struct{}{}
|
| }
|
| }
|
|
|
|
|
| allMatchedSet := make(map[string]struct{})
|
| for _, names := range matchedNamesByIdx {
|
| for _, n := range names {
|
| allMatchedSet[n] = struct{}{}
|
| }
|
| }
|
| allMatched := make([]string, 0, len(allMatchedSet))
|
| for n := range allMatchedSet {
|
| allMatched = append(allMatched, n)
|
| }
|
| matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched)
|
|
|
|
|
| for _, idx := range ruleIndices {
|
| mm := models[idx]
|
|
|
|
|
| if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" {
|
| eps := make([]constant.EndpointType, 0, len(es))
|
| for et := range es {
|
| eps = append(eps, et)
|
| }
|
| if b, err := json.Marshal(eps); err == nil {
|
| mm.Endpoints = string(b)
|
| }
|
| }
|
|
|
|
|
| if gs, ok := groupSetByIdx[idx]; ok {
|
| groups := make([]string, 0, len(gs))
|
| for g := range gs {
|
| groups = append(groups, g)
|
| }
|
| mm.EnableGroups = groups
|
| }
|
|
|
|
|
| if qs, ok := quotaSetByIdx[idx]; ok {
|
| arr := make([]int, 0, len(qs))
|
| for k := range qs {
|
| arr = append(arr, k)
|
| }
|
| sort.Ints(arr)
|
| mm.QuotaTypes = arr
|
| }
|
|
|
|
|
| names := matchedNamesByIdx[idx]
|
| channelSet := make(map[string]model.BoundChannel)
|
| for _, n := range names {
|
| for _, ch := range matchedChannelsByModel[n] {
|
| key := ch.Name + "_" + strconv.Itoa(ch.Type)
|
| channelSet[key] = ch
|
| }
|
| }
|
| if len(channelSet) > 0 {
|
| chs := make([]model.BoundChannel, 0, len(channelSet))
|
| for _, ch := range channelSet {
|
| chs = append(chs, ch)
|
| }
|
| mm.BoundChannels = chs
|
| }
|
|
|
|
|
| mm.MatchedModels = names
|
| mm.MatchedCount = len(names)
|
| }
|
| }
|
|
|