Spaces:
Build error
Build error
| package model | |
| import ( | |
| "encoding/json" | |
| "fmt" | |
| "strings" | |
| "one-api/common" | |
| "one-api/constant" | |
| "one-api/setting/ratio_setting" | |
| "one-api/types" | |
| "sync" | |
| "time" | |
| ) | |
| type Pricing struct { | |
| ModelName string `json:"model_name"` | |
| Description string `json:"description,omitempty"` | |
| Icon string `json:"icon,omitempty"` | |
| Tags string `json:"tags,omitempty"` | |
| VendorID int `json:"vendor_id,omitempty"` | |
| QuotaType int `json:"quota_type"` | |
| ModelRatio float64 `json:"model_ratio"` | |
| ModelPrice float64 `json:"model_price"` | |
| OwnerBy string `json:"owner_by"` | |
| CompletionRatio float64 `json:"completion_ratio"` | |
| EnableGroup []string `json:"enable_groups"` | |
| SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` | |
| } | |
| type PricingVendor struct { | |
| ID int `json:"id"` | |
| Name string `json:"name"` | |
| Description string `json:"description,omitempty"` | |
| Icon string `json:"icon,omitempty"` | |
| } | |
| var ( | |
| pricingMap []Pricing | |
| vendorsList []PricingVendor | |
| supportedEndpointMap map[string]common.EndpointInfo | |
| lastGetPricingTime time.Time | |
| updatePricingLock sync.Mutex | |
| // 缓存映射:模型名 -> 启用分组 / 计费类型 | |
| modelEnableGroups = make(map[string][]string) | |
| modelQuotaTypeMap = make(map[string]int) | |
| modelEnableGroupsLock = sync.RWMutex{} | |
| ) | |
| var ( | |
| modelSupportEndpointTypes = make(map[string][]constant.EndpointType) | |
| modelSupportEndpointsLock = sync.RWMutex{} | |
| ) | |
| func GetPricing() []Pricing { | |
| if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { | |
| updatePricingLock.Lock() | |
| defer updatePricingLock.Unlock() | |
| // Double check after acquiring the lock | |
| if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { | |
| modelSupportEndpointsLock.Lock() | |
| defer modelSupportEndpointsLock.Unlock() | |
| updatePricing() | |
| } | |
| } | |
| return pricingMap | |
| } | |
| // GetVendors 返回当前定价接口使用到的供应商信息 | |
| func GetVendors() []PricingVendor { | |
| if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { | |
| // 保证先刷新一次 | |
| GetPricing() | |
| } | |
| return vendorsList | |
| } | |
| func GetModelSupportEndpointTypes(model string) []constant.EndpointType { | |
| if model == "" { | |
| return make([]constant.EndpointType, 0) | |
| } | |
| modelSupportEndpointsLock.RLock() | |
| defer modelSupportEndpointsLock.RUnlock() | |
| if endpoints, ok := modelSupportEndpointTypes[model]; ok { | |
| return endpoints | |
| } | |
| return make([]constant.EndpointType, 0) | |
| } | |
| func updatePricing() { | |
| //modelRatios := common.GetModelRatios() | |
| enableAbilities, err := GetAllEnableAbilityWithChannels() | |
| if err != nil { | |
| common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) | |
| return | |
| } | |
| // 预加载模型元数据与供应商一次,避免循环查询 | |
| var allMeta []Model | |
| _ = DB.Find(&allMeta).Error | |
| metaMap := make(map[string]*Model) | |
| prefixList := make([]*Model, 0) | |
| suffixList := make([]*Model, 0) | |
| containsList := make([]*Model, 0) | |
| for i := range allMeta { | |
| m := &allMeta[i] | |
| if m.NameRule == NameRuleExact { | |
| metaMap[m.ModelName] = m | |
| } else { | |
| switch m.NameRule { | |
| case NameRulePrefix: | |
| prefixList = append(prefixList, m) | |
| case NameRuleSuffix: | |
| suffixList = append(suffixList, m) | |
| case NameRuleContains: | |
| containsList = append(containsList, m) | |
| } | |
| } | |
| } | |
| // 将非精确规则模型匹配到 metaMap | |
| for _, m := range prefixList { | |
| for _, pricingModel := range enableAbilities { | |
| if strings.HasPrefix(pricingModel.Model, m.ModelName) { | |
| if _, exists := metaMap[pricingModel.Model]; !exists { | |
| metaMap[pricingModel.Model] = m | |
| } | |
| } | |
| } | |
| } | |
| for _, m := range suffixList { | |
| for _, pricingModel := range enableAbilities { | |
| if strings.HasSuffix(pricingModel.Model, m.ModelName) { | |
| if _, exists := metaMap[pricingModel.Model]; !exists { | |
| metaMap[pricingModel.Model] = m | |
| } | |
| } | |
| } | |
| } | |
| for _, m := range containsList { | |
| for _, pricingModel := range enableAbilities { | |
| if strings.Contains(pricingModel.Model, m.ModelName) { | |
| if _, exists := metaMap[pricingModel.Model]; !exists { | |
| metaMap[pricingModel.Model] = m | |
| } | |
| } | |
| } | |
| } | |
| // 预加载供应商 | |
| var vendors []Vendor | |
| _ = DB.Find(&vendors).Error | |
| vendorMap := make(map[int]*Vendor) | |
| for i := range vendors { | |
| vendorMap[vendors[i].Id] = &vendors[i] | |
| } | |
| // 初始化默认供应商映射 | |
| initDefaultVendorMapping(metaMap, vendorMap, enableAbilities) | |
| // 构建对前端友好的供应商列表 | |
| vendorsList = make([]PricingVendor, 0, len(vendorMap)) | |
| for _, v := range vendorMap { | |
| vendorsList = append(vendorsList, PricingVendor{ | |
| ID: v.Id, | |
| Name: v.Name, | |
| Description: v.Description, | |
| Icon: v.Icon, | |
| }) | |
| } | |
| modelGroupsMap := make(map[string]*types.Set[string]) | |
| for _, ability := range enableAbilities { | |
| groups, ok := modelGroupsMap[ability.Model] | |
| if !ok { | |
| groups = types.NewSet[string]() | |
| modelGroupsMap[ability.Model] = groups | |
| } | |
| groups.Add(ability.Group) | |
| } | |
| //这里使用切片而不是Set,因为一个模型可能支持多个端点类型,并且第一个端点是优先使用端点 | |
| modelSupportEndpointsStr := make(map[string][]string) | |
| // 先根据已有能力填充原生端点 | |
| for _, ability := range enableAbilities { | |
| endpoints := modelSupportEndpointsStr[ability.Model] | |
| channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) | |
| for _, channelType := range channelTypes { | |
| if !common.StringsContains(endpoints, string(channelType)) { | |
| endpoints = append(endpoints, string(channelType)) | |
| } | |
| } | |
| modelSupportEndpointsStr[ability.Model] = endpoints | |
| } | |
| // 再补充模型自定义端点 | |
| for modelName, meta := range metaMap { | |
| if strings.TrimSpace(meta.Endpoints) == "" { | |
| continue | |
| } | |
| var raw map[string]interface{} | |
| if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { | |
| endpoints := modelSupportEndpointsStr[modelName] | |
| for k := range raw { | |
| if !common.StringsContains(endpoints, k) { | |
| endpoints = append(endpoints, k) | |
| } | |
| } | |
| modelSupportEndpointsStr[modelName] = endpoints | |
| } | |
| } | |
| modelSupportEndpointTypes = make(map[string][]constant.EndpointType) | |
| for model, endpoints := range modelSupportEndpointsStr { | |
| supportedEndpoints := make([]constant.EndpointType, 0) | |
| for _, endpointStr := range endpoints { | |
| endpointType := constant.EndpointType(endpointStr) | |
| supportedEndpoints = append(supportedEndpoints, endpointType) | |
| } | |
| modelSupportEndpointTypes[model] = supportedEndpoints | |
| } | |
| // 构建全局 supportedEndpointMap(默认 + 自定义覆盖) | |
| supportedEndpointMap = make(map[string]common.EndpointInfo) | |
| // 1. 默认端点 | |
| for _, endpoints := range modelSupportEndpointTypes { | |
| for _, et := range endpoints { | |
| if info, ok := common.GetDefaultEndpointInfo(et); ok { | |
| if _, exists := supportedEndpointMap[string(et)]; !exists { | |
| supportedEndpointMap[string(et)] = info | |
| } | |
| } | |
| } | |
| } | |
| // 2. 自定义端点(models 表)覆盖默认 | |
| for _, meta := range metaMap { | |
| if strings.TrimSpace(meta.Endpoints) == "" { | |
| continue | |
| } | |
| var raw map[string]interface{} | |
| if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { | |
| for k, v := range raw { | |
| switch val := v.(type) { | |
| case string: | |
| supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} | |
| case map[string]interface{}: | |
| ep := common.EndpointInfo{Method: "POST"} | |
| if p, ok := val["path"].(string); ok { | |
| ep.Path = p | |
| } | |
| if m, ok := val["method"].(string); ok { | |
| ep.Method = strings.ToUpper(m) | |
| } | |
| supportedEndpointMap[k] = ep | |
| default: | |
| // ignore unsupported types | |
| } | |
| } | |
| } | |
| } | |
| pricingMap = make([]Pricing, 0) | |
| for model, groups := range modelGroupsMap { | |
| pricing := Pricing{ | |
| ModelName: model, | |
| EnableGroup: groups.Items(), | |
| SupportedEndpointTypes: modelSupportEndpointTypes[model], | |
| } | |
| // 补充模型元数据(描述、标签、供应商、状态) | |
| if meta, ok := metaMap[model]; ok { | |
| // 若模型被禁用(status!=1),则直接跳过,不返回给前端 | |
| if meta.Status != 1 { | |
| continue | |
| } | |
| pricing.Description = meta.Description | |
| pricing.Icon = meta.Icon | |
| pricing.Tags = meta.Tags | |
| pricing.VendorID = meta.VendorID | |
| } | |
| modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) | |
| if findPrice { | |
| pricing.ModelPrice = modelPrice | |
| pricing.QuotaType = 1 | |
| } else { | |
| modelRatio, _, _ := ratio_setting.GetModelRatio(model) | |
| pricing.ModelRatio = modelRatio | |
| pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) | |
| pricing.QuotaType = 0 | |
| } | |
| pricingMap = append(pricingMap, pricing) | |
| } | |
| // 刷新缓存映射,供高并发快速查询 | |
| modelEnableGroupsLock.Lock() | |
| modelEnableGroups = make(map[string][]string) | |
| modelQuotaTypeMap = make(map[string]int) | |
| for _, p := range pricingMap { | |
| modelEnableGroups[p.ModelName] = p.EnableGroup | |
| modelQuotaTypeMap[p.ModelName] = p.QuotaType | |
| } | |
| modelEnableGroupsLock.Unlock() | |
| lastGetPricingTime = time.Now() | |
| } | |
| // GetSupportedEndpointMap 返回全局端点到路径的映射 | |
| func GetSupportedEndpointMap() map[string]common.EndpointInfo { | |
| return supportedEndpointMap | |
| } | |