| package model |
|
|
| import ( |
| "encoding/json" |
| "fmt" |
| "strings" |
|
|
| "sync" |
| "time" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/QuantumNous/new-api/constant" |
| "github.com/QuantumNous/new-api/setting/ratio_setting" |
| "github.com/QuantumNous/new-api/types" |
| ) |
|
|
| type Pricing struct { |
| ModelName string `json:"model_name"` |
| Description string `json:"description,omitempty"` |
| Icon string `json:"icon,omitempty"` |
| Tags string `json:"tags,omitempty"` |
| VendorID int `json:"vendor_id,omitempty"` |
| QuotaType int `json:"quota_type"` |
| ModelRatio float64 `json:"model_ratio"` |
| ModelPrice float64 `json:"model_price"` |
| OwnerBy string `json:"owner_by"` |
| CompletionRatio float64 `json:"completion_ratio"` |
| EnableGroup []string `json:"enable_groups"` |
| SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` |
| } |
|
|
| type PricingVendor struct { |
| ID int `json:"id"` |
| Name string `json:"name"` |
| Description string `json:"description,omitempty"` |
| Icon string `json:"icon,omitempty"` |
| } |
|
|
| var ( |
| pricingMap []Pricing |
| vendorsList []PricingVendor |
| supportedEndpointMap map[string]common.EndpointInfo |
| lastGetPricingTime time.Time |
| updatePricingLock sync.Mutex |
|
|
| |
| modelEnableGroups = make(map[string][]string) |
| modelQuotaTypeMap = make(map[string]int) |
| modelEnableGroupsLock = sync.RWMutex{} |
| ) |
|
|
| var ( |
| modelSupportEndpointTypes = make(map[string][]constant.EndpointType) |
| modelSupportEndpointsLock = sync.RWMutex{} |
| ) |
|
|
| func GetPricing() []Pricing { |
| if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { |
| updatePricingLock.Lock() |
| defer updatePricingLock.Unlock() |
| |
| if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { |
| modelSupportEndpointsLock.Lock() |
| defer modelSupportEndpointsLock.Unlock() |
| updatePricing() |
| } |
| } |
| return pricingMap |
| } |
|
|
| |
| func GetVendors() []PricingVendor { |
| if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { |
| |
| GetPricing() |
| } |
| return vendorsList |
| } |
|
|
| func GetModelSupportEndpointTypes(model string) []constant.EndpointType { |
| if model == "" { |
| return make([]constant.EndpointType, 0) |
| } |
| modelSupportEndpointsLock.RLock() |
| defer modelSupportEndpointsLock.RUnlock() |
| if endpoints, ok := modelSupportEndpointTypes[model]; ok { |
| return endpoints |
| } |
| return make([]constant.EndpointType, 0) |
| } |
|
|
| func updatePricing() { |
| |
| enableAbilities, err := GetAllEnableAbilityWithChannels() |
| if err != nil { |
| common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) |
| return |
| } |
| |
| var allMeta []Model |
| _ = DB.Find(&allMeta).Error |
| metaMap := make(map[string]*Model) |
| prefixList := make([]*Model, 0) |
| suffixList := make([]*Model, 0) |
| containsList := make([]*Model, 0) |
| for i := range allMeta { |
| m := &allMeta[i] |
| if m.NameRule == NameRuleExact { |
| metaMap[m.ModelName] = m |
| } else { |
| switch m.NameRule { |
| case NameRulePrefix: |
| prefixList = append(prefixList, m) |
| case NameRuleSuffix: |
| suffixList = append(suffixList, m) |
| case NameRuleContains: |
| containsList = append(containsList, m) |
| } |
| } |
| } |
|
|
| |
| for _, m := range prefixList { |
| for _, pricingModel := range enableAbilities { |
| if strings.HasPrefix(pricingModel.Model, m.ModelName) { |
| if _, exists := metaMap[pricingModel.Model]; !exists { |
| metaMap[pricingModel.Model] = m |
| } |
| } |
| } |
| } |
| for _, m := range suffixList { |
| for _, pricingModel := range enableAbilities { |
| if strings.HasSuffix(pricingModel.Model, m.ModelName) { |
| if _, exists := metaMap[pricingModel.Model]; !exists { |
| metaMap[pricingModel.Model] = m |
| } |
| } |
| } |
| } |
| for _, m := range containsList { |
| for _, pricingModel := range enableAbilities { |
| if strings.Contains(pricingModel.Model, m.ModelName) { |
| if _, exists := metaMap[pricingModel.Model]; !exists { |
| metaMap[pricingModel.Model] = m |
| } |
| } |
| } |
| } |
|
|
| |
| var vendors []Vendor |
| _ = DB.Find(&vendors).Error |
| vendorMap := make(map[int]*Vendor) |
| for i := range vendors { |
| vendorMap[vendors[i].Id] = &vendors[i] |
| } |
|
|
| |
| initDefaultVendorMapping(metaMap, vendorMap, enableAbilities) |
|
|
| |
| vendorsList = make([]PricingVendor, 0, len(vendorMap)) |
| for _, v := range vendorMap { |
| vendorsList = append(vendorsList, PricingVendor{ |
| ID: v.Id, |
| Name: v.Name, |
| Description: v.Description, |
| Icon: v.Icon, |
| }) |
| } |
|
|
| modelGroupsMap := make(map[string]*types.Set[string]) |
|
|
| for _, ability := range enableAbilities { |
| groups, ok := modelGroupsMap[ability.Model] |
| if !ok { |
| groups = types.NewSet[string]() |
| modelGroupsMap[ability.Model] = groups |
| } |
| groups.Add(ability.Group) |
| } |
|
|
| |
| modelSupportEndpointsStr := make(map[string][]string) |
|
|
| |
| for _, ability := range enableAbilities { |
| endpoints := modelSupportEndpointsStr[ability.Model] |
| channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) |
| for _, channelType := range channelTypes { |
| if !common.StringsContains(endpoints, string(channelType)) { |
| endpoints = append(endpoints, string(channelType)) |
| } |
| } |
| modelSupportEndpointsStr[ability.Model] = endpoints |
| } |
|
|
| |
| for modelName, meta := range metaMap { |
| if strings.TrimSpace(meta.Endpoints) == "" { |
| continue |
| } |
| var raw map[string]interface{} |
| if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { |
| endpoints := modelSupportEndpointsStr[modelName] |
| for k := range raw { |
| if !common.StringsContains(endpoints, k) { |
| endpoints = append(endpoints, k) |
| } |
| } |
| modelSupportEndpointsStr[modelName] = endpoints |
| } |
| } |
|
|
| modelSupportEndpointTypes = make(map[string][]constant.EndpointType) |
| for model, endpoints := range modelSupportEndpointsStr { |
| supportedEndpoints := make([]constant.EndpointType, 0) |
| for _, endpointStr := range endpoints { |
| endpointType := constant.EndpointType(endpointStr) |
| supportedEndpoints = append(supportedEndpoints, endpointType) |
| } |
| modelSupportEndpointTypes[model] = supportedEndpoints |
| } |
|
|
| |
| supportedEndpointMap = make(map[string]common.EndpointInfo) |
| |
| for _, endpoints := range modelSupportEndpointTypes { |
| for _, et := range endpoints { |
| if info, ok := common.GetDefaultEndpointInfo(et); ok { |
| if _, exists := supportedEndpointMap[string(et)]; !exists { |
| supportedEndpointMap[string(et)] = info |
| } |
| } |
| } |
| } |
| |
| for _, meta := range metaMap { |
| if strings.TrimSpace(meta.Endpoints) == "" { |
| continue |
| } |
| var raw map[string]interface{} |
| if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { |
| for k, v := range raw { |
| switch val := v.(type) { |
| case string: |
| supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} |
| case map[string]interface{}: |
| ep := common.EndpointInfo{Method: "POST"} |
| if p, ok := val["path"].(string); ok { |
| ep.Path = p |
| } |
| if m, ok := val["method"].(string); ok { |
| ep.Method = strings.ToUpper(m) |
| } |
| supportedEndpointMap[k] = ep |
| default: |
| |
| } |
| } |
| } |
| } |
|
|
| pricingMap = make([]Pricing, 0) |
| for model, groups := range modelGroupsMap { |
| pricing := Pricing{ |
| ModelName: model, |
| EnableGroup: groups.Items(), |
| SupportedEndpointTypes: modelSupportEndpointTypes[model], |
| } |
|
|
| |
| if meta, ok := metaMap[model]; ok { |
| |
| if meta.Status != 1 { |
| continue |
| } |
| pricing.Description = meta.Description |
| pricing.Icon = meta.Icon |
| pricing.Tags = meta.Tags |
| pricing.VendorID = meta.VendorID |
| } |
| modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) |
| if findPrice { |
| pricing.ModelPrice = modelPrice |
| pricing.QuotaType = 1 |
| } else { |
| modelRatio, _, _ := ratio_setting.GetModelRatio(model) |
| pricing.ModelRatio = modelRatio |
| pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) |
| pricing.QuotaType = 0 |
| } |
| pricingMap = append(pricingMap, pricing) |
| } |
|
|
| |
| modelEnableGroupsLock.Lock() |
| modelEnableGroups = make(map[string][]string) |
| modelQuotaTypeMap = make(map[string]int) |
| for _, p := range pricingMap { |
| modelEnableGroups[p.ModelName] = p.EnableGroup |
| modelQuotaTypeMap[p.ModelName] = p.QuotaType |
| } |
| modelEnableGroupsLock.Unlock() |
|
|
| lastGetPricingTime = time.Now() |
| } |
|
|
| |
| func GetSupportedEndpointMap() map[string]common.EndpointInfo { |
| return supportedEndpointMap |
| } |
|
|