Spaces:
Build error
Build error
| package controller | |
| import ( | |
| "encoding/json" | |
| "sort" | |
| "strconv" | |
| "strings" | |
| "one-api/common" | |
| "one-api/constant" | |
| "one-api/model" | |
| "github.com/gin-gonic/gin" | |
| ) | |
| // GetAllModelsMeta 获取模型列表(分页) | |
| func GetAllModelsMeta(c *gin.Context) { | |
| pageInfo := common.GetPageQuery(c) | |
| modelsMeta, err := model.GetAllModels(pageInfo.GetStartIdx(), pageInfo.GetPageSize()) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| // 批量填充附加字段,提升列表接口性能 | |
| enrichModels(modelsMeta) | |
| var total int64 | |
| model.DB.Model(&model.Model{}).Count(&total) | |
| // 统计供应商计数(全部数据,不受分页影响) | |
| vendorCounts, _ := model.GetVendorModelCounts() | |
| pageInfo.SetTotal(int(total)) | |
| pageInfo.SetItems(modelsMeta) | |
| common.ApiSuccess(c, gin.H{ | |
| "items": modelsMeta, | |
| "total": total, | |
| "page": pageInfo.GetPage(), | |
| "page_size": pageInfo.GetPageSize(), | |
| "vendor_counts": vendorCounts, | |
| }) | |
| } | |
| // SearchModelsMeta 搜索模型列表 | |
| func SearchModelsMeta(c *gin.Context) { | |
| keyword := c.Query("keyword") | |
| vendor := c.Query("vendor") | |
| pageInfo := common.GetPageQuery(c) | |
| modelsMeta, total, err := model.SearchModels(keyword, vendor, pageInfo.GetStartIdx(), pageInfo.GetPageSize()) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| // 批量填充附加字段,提升列表接口性能 | |
| enrichModels(modelsMeta) | |
| pageInfo.SetTotal(int(total)) | |
| pageInfo.SetItems(modelsMeta) | |
| common.ApiSuccess(c, pageInfo) | |
| } | |
| // GetModelMeta 根据 ID 获取单条模型信息 | |
| func GetModelMeta(c *gin.Context) { | |
| idStr := c.Param("id") | |
| id, err := strconv.Atoi(idStr) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| var m model.Model | |
| if err := model.DB.First(&m, id).Error; err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| enrichModels([]*model.Model{&m}) | |
| common.ApiSuccess(c, &m) | |
| } | |
| // CreateModelMeta 新建模型 | |
| func CreateModelMeta(c *gin.Context) { | |
| var m model.Model | |
| if err := c.ShouldBindJSON(&m); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| if m.ModelName == "" { | |
| common.ApiErrorMsg(c, "模型名称不能为空") | |
| return | |
| } | |
| // 名称冲突检查 | |
| if dup, err := model.IsModelNameDuplicated(0, m.ModelName); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } else if dup { | |
| common.ApiErrorMsg(c, "模型名称已存在") | |
| return | |
| } | |
| if err := m.Insert(); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| model.RefreshPricing() | |
| common.ApiSuccess(c, &m) | |
| } | |
| // UpdateModelMeta 更新模型 | |
| func UpdateModelMeta(c *gin.Context) { | |
| statusOnly := c.Query("status_only") == "true" | |
| var m model.Model | |
| if err := c.ShouldBindJSON(&m); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| if m.Id == 0 { | |
| common.ApiErrorMsg(c, "缺少模型 ID") | |
| return | |
| } | |
| if statusOnly { | |
| // 只更新状态,防止误清空其他字段 | |
| if err := model.DB.Model(&model.Model{}).Where("id = ?", m.Id).Update("status", m.Status).Error; err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| } else { | |
| // 名称冲突检查 | |
| if dup, err := model.IsModelNameDuplicated(m.Id, m.ModelName); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } else if dup { | |
| common.ApiErrorMsg(c, "模型名称已存在") | |
| return | |
| } | |
| if err := m.Update(); err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| } | |
| model.RefreshPricing() | |
| common.ApiSuccess(c, &m) | |
| } | |
| // DeleteModelMeta 删除模型 | |
| func DeleteModelMeta(c *gin.Context) { | |
| idStr := c.Param("id") | |
| id, err := strconv.Atoi(idStr) | |
| if err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| if err := model.DB.Delete(&model.Model{}, id).Error; err != nil { | |
| common.ApiError(c, err) | |
| return | |
| } | |
| model.RefreshPricing() | |
| common.ApiSuccess(c, nil) | |
| } | |
| // enrichModels 批量填充附加信息:端点、渠道、分组、计费类型,避免 N+1 查询 | |
| func enrichModels(models []*model.Model) { | |
| if len(models) == 0 { | |
| return | |
| } | |
| // 1) 拆分精确与规则匹配 | |
| exactNames := make([]string, 0) | |
| exactIdx := make(map[string][]int) // modelName -> indices in models | |
| ruleIndices := make([]int, 0) | |
| for i, m := range models { | |
| if m == nil { | |
| continue | |
| } | |
| if m.NameRule == model.NameRuleExact { | |
| exactNames = append(exactNames, m.ModelName) | |
| exactIdx[m.ModelName] = append(exactIdx[m.ModelName], i) | |
| } else { | |
| ruleIndices = append(ruleIndices, i) | |
| } | |
| } | |
| // 2) 批量查询精确模型的绑定渠道 | |
| channelsByModel, _ := model.GetBoundChannelsByModelsMap(exactNames) | |
| // 3) 精确模型:端点从缓存、渠道批量映射、分组/计费类型从缓存 | |
| for name, indices := range exactIdx { | |
| chs := channelsByModel[name] | |
| for _, idx := range indices { | |
| mm := models[idx] | |
| if mm.Endpoints == "" { | |
| eps := model.GetModelSupportEndpointTypes(mm.ModelName) | |
| if b, err := json.Marshal(eps); err == nil { | |
| mm.Endpoints = string(b) | |
| } | |
| } | |
| mm.BoundChannels = chs | |
| mm.EnableGroups = model.GetModelEnableGroups(mm.ModelName) | |
| mm.QuotaTypes = model.GetModelQuotaTypes(mm.ModelName) | |
| } | |
| } | |
| if len(ruleIndices) == 0 { | |
| return | |
| } | |
| // 4) 一次性读取定价缓存,内存匹配所有规则模型 | |
| pricings := model.GetPricing() | |
| // 为全部规则模型收集匹配名集合、端点并集、分组并集、配额集合 | |
| matchedNamesByIdx := make(map[int][]string) | |
| endpointSetByIdx := make(map[int]map[constant.EndpointType]struct{}) | |
| groupSetByIdx := make(map[int]map[string]struct{}) | |
| quotaSetByIdx := make(map[int]map[int]struct{}) | |
| for _, p := range pricings { | |
| for _, idx := range ruleIndices { | |
| mm := models[idx] | |
| var matched bool | |
| switch mm.NameRule { | |
| case model.NameRulePrefix: | |
| matched = strings.HasPrefix(p.ModelName, mm.ModelName) | |
| case model.NameRuleSuffix: | |
| matched = strings.HasSuffix(p.ModelName, mm.ModelName) | |
| case model.NameRuleContains: | |
| matched = strings.Contains(p.ModelName, mm.ModelName) | |
| } | |
| if !matched { | |
| continue | |
| } | |
| matchedNamesByIdx[idx] = append(matchedNamesByIdx[idx], p.ModelName) | |
| es := endpointSetByIdx[idx] | |
| if es == nil { | |
| es = make(map[constant.EndpointType]struct{}) | |
| endpointSetByIdx[idx] = es | |
| } | |
| for _, et := range p.SupportedEndpointTypes { | |
| es[et] = struct{}{} | |
| } | |
| gs := groupSetByIdx[idx] | |
| if gs == nil { | |
| gs = make(map[string]struct{}) | |
| groupSetByIdx[idx] = gs | |
| } | |
| for _, g := range p.EnableGroup { | |
| gs[g] = struct{}{} | |
| } | |
| qs := quotaSetByIdx[idx] | |
| if qs == nil { | |
| qs = make(map[int]struct{}) | |
| quotaSetByIdx[idx] = qs | |
| } | |
| qs[p.QuotaType] = struct{}{} | |
| } | |
| } | |
| // 5) 汇总所有匹配到的模型名称,批量查询一次渠道 | |
| allMatchedSet := make(map[string]struct{}) | |
| for _, names := range matchedNamesByIdx { | |
| for _, n := range names { | |
| allMatchedSet[n] = struct{}{} | |
| } | |
| } | |
| allMatched := make([]string, 0, len(allMatchedSet)) | |
| for n := range allMatchedSet { | |
| allMatched = append(allMatched, n) | |
| } | |
| matchedChannelsByModel, _ := model.GetBoundChannelsByModelsMap(allMatched) | |
| // 6) 回填每个规则模型的并集信息 | |
| for _, idx := range ruleIndices { | |
| mm := models[idx] | |
| // 端点并集 -> 序列化 | |
| if es, ok := endpointSetByIdx[idx]; ok && mm.Endpoints == "" { | |
| eps := make([]constant.EndpointType, 0, len(es)) | |
| for et := range es { | |
| eps = append(eps, et) | |
| } | |
| if b, err := json.Marshal(eps); err == nil { | |
| mm.Endpoints = string(b) | |
| } | |
| } | |
| // 分组并集 | |
| if gs, ok := groupSetByIdx[idx]; ok { | |
| groups := make([]string, 0, len(gs)) | |
| for g := range gs { | |
| groups = append(groups, g) | |
| } | |
| mm.EnableGroups = groups | |
| } | |
| // 配额类型集合(保持去重并排序) | |
| if qs, ok := quotaSetByIdx[idx]; ok { | |
| arr := make([]int, 0, len(qs)) | |
| for k := range qs { | |
| arr = append(arr, k) | |
| } | |
| sort.Ints(arr) | |
| mm.QuotaTypes = arr | |
| } | |
| // 渠道并集 | |
| names := matchedNamesByIdx[idx] | |
| channelSet := make(map[string]model.BoundChannel) | |
| for _, n := range names { | |
| for _, ch := range matchedChannelsByModel[n] { | |
| key := ch.Name + "_" + strconv.Itoa(ch.Type) | |
| channelSet[key] = ch | |
| } | |
| } | |
| if len(channelSet) > 0 { | |
| chs := make([]model.BoundChannel, 0, len(channelSet)) | |
| for _, ch := range channelSet { | |
| chs = append(chs, ch) | |
| } | |
| mm.BoundChannels = chs | |
| } | |
| // 匹配信息 | |
| mm.MatchedModels = names | |
| mm.MatchedCount = len(names) | |
| } | |
| } | |