| | package controller |
| |
|
| | import ( |
| | "context" |
| | "encoding/json" |
| | "fmt" |
| | "io" |
| | "net" |
| | "net/http" |
| | "strings" |
| | "sync" |
| | "time" |
| |
|
| | "github.com/QuantumNous/new-api/logger" |
| |
|
| | "github.com/QuantumNous/new-api/dto" |
| | "github.com/QuantumNous/new-api/model" |
| | "github.com/QuantumNous/new-api/setting/ratio_setting" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | const ( |
| | defaultTimeoutSeconds = 10 |
| | defaultEndpoint = "/api/ratio_config" |
| | maxConcurrentFetches = 8 |
| | maxRatioConfigBytes = 10 << 20 |
| | floatEpsilon = 1e-9 |
| | ) |
| |
|
| | func nearlyEqual(a, b float64) bool { |
| | if a > b { |
| | return a-b < floatEpsilon |
| | } |
| | return b-a < floatEpsilon |
| | } |
| |
|
| | func valuesEqual(a, b interface{}) bool { |
| | af, aok := a.(float64) |
| | bf, bok := b.(float64) |
| | if aok && bok { |
| | return nearlyEqual(af, bf) |
| | } |
| | return a == b |
| | } |
| |
|
| | var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} |
| |
|
| | type upstreamResult struct { |
| | Name string `json:"name"` |
| | Data map[string]any `json:"data,omitempty"` |
| | Err string `json:"err,omitempty"` |
| | } |
| |
|
| | func FetchUpstreamRatios(c *gin.Context) { |
| | var req dto.UpstreamRequest |
| | if err := c.ShouldBindJSON(&req); err != nil { |
| | c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) |
| | return |
| | } |
| |
|
| | if req.Timeout <= 0 { |
| | req.Timeout = defaultTimeoutSeconds |
| | } |
| |
|
| | var upstreams []dto.UpstreamDTO |
| |
|
| | if len(req.Upstreams) > 0 { |
| | for _, u := range req.Upstreams { |
| | if strings.HasPrefix(u.BaseURL, "http") { |
| | if u.Endpoint == "" { |
| | u.Endpoint = defaultEndpoint |
| | } |
| | u.BaseURL = strings.TrimRight(u.BaseURL, "/") |
| | upstreams = append(upstreams, u) |
| | } |
| | } |
| | } else if len(req.ChannelIDs) > 0 { |
| | intIds := make([]int, 0, len(req.ChannelIDs)) |
| | for _, id64 := range req.ChannelIDs { |
| | intIds = append(intIds, int(id64)) |
| | } |
| | dbChannels, err := model.GetChannelsByIds(intIds) |
| | if err != nil { |
| | logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) |
| | c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) |
| | return |
| | } |
| | for _, ch := range dbChannels { |
| | if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { |
| | upstreams = append(upstreams, dto.UpstreamDTO{ |
| | ID: ch.Id, |
| | Name: ch.Name, |
| | BaseURL: strings.TrimRight(base, "/"), |
| | Endpoint: "", |
| | }) |
| | } |
| | } |
| | } |
| |
|
| | if len(upstreams) == 0 { |
| | c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) |
| | return |
| | } |
| |
|
| | var wg sync.WaitGroup |
| | ch := make(chan upstreamResult, len(upstreams)) |
| |
|
| | sem := make(chan struct{}, maxConcurrentFetches) |
| |
|
| | dialer := &net.Dialer{Timeout: 10 * time.Second} |
| | transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second} |
| | transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { |
| | host, _, err := net.SplitHostPort(addr) |
| | if err != nil { |
| | host = addr |
| | } |
| | |
| | if strings.HasSuffix(host, "github.io") { |
| | if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { |
| | return conn, nil |
| | } |
| | return dialer.DialContext(ctx, "tcp6", addr) |
| | } |
| | return dialer.DialContext(ctx, network, addr) |
| | } |
| | client := &http.Client{Transport: transport} |
| |
|
| | for _, chn := range upstreams { |
| | wg.Add(1) |
| | go func(chItem dto.UpstreamDTO) { |
| | defer wg.Done() |
| |
|
| | sem <- struct{}{} |
| | defer func() { <-sem }() |
| |
|
| | endpoint := chItem.Endpoint |
| | var fullURL string |
| | if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { |
| | fullURL = endpoint |
| | } else { |
| | if endpoint == "" { |
| | endpoint = defaultEndpoint |
| | } else if !strings.HasPrefix(endpoint, "/") { |
| | endpoint = "/" + endpoint |
| | } |
| | fullURL = chItem.BaseURL + endpoint |
| | } |
| |
|
| | uniqueName := chItem.Name |
| | if chItem.ID != 0 { |
| | uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) |
| | } |
| |
|
| | ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) |
| | defer cancel() |
| |
|
| | httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) |
| | if err != nil { |
| | logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) |
| | ch <- upstreamResult{Name: uniqueName, Err: err.Error()} |
| | return |
| | } |
| |
|
| | |
| | var resp *http.Response |
| | var lastErr error |
| | for attempt := 0; attempt < 3; attempt++ { |
| | resp, lastErr = client.Do(httpReq) |
| | if lastErr == nil { |
| | break |
| | } |
| | time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond) |
| | } |
| | if lastErr != nil { |
| | logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error()) |
| | ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()} |
| | return |
| | } |
| | defer resp.Body.Close() |
| | if resp.StatusCode != http.StatusOK { |
| | logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) |
| | ch <- upstreamResult{Name: uniqueName, Err: resp.Status} |
| | return |
| | } |
| |
|
| | |
| | if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") { |
| | logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct) |
| | } |
| | limited := io.LimitReader(resp.Body, maxRatioConfigBytes) |
| | |
| | |
| | |
| | var body struct { |
| | Success bool `json:"success"` |
| | Data json.RawMessage `json:"data"` |
| | Message string `json:"message"` |
| | } |
| |
|
| | if err := json.NewDecoder(limited).Decode(&body); err != nil { |
| | logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) |
| | ch <- upstreamResult{Name: uniqueName, Err: err.Error()} |
| | return |
| | } |
| |
|
| | if !body.Success { |
| | ch <- upstreamResult{Name: uniqueName, Err: body.Message} |
| | return |
| | } |
| |
|
| | |
| |
|
| | |
| | var type1Data map[string]any |
| | if err := json.Unmarshal(body.Data, &type1Data); err == nil { |
| | |
| | isType1 := false |
| | for _, rt := range ratioTypes { |
| | if _, ok := type1Data[rt]; ok { |
| | isType1 = true |
| | break |
| | } |
| | } |
| | if isType1 { |
| | ch <- upstreamResult{Name: uniqueName, Data: type1Data} |
| | return |
| | } |
| | } |
| |
|
| | |
| | var pricingItems []struct { |
| | ModelName string `json:"model_name"` |
| | QuotaType int `json:"quota_type"` |
| | ModelRatio float64 `json:"model_ratio"` |
| | ModelPrice float64 `json:"model_price"` |
| | CompletionRatio float64 `json:"completion_ratio"` |
| | } |
| | if err := json.Unmarshal(body.Data, &pricingItems); err != nil { |
| | logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) |
| | ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} |
| | return |
| | } |
| |
|
| | modelRatioMap := make(map[string]float64) |
| | completionRatioMap := make(map[string]float64) |
| | modelPriceMap := make(map[string]float64) |
| |
|
| | for _, item := range pricingItems { |
| | if item.QuotaType == 1 { |
| | modelPriceMap[item.ModelName] = item.ModelPrice |
| | } else { |
| | modelRatioMap[item.ModelName] = item.ModelRatio |
| | |
| | completionRatioMap[item.ModelName] = item.CompletionRatio |
| | } |
| | } |
| |
|
| | converted := make(map[string]any) |
| |
|
| | if len(modelRatioMap) > 0 { |
| | ratioAny := make(map[string]any, len(modelRatioMap)) |
| | for k, v := range modelRatioMap { |
| | ratioAny[k] = v |
| | } |
| | converted["model_ratio"] = ratioAny |
| | } |
| |
|
| | if len(completionRatioMap) > 0 { |
| | compAny := make(map[string]any, len(completionRatioMap)) |
| | for k, v := range completionRatioMap { |
| | compAny[k] = v |
| | } |
| | converted["completion_ratio"] = compAny |
| | } |
| |
|
| | if len(modelPriceMap) > 0 { |
| | priceAny := make(map[string]any, len(modelPriceMap)) |
| | for k, v := range modelPriceMap { |
| | priceAny[k] = v |
| | } |
| | converted["model_price"] = priceAny |
| | } |
| |
|
| | ch <- upstreamResult{Name: uniqueName, Data: converted} |
| | }(chn) |
| | } |
| |
|
| | wg.Wait() |
| | close(ch) |
| |
|
| | localData := ratio_setting.GetExposedData() |
| |
|
| | var testResults []dto.TestResult |
| | var successfulChannels []struct { |
| | name string |
| | data map[string]any |
| | } |
| |
|
| | for r := range ch { |
| | if r.Err != "" { |
| | testResults = append(testResults, dto.TestResult{ |
| | Name: r.Name, |
| | Status: "error", |
| | Error: r.Err, |
| | }) |
| | } else { |
| | testResults = append(testResults, dto.TestResult{ |
| | Name: r.Name, |
| | Status: "success", |
| | }) |
| | successfulChannels = append(successfulChannels, struct { |
| | name string |
| | data map[string]any |
| | }{name: r.Name, data: r.Data}) |
| | } |
| | } |
| |
|
| | differences := buildDifferences(localData, successfulChannels) |
| |
|
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "data": gin.H{ |
| | "differences": differences, |
| | "test_results": testResults, |
| | }, |
| | }) |
| | } |
| |
|
| | func buildDifferences(localData map[string]any, successfulChannels []struct { |
| | name string |
| | data map[string]any |
| | }) map[string]map[string]dto.DifferenceItem { |
| | differences := make(map[string]map[string]dto.DifferenceItem) |
| |
|
| | allModels := make(map[string]struct{}) |
| |
|
| | for _, ratioType := range ratioTypes { |
| | if localRatioAny, ok := localData[ratioType]; ok { |
| | if localRatio, ok := localRatioAny.(map[string]float64); ok { |
| | for modelName := range localRatio { |
| | allModels[modelName] = struct{}{} |
| | } |
| | } |
| | } |
| | } |
| |
|
| | for _, channel := range successfulChannels { |
| | for _, ratioType := range ratioTypes { |
| | if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { |
| | for modelName := range upstreamRatio { |
| | allModels[modelName] = struct{}{} |
| | } |
| | } |
| | } |
| | } |
| |
|
| | confidenceMap := make(map[string]map[string]bool) |
| |
|
| | |
| | for _, channel := range successfulChannels { |
| | confidenceMap[channel.name] = make(map[string]bool) |
| |
|
| | modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) |
| | completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) |
| |
|
| | if hasModelRatio && hasCompletionRatio { |
| | |
| | for modelName := range allModels { |
| | |
| | confidenceMap[channel.name][modelName] = true |
| |
|
| | |
| | if modelRatioVal, ok := modelRatios[modelName]; ok { |
| | if completionRatioVal, ok := completionRatios[modelName]; ok { |
| | |
| | if modelRatioFloat, ok := modelRatioVal.(float64); ok { |
| | if completionRatioFloat, ok := completionRatioVal.(float64); ok { |
| | if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { |
| | confidenceMap[channel.name][modelName] = false |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } else { |
| | |
| | for modelName := range allModels { |
| | confidenceMap[channel.name][modelName] = true |
| | } |
| | } |
| | } |
| |
|
| | for modelName := range allModels { |
| | for _, ratioType := range ratioTypes { |
| | var localValue interface{} = nil |
| | if localRatioAny, ok := localData[ratioType]; ok { |
| | if localRatio, ok := localRatioAny.(map[string]float64); ok { |
| | if val, exists := localRatio[modelName]; exists { |
| | localValue = val |
| | } |
| | } |
| | } |
| |
|
| | upstreamValues := make(map[string]interface{}) |
| | confidenceValues := make(map[string]bool) |
| | hasUpstreamValue := false |
| | hasDifference := false |
| |
|
| | for _, channel := range successfulChannels { |
| | var upstreamValue interface{} = nil |
| |
|
| | if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { |
| | if val, exists := upstreamRatio[modelName]; exists { |
| | upstreamValue = val |
| | hasUpstreamValue = true |
| |
|
| | if localValue != nil && !valuesEqual(localValue, val) { |
| | hasDifference = true |
| | } else if valuesEqual(localValue, val) { |
| | upstreamValue = "same" |
| | } |
| | } |
| | } |
| | if upstreamValue == nil && localValue == nil { |
| | upstreamValue = "same" |
| | } |
| |
|
| | if localValue == nil && upstreamValue != nil && upstreamValue != "same" { |
| | hasDifference = true |
| | } |
| |
|
| | upstreamValues[channel.name] = upstreamValue |
| |
|
| | confidenceValues[channel.name] = confidenceMap[channel.name][modelName] |
| | } |
| |
|
| | shouldInclude := false |
| |
|
| | if localValue != nil { |
| | if hasDifference { |
| | shouldInclude = true |
| | } |
| | } else { |
| | if hasUpstreamValue { |
| | shouldInclude = true |
| | } |
| | } |
| |
|
| | if shouldInclude { |
| | if differences[modelName] == nil { |
| | differences[modelName] = make(map[string]dto.DifferenceItem) |
| | } |
| | differences[modelName][ratioType] = dto.DifferenceItem{ |
| | Current: localValue, |
| | Upstreams: upstreamValues, |
| | Confidence: confidenceValues, |
| | } |
| | } |
| | } |
| | } |
| |
|
| | channelHasDiff := make(map[string]bool) |
| | for _, ratioMap := range differences { |
| | for _, item := range ratioMap { |
| | for chName, val := range item.Upstreams { |
| | if val != nil && val != "same" { |
| | channelHasDiff[chName] = true |
| | } |
| | } |
| | } |
| | } |
| |
|
| | for modelName, ratioMap := range differences { |
| | for ratioType, item := range ratioMap { |
| | for chName := range item.Upstreams { |
| | if !channelHasDiff[chName] { |
| | delete(item.Upstreams, chName) |
| | delete(item.Confidence, chName) |
| | } |
| | } |
| |
|
| | allSame := true |
| | for _, v := range item.Upstreams { |
| | if v != "same" { |
| | allSame = false |
| | break |
| | } |
| | } |
| | if len(item.Upstreams) == 0 || allSame { |
| | delete(ratioMap, ratioType) |
| | } else { |
| | differences[modelName][ratioType] = item |
| | } |
| | } |
| |
|
| | if len(ratioMap) == 0 { |
| | delete(differences, modelName) |
| | } |
| | } |
| |
|
| | return differences |
| | } |
| |
|
| | func GetSyncableChannels(c *gin.Context) { |
| | channels, err := model.GetAllChannels(0, 0, true, false) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| |
|
| | var syncableChannels []dto.SyncableChannel |
| | for _, channel := range channels { |
| | if channel.GetBaseURL() != "" { |
| | syncableChannels = append(syncableChannels, dto.SyncableChannel{ |
| | ID: channel.Id, |
| | Name: channel.Name, |
| | BaseURL: channel.GetBaseURL(), |
| | Status: channel.Status, |
| | }) |
| | } |
| | } |
| |
|
| | syncableChannels = append(syncableChannels, dto.SyncableChannel{ |
| | ID: -100, |
| | Name: "官方倍率预设", |
| | BaseURL: "https://basellm.github.io", |
| | Status: 1, |
| | }) |
| |
|
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "data": syncableChannels, |
| | }) |
| | } |
| |
|