|
|
package model |
|
|
|
|
|
import ( |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"strings" |
|
|
|
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/constant" |
|
|
"github.com/QuantumNous/new-api/setting/ratio_setting" |
|
|
"github.com/QuantumNous/new-api/types" |
|
|
) |
|
|
|
|
|
type Pricing struct { |
|
|
ModelName string `json:"model_name"` |
|
|
Description string `json:"description,omitempty"` |
|
|
Icon string `json:"icon,omitempty"` |
|
|
Tags string `json:"tags,omitempty"` |
|
|
VendorID int `json:"vendor_id,omitempty"` |
|
|
QuotaType int `json:"quota_type"` |
|
|
ModelRatio float64 `json:"model_ratio"` |
|
|
ModelPrice float64 `json:"model_price"` |
|
|
OwnerBy string `json:"owner_by"` |
|
|
CompletionRatio float64 `json:"completion_ratio"` |
|
|
EnableGroup []string `json:"enable_groups"` |
|
|
SupportedEndpointTypes []constant.EndpointType `json:"supported_endpoint_types"` |
|
|
} |
|
|
|
|
|
type PricingVendor struct { |
|
|
ID int `json:"id"` |
|
|
Name string `json:"name"` |
|
|
Description string `json:"description,omitempty"` |
|
|
Icon string `json:"icon,omitempty"` |
|
|
} |
|
|
|
|
|
var ( |
|
|
pricingMap []Pricing |
|
|
vendorsList []PricingVendor |
|
|
supportedEndpointMap map[string]common.EndpointInfo |
|
|
lastGetPricingTime time.Time |
|
|
updatePricingLock sync.Mutex |
|
|
|
|
|
|
|
|
modelEnableGroups = make(map[string][]string) |
|
|
modelQuotaTypeMap = make(map[string]int) |
|
|
modelEnableGroupsLock = sync.RWMutex{} |
|
|
) |
|
|
|
|
|
var ( |
|
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType) |
|
|
modelSupportEndpointsLock = sync.RWMutex{} |
|
|
) |
|
|
|
|
|
func GetPricing() []Pricing { |
|
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { |
|
|
updatePricingLock.Lock() |
|
|
defer updatePricingLock.Unlock() |
|
|
|
|
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { |
|
|
modelSupportEndpointsLock.Lock() |
|
|
defer modelSupportEndpointsLock.Unlock() |
|
|
updatePricing() |
|
|
} |
|
|
} |
|
|
return pricingMap |
|
|
} |
|
|
|
|
|
|
|
|
func GetVendors() []PricingVendor { |
|
|
if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { |
|
|
|
|
|
GetPricing() |
|
|
} |
|
|
return vendorsList |
|
|
} |
|
|
|
|
|
func GetModelSupportEndpointTypes(model string) []constant.EndpointType { |
|
|
if model == "" { |
|
|
return make([]constant.EndpointType, 0) |
|
|
} |
|
|
modelSupportEndpointsLock.RLock() |
|
|
defer modelSupportEndpointsLock.RUnlock() |
|
|
if endpoints, ok := modelSupportEndpointTypes[model]; ok { |
|
|
return endpoints |
|
|
} |
|
|
return make([]constant.EndpointType, 0) |
|
|
} |
|
|
|
|
|
func updatePricing() { |
|
|
|
|
|
enableAbilities, err := GetAllEnableAbilityWithChannels() |
|
|
if err != nil { |
|
|
common.SysLog(fmt.Sprintf("GetAllEnableAbilityWithChannels error: %v", err)) |
|
|
return |
|
|
} |
|
|
|
|
|
var allMeta []Model |
|
|
_ = DB.Find(&allMeta).Error |
|
|
metaMap := make(map[string]*Model) |
|
|
prefixList := make([]*Model, 0) |
|
|
suffixList := make([]*Model, 0) |
|
|
containsList := make([]*Model, 0) |
|
|
for i := range allMeta { |
|
|
m := &allMeta[i] |
|
|
if m.NameRule == NameRuleExact { |
|
|
metaMap[m.ModelName] = m |
|
|
} else { |
|
|
switch m.NameRule { |
|
|
case NameRulePrefix: |
|
|
prefixList = append(prefixList, m) |
|
|
case NameRuleSuffix: |
|
|
suffixList = append(suffixList, m) |
|
|
case NameRuleContains: |
|
|
containsList = append(containsList, m) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for _, m := range prefixList { |
|
|
for _, pricingModel := range enableAbilities { |
|
|
if strings.HasPrefix(pricingModel.Model, m.ModelName) { |
|
|
if _, exists := metaMap[pricingModel.Model]; !exists { |
|
|
metaMap[pricingModel.Model] = m |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
for _, m := range suffixList { |
|
|
for _, pricingModel := range enableAbilities { |
|
|
if strings.HasSuffix(pricingModel.Model, m.ModelName) { |
|
|
if _, exists := metaMap[pricingModel.Model]; !exists { |
|
|
metaMap[pricingModel.Model] = m |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
for _, m := range containsList { |
|
|
for _, pricingModel := range enableAbilities { |
|
|
if strings.Contains(pricingModel.Model, m.ModelName) { |
|
|
if _, exists := metaMap[pricingModel.Model]; !exists { |
|
|
metaMap[pricingModel.Model] = m |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var vendors []Vendor |
|
|
_ = DB.Find(&vendors).Error |
|
|
vendorMap := make(map[int]*Vendor) |
|
|
for i := range vendors { |
|
|
vendorMap[vendors[i].Id] = &vendors[i] |
|
|
} |
|
|
|
|
|
|
|
|
initDefaultVendorMapping(metaMap, vendorMap, enableAbilities) |
|
|
|
|
|
|
|
|
vendorsList = make([]PricingVendor, 0, len(vendorMap)) |
|
|
for _, v := range vendorMap { |
|
|
vendorsList = append(vendorsList, PricingVendor{ |
|
|
ID: v.Id, |
|
|
Name: v.Name, |
|
|
Description: v.Description, |
|
|
Icon: v.Icon, |
|
|
}) |
|
|
} |
|
|
|
|
|
modelGroupsMap := make(map[string]*types.Set[string]) |
|
|
|
|
|
for _, ability := range enableAbilities { |
|
|
groups, ok := modelGroupsMap[ability.Model] |
|
|
if !ok { |
|
|
groups = types.NewSet[string]() |
|
|
modelGroupsMap[ability.Model] = groups |
|
|
} |
|
|
groups.Add(ability.Group) |
|
|
} |
|
|
|
|
|
|
|
|
modelSupportEndpointsStr := make(map[string][]string) |
|
|
|
|
|
|
|
|
for _, ability := range enableAbilities { |
|
|
endpoints := modelSupportEndpointsStr[ability.Model] |
|
|
channelTypes := common.GetEndpointTypesByChannelType(ability.ChannelType, ability.Model) |
|
|
for _, channelType := range channelTypes { |
|
|
if !common.StringsContains(endpoints, string(channelType)) { |
|
|
endpoints = append(endpoints, string(channelType)) |
|
|
} |
|
|
} |
|
|
modelSupportEndpointsStr[ability.Model] = endpoints |
|
|
} |
|
|
|
|
|
|
|
|
for modelName, meta := range metaMap { |
|
|
if strings.TrimSpace(meta.Endpoints) == "" { |
|
|
continue |
|
|
} |
|
|
var raw map[string]interface{} |
|
|
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { |
|
|
endpoints := modelSupportEndpointsStr[modelName] |
|
|
for k := range raw { |
|
|
if !common.StringsContains(endpoints, k) { |
|
|
endpoints = append(endpoints, k) |
|
|
} |
|
|
} |
|
|
modelSupportEndpointsStr[modelName] = endpoints |
|
|
} |
|
|
} |
|
|
|
|
|
modelSupportEndpointTypes = make(map[string][]constant.EndpointType) |
|
|
for model, endpoints := range modelSupportEndpointsStr { |
|
|
supportedEndpoints := make([]constant.EndpointType, 0) |
|
|
for _, endpointStr := range endpoints { |
|
|
endpointType := constant.EndpointType(endpointStr) |
|
|
supportedEndpoints = append(supportedEndpoints, endpointType) |
|
|
} |
|
|
modelSupportEndpointTypes[model] = supportedEndpoints |
|
|
} |
|
|
|
|
|
|
|
|
supportedEndpointMap = make(map[string]common.EndpointInfo) |
|
|
|
|
|
for _, endpoints := range modelSupportEndpointTypes { |
|
|
for _, et := range endpoints { |
|
|
if info, ok := common.GetDefaultEndpointInfo(et); ok { |
|
|
if _, exists := supportedEndpointMap[string(et)]; !exists { |
|
|
supportedEndpointMap[string(et)] = info |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
for _, meta := range metaMap { |
|
|
if strings.TrimSpace(meta.Endpoints) == "" { |
|
|
continue |
|
|
} |
|
|
var raw map[string]interface{} |
|
|
if err := json.Unmarshal([]byte(meta.Endpoints), &raw); err == nil { |
|
|
for k, v := range raw { |
|
|
switch val := v.(type) { |
|
|
case string: |
|
|
supportedEndpointMap[k] = common.EndpointInfo{Path: val, Method: "POST"} |
|
|
case map[string]interface{}: |
|
|
ep := common.EndpointInfo{Method: "POST"} |
|
|
if p, ok := val["path"].(string); ok { |
|
|
ep.Path = p |
|
|
} |
|
|
if m, ok := val["method"].(string); ok { |
|
|
ep.Method = strings.ToUpper(m) |
|
|
} |
|
|
supportedEndpointMap[k] = ep |
|
|
default: |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
pricingMap = make([]Pricing, 0) |
|
|
for model, groups := range modelGroupsMap { |
|
|
pricing := Pricing{ |
|
|
ModelName: model, |
|
|
EnableGroup: groups.Items(), |
|
|
SupportedEndpointTypes: modelSupportEndpointTypes[model], |
|
|
} |
|
|
|
|
|
|
|
|
if meta, ok := metaMap[model]; ok { |
|
|
|
|
|
if meta.Status != 1 { |
|
|
continue |
|
|
} |
|
|
pricing.Description = meta.Description |
|
|
pricing.Icon = meta.Icon |
|
|
pricing.Tags = meta.Tags |
|
|
pricing.VendorID = meta.VendorID |
|
|
} |
|
|
modelPrice, findPrice := ratio_setting.GetModelPrice(model, false) |
|
|
if findPrice { |
|
|
pricing.ModelPrice = modelPrice |
|
|
pricing.QuotaType = 1 |
|
|
} else { |
|
|
modelRatio, _, _ := ratio_setting.GetModelRatio(model) |
|
|
pricing.ModelRatio = modelRatio |
|
|
pricing.CompletionRatio = ratio_setting.GetCompletionRatio(model) |
|
|
pricing.QuotaType = 0 |
|
|
} |
|
|
pricingMap = append(pricingMap, pricing) |
|
|
} |
|
|
|
|
|
|
|
|
modelEnableGroupsLock.Lock() |
|
|
modelEnableGroups = make(map[string][]string) |
|
|
modelQuotaTypeMap = make(map[string]int) |
|
|
for _, p := range pricingMap { |
|
|
modelEnableGroups[p.ModelName] = p.EnableGroup |
|
|
modelQuotaTypeMap[p.ModelName] = p.QuotaType |
|
|
} |
|
|
modelEnableGroupsLock.Unlock() |
|
|
|
|
|
lastGetPricingTime = time.Now() |
|
|
} |
|
|
|
|
|
|
|
|
func GetSupportedEndpointMap() map[string]common.EndpointInfo { |
|
|
return supportedEndpointMap |
|
|
} |
|
|
|