| package controller
|
|
|
| import (
|
| "context"
|
| "encoding/json"
|
| "fmt"
|
| "io"
|
| "net"
|
| "net/http"
|
| "strings"
|
| "sync"
|
| "time"
|
|
|
| "github.com/QuantumNous/new-api/common"
|
| "github.com/QuantumNous/new-api/logger"
|
|
|
| "github.com/QuantumNous/new-api/dto"
|
| "github.com/QuantumNous/new-api/model"
|
| "github.com/QuantumNous/new-api/setting/ratio_setting"
|
|
|
| "github.com/gin-gonic/gin"
|
| )
|
|
|
| const (
|
| defaultTimeoutSeconds = 10
|
| defaultEndpoint = "/api/ratio_config"
|
| maxConcurrentFetches = 8
|
| maxRatioConfigBytes = 10 << 20
|
| floatEpsilon = 1e-9
|
| )
|
|
|
| func nearlyEqual(a, b float64) bool {
|
| if a > b {
|
| return a-b < floatEpsilon
|
| }
|
| return b-a < floatEpsilon
|
| }
|
|
|
| func valuesEqual(a, b interface{}) bool {
|
| af, aok := a.(float64)
|
| bf, bok := b.(float64)
|
| if aok && bok {
|
| return nearlyEqual(af, bf)
|
| }
|
| return a == b
|
| }
|
|
|
| var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"}
|
|
|
| type upstreamResult struct {
|
| Name string `json:"name"`
|
| Data map[string]any `json:"data,omitempty"`
|
| Err string `json:"err,omitempty"`
|
| }
|
|
|
| func FetchUpstreamRatios(c *gin.Context) {
|
| var req dto.UpstreamRequest
|
| if err := c.ShouldBindJSON(&req); err != nil {
|
| c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()})
|
| return
|
| }
|
|
|
| if req.Timeout <= 0 {
|
| req.Timeout = defaultTimeoutSeconds
|
| }
|
|
|
| var upstreams []dto.UpstreamDTO
|
|
|
| if len(req.Upstreams) > 0 {
|
| for _, u := range req.Upstreams {
|
| if strings.HasPrefix(u.BaseURL, "http") {
|
| if u.Endpoint == "" {
|
| u.Endpoint = defaultEndpoint
|
| }
|
| u.BaseURL = strings.TrimRight(u.BaseURL, "/")
|
| upstreams = append(upstreams, u)
|
| }
|
| }
|
| } else if len(req.ChannelIDs) > 0 {
|
| intIds := make([]int, 0, len(req.ChannelIDs))
|
| for _, id64 := range req.ChannelIDs {
|
| intIds = append(intIds, int(id64))
|
| }
|
| dbChannels, err := model.GetChannelsByIds(intIds)
|
| if err != nil {
|
| logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error())
|
| c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"})
|
| return
|
| }
|
| for _, ch := range dbChannels {
|
| if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") {
|
| upstreams = append(upstreams, dto.UpstreamDTO{
|
| ID: ch.Id,
|
| Name: ch.Name,
|
| BaseURL: strings.TrimRight(base, "/"),
|
| Endpoint: "",
|
| })
|
| }
|
| }
|
| }
|
|
|
| if len(upstreams) == 0 {
|
| c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"})
|
| return
|
| }
|
|
|
| var wg sync.WaitGroup
|
| ch := make(chan upstreamResult, len(upstreams))
|
|
|
| sem := make(chan struct{}, maxConcurrentFetches)
|
|
|
| dialer := &net.Dialer{Timeout: 10 * time.Second}
|
| transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second}
|
| if common.TLSInsecureSkipVerify {
|
| transport.TLSClientConfig = common.InsecureTLSConfig
|
| }
|
| transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
| host, _, err := net.SplitHostPort(addr)
|
| if err != nil {
|
| host = addr
|
| }
|
|
|
| if strings.HasSuffix(host, "github.io") {
|
| if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil {
|
| return conn, nil
|
| }
|
| return dialer.DialContext(ctx, "tcp6", addr)
|
| }
|
| return dialer.DialContext(ctx, network, addr)
|
| }
|
| client := &http.Client{Transport: transport}
|
|
|
| for _, chn := range upstreams {
|
| wg.Add(1)
|
| go func(chItem dto.UpstreamDTO) {
|
| defer wg.Done()
|
|
|
| sem <- struct{}{}
|
| defer func() { <-sem }()
|
|
|
| endpoint := chItem.Endpoint
|
| var fullURL string
|
| if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
| fullURL = endpoint
|
| } else {
|
| if endpoint == "" {
|
| endpoint = defaultEndpoint
|
| } else if !strings.HasPrefix(endpoint, "/") {
|
| endpoint = "/" + endpoint
|
| }
|
| fullURL = chItem.BaseURL + endpoint
|
| }
|
|
|
| uniqueName := chItem.Name
|
| if chItem.ID != 0 {
|
| uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID)
|
| }
|
|
|
| ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second)
|
| defer cancel()
|
|
|
| httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil)
|
| if err != nil {
|
| logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error())
|
| ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
| return
|
| }
|
|
|
|
|
| var resp *http.Response
|
| var lastErr error
|
| for attempt := 0; attempt < 3; attempt++ {
|
| resp, lastErr = client.Do(httpReq)
|
| if lastErr == nil {
|
| break
|
| }
|
| time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond)
|
| }
|
| if lastErr != nil {
|
| logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error())
|
| ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()}
|
| return
|
| }
|
| defer resp.Body.Close()
|
| if resp.StatusCode != http.StatusOK {
|
| logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status)
|
| ch <- upstreamResult{Name: uniqueName, Err: resp.Status}
|
| return
|
| }
|
|
|
|
|
| if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") {
|
| logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct)
|
| }
|
| limited := io.LimitReader(resp.Body, maxRatioConfigBytes)
|
|
|
|
|
|
|
| var body struct {
|
| Success bool `json:"success"`
|
| Data json.RawMessage `json:"data"`
|
| Message string `json:"message"`
|
| }
|
|
|
| if err := json.NewDecoder(limited).Decode(&body); err != nil {
|
| logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error())
|
| ch <- upstreamResult{Name: uniqueName, Err: err.Error()}
|
| return
|
| }
|
|
|
| if !body.Success {
|
| ch <- upstreamResult{Name: uniqueName, Err: body.Message}
|
| return
|
| }
|
|
|
|
|
|
|
|
|
| var type1Data map[string]any
|
| if err := json.Unmarshal(body.Data, &type1Data); err == nil {
|
|
|
| isType1 := false
|
| for _, rt := range ratioTypes {
|
| if _, ok := type1Data[rt]; ok {
|
| isType1 = true
|
| break
|
| }
|
| }
|
| if isType1 {
|
| ch <- upstreamResult{Name: uniqueName, Data: type1Data}
|
| return
|
| }
|
| }
|
|
|
|
|
| var pricingItems []struct {
|
| ModelName string `json:"model_name"`
|
| QuotaType int `json:"quota_type"`
|
| ModelRatio float64 `json:"model_ratio"`
|
| ModelPrice float64 `json:"model_price"`
|
| CompletionRatio float64 `json:"completion_ratio"`
|
| }
|
| if err := json.Unmarshal(body.Data, &pricingItems); err != nil {
|
| logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error())
|
| ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"}
|
| return
|
| }
|
|
|
| modelRatioMap := make(map[string]float64)
|
| completionRatioMap := make(map[string]float64)
|
| modelPriceMap := make(map[string]float64)
|
|
|
| for _, item := range pricingItems {
|
| if item.QuotaType == 1 {
|
| modelPriceMap[item.ModelName] = item.ModelPrice
|
| } else {
|
| modelRatioMap[item.ModelName] = item.ModelRatio
|
|
|
| completionRatioMap[item.ModelName] = item.CompletionRatio
|
| }
|
| }
|
|
|
| converted := make(map[string]any)
|
|
|
| if len(modelRatioMap) > 0 {
|
| ratioAny := make(map[string]any, len(modelRatioMap))
|
| for k, v := range modelRatioMap {
|
| ratioAny[k] = v
|
| }
|
| converted["model_ratio"] = ratioAny
|
| }
|
|
|
| if len(completionRatioMap) > 0 {
|
| compAny := make(map[string]any, len(completionRatioMap))
|
| for k, v := range completionRatioMap {
|
| compAny[k] = v
|
| }
|
| converted["completion_ratio"] = compAny
|
| }
|
|
|
| if len(modelPriceMap) > 0 {
|
| priceAny := make(map[string]any, len(modelPriceMap))
|
| for k, v := range modelPriceMap {
|
| priceAny[k] = v
|
| }
|
| converted["model_price"] = priceAny
|
| }
|
|
|
| ch <- upstreamResult{Name: uniqueName, Data: converted}
|
| }(chn)
|
| }
|
|
|
| wg.Wait()
|
| close(ch)
|
|
|
| localData := ratio_setting.GetExposedData()
|
|
|
| var testResults []dto.TestResult
|
| var successfulChannels []struct {
|
| name string
|
| data map[string]any
|
| }
|
|
|
| for r := range ch {
|
| if r.Err != "" {
|
| testResults = append(testResults, dto.TestResult{
|
| Name: r.Name,
|
| Status: "error",
|
| Error: r.Err,
|
| })
|
| } else {
|
| testResults = append(testResults, dto.TestResult{
|
| Name: r.Name,
|
| Status: "success",
|
| })
|
| successfulChannels = append(successfulChannels, struct {
|
| name string
|
| data map[string]any
|
| }{name: r.Name, data: r.Data})
|
| }
|
| }
|
|
|
| differences := buildDifferences(localData, successfulChannels)
|
|
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": true,
|
| "data": gin.H{
|
| "differences": differences,
|
| "test_results": testResults,
|
| },
|
| })
|
| }
|
|
|
| func buildDifferences(localData map[string]any, successfulChannels []struct {
|
| name string
|
| data map[string]any
|
| }) map[string]map[string]dto.DifferenceItem {
|
| differences := make(map[string]map[string]dto.DifferenceItem)
|
|
|
| allModels := make(map[string]struct{})
|
|
|
| for _, ratioType := range ratioTypes {
|
| if localRatioAny, ok := localData[ratioType]; ok {
|
| if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
| for modelName := range localRatio {
|
| allModels[modelName] = struct{}{}
|
| }
|
| }
|
| }
|
| }
|
|
|
| for _, channel := range successfulChannels {
|
| for _, ratioType := range ratioTypes {
|
| if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
| for modelName := range upstreamRatio {
|
| allModels[modelName] = struct{}{}
|
| }
|
| }
|
| }
|
| }
|
|
|
| confidenceMap := make(map[string]map[string]bool)
|
|
|
|
|
| for _, channel := range successfulChannels {
|
| confidenceMap[channel.name] = make(map[string]bool)
|
|
|
| modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any)
|
| completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any)
|
|
|
| if hasModelRatio && hasCompletionRatio {
|
|
|
| for modelName := range allModels {
|
|
|
| confidenceMap[channel.name][modelName] = true
|
|
|
|
|
| if modelRatioVal, ok := modelRatios[modelName]; ok {
|
| if completionRatioVal, ok := completionRatios[modelName]; ok {
|
|
|
| if modelRatioFloat, ok := modelRatioVal.(float64); ok {
|
| if completionRatioFloat, ok := completionRatioVal.(float64); ok {
|
| if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 {
|
| confidenceMap[channel.name][modelName] = false
|
| }
|
| }
|
| }
|
| }
|
| }
|
| }
|
| } else {
|
|
|
| for modelName := range allModels {
|
| confidenceMap[channel.name][modelName] = true
|
| }
|
| }
|
| }
|
|
|
| for modelName := range allModels {
|
| for _, ratioType := range ratioTypes {
|
| var localValue interface{} = nil
|
| if localRatioAny, ok := localData[ratioType]; ok {
|
| if localRatio, ok := localRatioAny.(map[string]float64); ok {
|
| if val, exists := localRatio[modelName]; exists {
|
| localValue = val
|
| }
|
| }
|
| }
|
|
|
| upstreamValues := make(map[string]interface{})
|
| confidenceValues := make(map[string]bool)
|
| hasUpstreamValue := false
|
| hasDifference := false
|
|
|
| for _, channel := range successfulChannels {
|
| var upstreamValue interface{} = nil
|
|
|
| if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok {
|
| if val, exists := upstreamRatio[modelName]; exists {
|
| upstreamValue = val
|
| hasUpstreamValue = true
|
|
|
| if localValue != nil && !valuesEqual(localValue, val) {
|
| hasDifference = true
|
| } else if valuesEqual(localValue, val) {
|
| upstreamValue = "same"
|
| }
|
| }
|
| }
|
| if upstreamValue == nil && localValue == nil {
|
| upstreamValue = "same"
|
| }
|
|
|
| if localValue == nil && upstreamValue != nil && upstreamValue != "same" {
|
| hasDifference = true
|
| }
|
|
|
| upstreamValues[channel.name] = upstreamValue
|
|
|
| confidenceValues[channel.name] = confidenceMap[channel.name][modelName]
|
| }
|
|
|
| shouldInclude := false
|
|
|
| if localValue != nil {
|
| if hasDifference {
|
| shouldInclude = true
|
| }
|
| } else {
|
| if hasUpstreamValue {
|
| shouldInclude = true
|
| }
|
| }
|
|
|
| if shouldInclude {
|
| if differences[modelName] == nil {
|
| differences[modelName] = make(map[string]dto.DifferenceItem)
|
| }
|
| differences[modelName][ratioType] = dto.DifferenceItem{
|
| Current: localValue,
|
| Upstreams: upstreamValues,
|
| Confidence: confidenceValues,
|
| }
|
| }
|
| }
|
| }
|
|
|
| channelHasDiff := make(map[string]bool)
|
| for _, ratioMap := range differences {
|
| for _, item := range ratioMap {
|
| for chName, val := range item.Upstreams {
|
| if val != nil && val != "same" {
|
| channelHasDiff[chName] = true
|
| }
|
| }
|
| }
|
| }
|
|
|
| for modelName, ratioMap := range differences {
|
| for ratioType, item := range ratioMap {
|
| for chName := range item.Upstreams {
|
| if !channelHasDiff[chName] {
|
| delete(item.Upstreams, chName)
|
| delete(item.Confidence, chName)
|
| }
|
| }
|
|
|
| allSame := true
|
| for _, v := range item.Upstreams {
|
| if v != "same" {
|
| allSame = false
|
| break
|
| }
|
| }
|
| if len(item.Upstreams) == 0 || allSame {
|
| delete(ratioMap, ratioType)
|
| } else {
|
| differences[modelName][ratioType] = item
|
| }
|
| }
|
|
|
| if len(ratioMap) == 0 {
|
| delete(differences, modelName)
|
| }
|
| }
|
|
|
| return differences
|
| }
|
|
|
| func GetSyncableChannels(c *gin.Context) {
|
| channels, err := model.GetAllChannels(0, 0, true, false)
|
| if err != nil {
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": false,
|
| "message": err.Error(),
|
| })
|
| return
|
| }
|
|
|
| var syncableChannels []dto.SyncableChannel
|
| for _, channel := range channels {
|
| if channel.GetBaseURL() != "" {
|
| syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
| ID: channel.Id,
|
| Name: channel.Name,
|
| BaseURL: channel.GetBaseURL(),
|
| Status: channel.Status,
|
| })
|
| }
|
| }
|
|
|
| syncableChannels = append(syncableChannels, dto.SyncableChannel{
|
| ID: -100,
|
| Name: "官方倍率预设",
|
| BaseURL: "https://basellm.github.io",
|
| Status: 1,
|
| })
|
|
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": true,
|
| "message": "",
|
| "data": syncableChannels,
|
| })
|
| }
|
|
|