| | package controller |
| |
|
| | import ( |
| | "encoding/json" |
| | "sort" |
| | "strconv" |
| | "strings" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/QuantumNous/new-api/constant" |
| | "github.com/QuantumNous/new-api/model" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | |
| | func GetAllModelsMeta(c *gin.Context) { |
| |
|
| | pageInfo := common.GetPageQuery(c) |
| | modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) |
| | if err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | |
| | enrichModels(modelsMeta) |
| | var total int64 |
| | model.DB.Model(&model.Model{}).Count(&total) |
| |
|
| | |
| | vendorCounts, _ := model.GetVendorModelCounts() |
| |
|
| | pageInfo.SetTotal(int(total)) |
| | pageInfo.SetItems(modelsMeta) |
| | common.ApiSuccess(c, gin.H{ |
| | "items": modelsMeta, |
| | "total": total, |
| | "page": pageInfo.GetPage(), |
| | "page_size": pageInfo.GetPageSize(), |
| | "vendor_counts": vendorCounts, |
| | }) |
| | } |
| |
|
| | |
| | func SearchModelsMeta(c *gin.Context) { |
| |
|
| | keyword := c.Query("keyword") |
| | vendor := c.Query("vendor") |
| | pageInfo := common.GetPageQuery(c) |
| |
|
| | modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) |
| | if err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | |
| | enrichModels(modelsMeta) |
| | pageInfo.SetTotal(int(total)) |
| | pageInfo.SetItems(modelsMeta) |
| | common.ApiSuccess(c, pageInfo) |
| | } |
| |
|
| | |
| | func GetModelMeta(c *gin.Context) { |
| | idStr := c.Param("id") |
| | id, err := strconv.Atoi(idStr) |
| | if err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | var m model.Model |
| | if err := model.DB.First(&m, id).Error; err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | enrichModels([]*model.Model{&m}) |
| | common.ApiSuccess(c, &m) |
| | } |
| |
|
| | |
| | func CreateModelMeta(c *gin.Context) { |
| | var m model.Model |
| | if err := c.ShouldBindJSON(&m); err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | if m.ModelName == "" { |
| | common.ApiErrorMsg(c, "模型名称不能为空") |
| | return |
| | } |
| | |
| | if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } else if dup { |
| | common.ApiErrorMsg(c, "模型名称已存在") |
| | return |
| | } |
| |
|
| | if err := m.Insert(); err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | model.RefreshPricing() |
| | common.ApiSuccess(c, &m) |
| | } |
| |
|
| | |
| | func UpdateModelMeta(c *gin.Context) { |
| | statusOnly := c.Query("status_only") == "true" |
| |
|
| | var m model.Model |
| | if err := c.ShouldBindJSON(&m); err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | if m.Id == 0 { |
| | common.ApiErrorMsg(c, "缺少模型 ID") |
| | return |
| | } |
| |
|
| | if statusOnly { |
| | |
| | if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | } else { |
| | |
| | if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } else if dup { |
| | common.ApiErrorMsg(c, "模型名称已存在") |
| | return |
| | } |
| |
|
| | if err := m.Update(); err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | } |
| | model.RefreshPricing() |
| | common.ApiSuccess(c, &m) |
| | } |
| |
|
| | |
| | func DeleteModelMeta(c *gin.Context) { |
| | idStr := c.Param("id") |
| | id, err := strconv.Atoi(idStr) |
| | if err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { |
| | common.ApiError(c, err) |
| | return |
| | } |
| | model.RefreshPricing() |
| | common.ApiSuccess(c, nil) |
| | } |
| |
|
| | |
| | func enrichModels(models []*model.Model) { |
| | if len(models) == 0 { |
| | return |
| | } |
| |
|
| | |
| | exactNames := make([]string, 0) |
| | exactIdx := make(map[string][]int) |
| | ruleIndices := make([]int, 0) |
| | for i, m := range models { |
| | if m == nil { |
| | continue |
| | } |
| | if m.NameRule == model.NameRuleExact { |
| | exactNames = append(exactNames, m.ModelName) |
| | exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) |
| | } else { |
| | ruleIndices = append(ruleIndices, i) |
| | } |
| | } |
| |
|
| | |
| | channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) |
| |
|
| | |
| | for name, indices := range exactIdx { |
| | chs := channelsByModel[name] |
| | for _, idx := range indices { |
| | mm := models[idx] |
| | if mm.Endpoints == "" { |
| | eps := model.GetModelSupportEndpointTypes(mm.ModelName) |
| | if b, err := json.Marshal(eps); err == nil { |
| | mm.Endpoints = string(b) |
| | } |
| | } |
| | mm.BoundChannels = chs |
| | mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) |
| | mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) |
| | } |
| | } |
| |
|
| | if len(ruleIndices) == 0 { |
| | return |
| | } |
| |
|
| | |
| | pricings := model.GetPricing() |
| |
|
| | |
| | matchedNamesByIdx := make(map[int][]string) |
| | endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) |
| | groupSetByIdx := make(map[int]map[string]struct{}) |
| | quotaSetByIdx := make(map[int]map[int]struct{}) |
| |
|
| | for _, p := range pricings { |
| | for _, idx := range ruleIndices { |
| | mm := models[idx] |
| | var matched bool |
| | switch mm.NameRule { |
| | case model.NameRulePrefix: |
| | matched = strings.HasPrefix(p.ModelName, mm.ModelName) |
| | case model.NameRuleSuffix: |
| | matched = strings.HasSuffix(p.ModelName, mm.ModelName) |
| | case model.NameRuleContains: |
| | matched = strings.Contains(p.ModelName, mm.ModelName) |
| | } |
| | if !matched { |
| | continue |
| | } |
| | matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) |
| |
|
| | es := endpointSetByIdx[idx] |
| | if es == nil { |
| | es = make(map[constant.EndpointType]struct{}) |
| | endpointSetByIdx[idx] = es |
| | } |
| | for _, et := range p.SupportedEndpointTypes { |
| | es[et] = struct{}{} |
| | } |
| |
|
| | gs := groupSetByIdx[idx] |
| | if gs == nil { |
| | gs = make(map[string]struct{}) |
| | groupSetByIdx[idx] = gs |
| | } |
| | for _, g := range p.EnableGroup { |
| | gs[g] = struct{}{} |
| | } |
| |
|
| | qs := quotaSetByIdx[idx] |
| | if qs == nil { |
| | qs = make(map[int]struct{}) |
| | quotaSetByIdx[idx] = qs |
| | } |
| | qs[p.QuotaType] = struct{}{} |
| | } |
| | } |
| |
|
| | |
| | allMatchedSet := make(map[string]struct{}) |
| | for _, names := range matchedNamesByIdx { |
| | for _, n := range names { |
| | allMatchedSet[n] = struct{}{} |
| | } |
| | } |
| | allMatched := make([]string, 0, len(allMatchedSet)) |
| | for n := range allMatchedSet { |
| | allMatched = append(allMatched, n) |
| | } |
| | matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) |
| |
|
| | |
| | for _, idx := range ruleIndices { |
| | mm := models[idx] |
| |
|
| | |
| | if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { |
| | eps := make([]constant.EndpointType, 0, len(es)) |
| | for et := range es { |
| | eps = append(eps, et) |
| | } |
| | if b, err := json.Marshal(eps); err == nil { |
| | mm.Endpoints = string(b) |
| | } |
| | } |
| |
|
| | |
| | if gs, ok := groupSetByIdx[idx]; ok { |
| | groups := make([]string, 0, len(gs)) |
| | for g := range gs { |
| | groups = append(groups, g) |
| | } |
| | mm.EnableGroups = groups |
| | } |
| |
|
| | |
| | if qs, ok := quotaSetByIdx[idx]; ok { |
| | arr := make([]int, 0, len(qs)) |
| | for k := range qs { |
| | arr = append(arr, k) |
| | } |
| | sort.Ints(arr) |
| | mm.QuotaTypes = arr |
| | } |
| |
|
| | |
| | names := matchedNamesByIdx[idx] |
| | channelSet := make(map[string]model.BoundChannel) |
| | for _, n := range names { |
| | for _, ch := range matchedChannelsByModel[n] { |
| | key := ch.Name + "_" + strconv.Itoa(ch.Type) |
| | channelSet[key] = ch |
| | } |
| | } |
| | if len(channelSet) > 0 { |
| | chs := make([]model.BoundChannel, 0, len(channelSet)) |
| | for _, ch := range channelSet { |
| | chs = append(chs, ch) |
| | } |
| | mm.BoundChannels = chs |
| | } |
| |
|
| | |
| | mm.MatchedModels = names |
| | mm.MatchedCount = len(names) |
| | } |
| | } |
| |
|