|
|
package controller |
|
|
|
|
|
import ( |
|
|
"encoding/json" |
|
|
"sort" |
|
|
"strconv" |
|
|
"strings" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
"github.com/QuantumNous/new-api/model" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
|
|
|
func GetAllModelsMeta(c *gin.Context) { |
|
|
|
|
|
pageInfo := common.GetPageQuery(c) |
|
|
modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) |
|
|
if err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
|
|
|
enrichModels(modelsMeta) |
|
|
var total int64 |
|
|
model.DB.Model(&model.Model{}).Count(&total) |
|
|
|
|
|
|
|
|
vendorCounts, _ := model.GetVendorModelCounts() |
|
|
|
|
|
pageInfo.SetTotal(int(total)) |
|
|
pageInfo.SetItems(modelsMeta) |
|
|
common.ApiSuccess(c, gin.H{ |
|
|
"items": modelsMeta, |
|
|
"total": total, |
|
|
"page": pageInfo.GetPage(), |
|
|
"page_size": pageInfo.GetPageSize(), |
|
|
"vendor_counts": vendorCounts, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func SearchModelsMeta(c *gin.Context) { |
|
|
|
|
|
keyword := c.Query("keyword") |
|
|
vendor := c.Query("vendor") |
|
|
pageInfo := common.GetPageQuery(c) |
|
|
|
|
|
modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) |
|
|
if err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
|
|
|
enrichModels(modelsMeta) |
|
|
pageInfo.SetTotal(int(total)) |
|
|
pageInfo.SetItems(modelsMeta) |
|
|
common.ApiSuccess(c, pageInfo) |
|
|
} |
|
|
|
|
|
|
|
|
func GetModelMeta(c *gin.Context) { |
|
|
idStr := c.Param("id") |
|
|
id, err := strconv.Atoi(idStr) |
|
|
if err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
var m model.Model |
|
|
if err := model.DB.First(&m, id).Error; err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
enrichModels([]*model.Model{&m}) |
|
|
common.ApiSuccess(c, &m) |
|
|
} |
|
|
|
|
|
|
|
|
func CreateModelMeta(c *gin.Context) { |
|
|
var m model.Model |
|
|
if err := c.ShouldBindJSON(&m); err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
if m.ModelName == "" { |
|
|
common.ApiErrorMsg(c, "模型名称不能为空") |
|
|
return |
|
|
} |
|
|
|
|
|
if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} else if dup { |
|
|
common.ApiErrorMsg(c, "模型名称已存在") |
|
|
return |
|
|
} |
|
|
|
|
|
if err := m.Insert(); err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
model.RefreshPricing() |
|
|
common.ApiSuccess(c, &m) |
|
|
} |
|
|
|
|
|
|
|
|
func UpdateModelMeta(c *gin.Context) { |
|
|
statusOnly := c.Query("status_only") == "true" |
|
|
|
|
|
var m model.Model |
|
|
if err := c.ShouldBindJSON(&m); err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
if m.Id == 0 { |
|
|
common.ApiErrorMsg(c, "缺少模型 ID") |
|
|
return |
|
|
} |
|
|
|
|
|
if statusOnly { |
|
|
|
|
|
if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
} else { |
|
|
|
|
|
if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} else if dup { |
|
|
common.ApiErrorMsg(c, "模型名称已存在") |
|
|
return |
|
|
} |
|
|
|
|
|
if err := m.Update(); err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
} |
|
|
model.RefreshPricing() |
|
|
common.ApiSuccess(c, &m) |
|
|
} |
|
|
|
|
|
|
|
|
func DeleteModelMeta(c *gin.Context) { |
|
|
idStr := c.Param("id") |
|
|
id, err := strconv.Atoi(idStr) |
|
|
if err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { |
|
|
common.ApiError(c, err) |
|
|
return |
|
|
} |
|
|
model.RefreshPricing() |
|
|
common.ApiSuccess(c, nil) |
|
|
} |
|
|
|
|
|
|
|
|
func enrichModels(models []*model.Model) { |
|
|
if len(models) == 0 { |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
exactNames := make([]string, 0) |
|
|
exactIdx := make(map[string][]int) |
|
|
ruleIndices := make([]int, 0) |
|
|
for i, m := range models { |
|
|
if m == nil { |
|
|
continue |
|
|
} |
|
|
if m.NameRule == model.NameRuleExact { |
|
|
exactNames = append(exactNames, m.ModelName) |
|
|
exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) |
|
|
} else { |
|
|
ruleIndices = append(ruleIndices, i) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) |
|
|
|
|
|
|
|
|
for name, indices := range exactIdx { |
|
|
chs := channelsByModel[name] |
|
|
for _, idx := range indices { |
|
|
mm := models[idx] |
|
|
if mm.Endpoints == "" { |
|
|
eps := model.GetModelSupportEndpointTypes(mm.ModelName) |
|
|
if b, err := json.Marshal(eps); err == nil { |
|
|
mm.Endpoints = string(b) |
|
|
} |
|
|
} |
|
|
mm.BoundChannels = chs |
|
|
mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) |
|
|
mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) |
|
|
} |
|
|
} |
|
|
|
|
|
if len(ruleIndices) == 0 { |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
pricings := model.GetPricing() |
|
|
|
|
|
|
|
|
matchedNamesByIdx := make(map[int][]string) |
|
|
endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) |
|
|
groupSetByIdx := make(map[int]map[string]struct{}) |
|
|
quotaSetByIdx := make(map[int]map[int]struct{}) |
|
|
|
|
|
for _, p := range pricings { |
|
|
for _, idx := range ruleIndices { |
|
|
mm := models[idx] |
|
|
var matched bool |
|
|
switch mm.NameRule { |
|
|
case model.NameRulePrefix: |
|
|
matched = strings.HasPrefix(p.ModelName, mm.ModelName) |
|
|
case model.NameRuleSuffix: |
|
|
matched = strings.HasSuffix(p.ModelName, mm.ModelName) |
|
|
case model.NameRuleContains: |
|
|
matched = strings.Contains(p.ModelName, mm.ModelName) |
|
|
} |
|
|
if !matched { |
|
|
continue |
|
|
} |
|
|
matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) |
|
|
|
|
|
es := endpointSetByIdx[idx] |
|
|
if es == nil { |
|
|
es = make(map[constant.EndpointType]struct{}) |
|
|
endpointSetByIdx[idx] = es |
|
|
} |
|
|
for _, et := range p.SupportedEndpointTypes { |
|
|
es[et] = struct{}{} |
|
|
} |
|
|
|
|
|
gs := groupSetByIdx[idx] |
|
|
if gs == nil { |
|
|
gs = make(map[string]struct{}) |
|
|
groupSetByIdx[idx] = gs |
|
|
} |
|
|
for _, g := range p.EnableGroup { |
|
|
gs[g] = struct{}{} |
|
|
} |
|
|
|
|
|
qs := quotaSetByIdx[idx] |
|
|
if qs == nil { |
|
|
qs = make(map[int]struct{}) |
|
|
quotaSetByIdx[idx] = qs |
|
|
} |
|
|
qs[p.QuotaType] = struct{}{} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
allMatchedSet := make(map[string]struct{}) |
|
|
for _, names := range matchedNamesByIdx { |
|
|
for _, n := range names { |
|
|
allMatchedSet[n] = struct{}{} |
|
|
} |
|
|
} |
|
|
allMatched := make([]string, 0, len(allMatchedSet)) |
|
|
for n := range allMatchedSet { |
|
|
allMatched = append(allMatched, n) |
|
|
} |
|
|
matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) |
|
|
|
|
|
|
|
|
for _, idx := range ruleIndices { |
|
|
mm := models[idx] |
|
|
|
|
|
|
|
|
if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { |
|
|
eps := make([]constant.EndpointType, 0, len(es)) |
|
|
for et := range es { |
|
|
eps = append(eps, et) |
|
|
} |
|
|
if b, err := json.Marshal(eps); err == nil { |
|
|
mm.Endpoints = string(b) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if gs, ok := groupSetByIdx[idx]; ok { |
|
|
groups := make([]string, 0, len(gs)) |
|
|
for g := range gs { |
|
|
groups = append(groups, g) |
|
|
} |
|
|
mm.EnableGroups = groups |
|
|
} |
|
|
|
|
|
|
|
|
if qs, ok := quotaSetByIdx[idx]; ok { |
|
|
arr := make([]int, 0, len(qs)) |
|
|
for k := range qs { |
|
|
arr = append(arr, k) |
|
|
} |
|
|
sort.Ints(arr) |
|
|
mm.QuotaTypes = arr |
|
|
} |
|
|
|
|
|
|
|
|
names := matchedNamesByIdx[idx] |
|
|
channelSet := make(map[string]model.BoundChannel) |
|
|
for _, n := range names { |
|
|
for _, ch := range matchedChannelsByModel[n] { |
|
|
key := ch.Name + "_" + strconv.Itoa(ch.Type) |
|
|
channelSet[key] = ch |
|
|
} |
|
|
} |
|
|
if len(channelSet) > 0 { |
|
|
chs := make([]model.BoundChannel, 0, len(channelSet)) |
|
|
for _, ch := range channelSet { |
|
|
chs = append(chs, ch) |
|
|
} |
|
|
mm.BoundChannels = chs |
|
|
} |
|
|
|
|
|
|
|
|
mm.MatchedModels = names |
|
|
mm.MatchedCount = len(names) |
|
|
} |
|
|
} |
|
|
|