|
|
package controller |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"net" |
|
|
"net/http" |
|
|
"strings" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/QuantumNous/new-api/logger" |
|
|
|
|
|
"github.com/QuantumNous/new-api/dto" |
|
|
"github.com/QuantumNous/new-api/model" |
|
|
"github.com/QuantumNous/new-api/setting/ratio_setting" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
const ( |
|
|
defaultTimeoutSeconds = 10 |
|
|
defaultEndpoint = "/api/ratio_config" |
|
|
maxConcurrentFetches = 8 |
|
|
maxRatioConfigBytes = 10 << 20 |
|
|
floatEpsilon = 1e-9 |
|
|
) |
|
|
|
|
|
func nearlyEqual(a, b float64) bool { |
|
|
if a > b { |
|
|
return a-b < floatEpsilon |
|
|
} |
|
|
return b-a < floatEpsilon |
|
|
} |
|
|
|
|
|
func valuesEqual(a, b interface{}) bool { |
|
|
af, aok := a.(float64) |
|
|
bf, bok := b.(float64) |
|
|
if aok && bok { |
|
|
return nearlyEqual(af, bf) |
|
|
} |
|
|
return a == b |
|
|
} |
|
|
|
|
|
var ratioTypes = []string{"model_ratio", "completion_ratio", "cache_ratio", "model_price"} |
|
|
|
|
|
type upstreamResult struct { |
|
|
Name string `json:"name"` |
|
|
Data map[string]any `json:"data,omitempty"` |
|
|
Err string `json:"err,omitempty"` |
|
|
} |
|
|
|
|
|
func FetchUpstreamRatios(c *gin.Context) { |
|
|
var req dto.UpstreamRequest |
|
|
if err := c.ShouldBindJSON(&req); err != nil { |
|
|
c.JSON(http.StatusBadRequest, gin.H{"success": false, "message": err.Error()}) |
|
|
return |
|
|
} |
|
|
|
|
|
if req.Timeout <= 0 { |
|
|
req.Timeout = defaultTimeoutSeconds |
|
|
} |
|
|
|
|
|
var upstreams []dto.UpstreamDTO |
|
|
|
|
|
if len(req.Upstreams) > 0 { |
|
|
for _, u := range req.Upstreams { |
|
|
if strings.HasPrefix(u.BaseURL, "http") { |
|
|
if u.Endpoint == "" { |
|
|
u.Endpoint = defaultEndpoint |
|
|
} |
|
|
u.BaseURL = strings.TrimRight(u.BaseURL, "/") |
|
|
upstreams = append(upstreams, u) |
|
|
} |
|
|
} |
|
|
} else if len(req.ChannelIDs) > 0 { |
|
|
intIds := make([]int, 0, len(req.ChannelIDs)) |
|
|
for _, id64 := range req.ChannelIDs { |
|
|
intIds = append(intIds, int(id64)) |
|
|
} |
|
|
dbChannels, err := model.GetChannelsByIds(intIds) |
|
|
if err != nil { |
|
|
logger.LogError(c.Request.Context(), "failed to query channels: "+err.Error()) |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"success": false, "message": "查询渠道失败"}) |
|
|
return |
|
|
} |
|
|
for _, ch := range dbChannels { |
|
|
if base := ch.GetBaseURL(); strings.HasPrefix(base, "http") { |
|
|
upstreams = append(upstreams, dto.UpstreamDTO{ |
|
|
ID: ch.Id, |
|
|
Name: ch.Name, |
|
|
BaseURL: strings.TrimRight(base, "/"), |
|
|
Endpoint: "", |
|
|
}) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if len(upstreams) == 0 { |
|
|
c.JSON(http.StatusOK, gin.H{"success": false, "message": "无有效上游渠道"}) |
|
|
return |
|
|
} |
|
|
|
|
|
var wg sync.WaitGroup |
|
|
ch := make(chan upstreamResult, len(upstreams)) |
|
|
|
|
|
sem := make(chan struct{}, maxConcurrentFetches) |
|
|
|
|
|
dialer := &net.Dialer{Timeout: 10 * time.Second} |
|
|
transport := &http.Transport{MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, ResponseHeaderTimeout: 10 * time.Second} |
|
|
transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { |
|
|
host, _, err := net.SplitHostPort(addr) |
|
|
if err != nil { |
|
|
host = addr |
|
|
} |
|
|
|
|
|
if strings.HasSuffix(host, "github.io") { |
|
|
if conn, err := dialer.DialContext(ctx, "tcp4", addr); err == nil { |
|
|
return conn, nil |
|
|
} |
|
|
return dialer.DialContext(ctx, "tcp6", addr) |
|
|
} |
|
|
return dialer.DialContext(ctx, network, addr) |
|
|
} |
|
|
client := &http.Client{Transport: transport} |
|
|
|
|
|
for _, chn := range upstreams { |
|
|
wg.Add(1) |
|
|
go func(chItem dto.UpstreamDTO) { |
|
|
defer wg.Done() |
|
|
|
|
|
sem <- struct{}{} |
|
|
defer func() { <-sem }() |
|
|
|
|
|
endpoint := chItem.Endpoint |
|
|
var fullURL string |
|
|
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { |
|
|
fullURL = endpoint |
|
|
} else { |
|
|
if endpoint == "" { |
|
|
endpoint = defaultEndpoint |
|
|
} else if !strings.HasPrefix(endpoint, "/") { |
|
|
endpoint = "/" + endpoint |
|
|
} |
|
|
fullURL = chItem.BaseURL + endpoint |
|
|
} |
|
|
|
|
|
uniqueName := chItem.Name |
|
|
if chItem.ID != 0 { |
|
|
uniqueName = fmt.Sprintf("%s(%d)", chItem.Name, chItem.ID) |
|
|
} |
|
|
|
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), time.Duration(req.Timeout)*time.Second) |
|
|
defer cancel() |
|
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, fullURL, nil) |
|
|
if err != nil { |
|
|
logger.LogWarn(c.Request.Context(), "build request failed: "+err.Error()) |
|
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()} |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
var resp *http.Response |
|
|
var lastErr error |
|
|
for attempt := 0; attempt < 3; attempt++ { |
|
|
resp, lastErr = client.Do(httpReq) |
|
|
if lastErr == nil { |
|
|
break |
|
|
} |
|
|
time.Sleep(time.Duration(200*(1<<attempt)) * time.Millisecond) |
|
|
} |
|
|
if lastErr != nil { |
|
|
logger.LogWarn(c.Request.Context(), "http error on "+chItem.Name+": "+lastErr.Error()) |
|
|
ch <- upstreamResult{Name: uniqueName, Err: lastErr.Error()} |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
if resp.StatusCode != http.StatusOK { |
|
|
logger.LogWarn(c.Request.Context(), "non-200 from "+chItem.Name+": "+resp.Status) |
|
|
ch <- upstreamResult{Name: uniqueName, Err: resp.Status} |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
if ct := resp.Header.Get("Content-Type"); ct != "" && !strings.Contains(strings.ToLower(ct), "application/json") { |
|
|
logger.LogWarn(c.Request.Context(), "unexpected content-type from "+chItem.Name+": "+ct) |
|
|
} |
|
|
limited := io.LimitReader(resp.Body, maxRatioConfigBytes) |
|
|
|
|
|
|
|
|
|
|
|
var body struct { |
|
|
Success bool `json:"success"` |
|
|
Data json.RawMessage `json:"data"` |
|
|
Message string `json:"message"` |
|
|
} |
|
|
|
|
|
if err := json.NewDecoder(limited).Decode(&body); err != nil { |
|
|
logger.LogWarn(c.Request.Context(), "json decode failed from "+chItem.Name+": "+err.Error()) |
|
|
ch <- upstreamResult{Name: uniqueName, Err: err.Error()} |
|
|
return |
|
|
} |
|
|
|
|
|
if !body.Success { |
|
|
ch <- upstreamResult{Name: uniqueName, Err: body.Message} |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
var type1Data map[string]any |
|
|
if err := json.Unmarshal(body.Data, &type1Data); err == nil { |
|
|
|
|
|
isType1 := false |
|
|
for _, rt := range ratioTypes { |
|
|
if _, ok := type1Data[rt]; ok { |
|
|
isType1 = true |
|
|
break |
|
|
} |
|
|
} |
|
|
if isType1 { |
|
|
ch <- upstreamResult{Name: uniqueName, Data: type1Data} |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
var pricingItems []struct { |
|
|
ModelName string `json:"model_name"` |
|
|
QuotaType int `json:"quota_type"` |
|
|
ModelRatio float64 `json:"model_ratio"` |
|
|
ModelPrice float64 `json:"model_price"` |
|
|
CompletionRatio float64 `json:"completion_ratio"` |
|
|
} |
|
|
if err := json.Unmarshal(body.Data, &pricingItems); err != nil { |
|
|
logger.LogWarn(c.Request.Context(), "unrecognized data format from "+chItem.Name+": "+err.Error()) |
|
|
ch <- upstreamResult{Name: uniqueName, Err: "无法解析上游返回数据"} |
|
|
return |
|
|
} |
|
|
|
|
|
modelRatioMap := make(map[string]float64) |
|
|
completionRatioMap := make(map[string]float64) |
|
|
modelPriceMap := make(map[string]float64) |
|
|
|
|
|
for _, item := range pricingItems { |
|
|
if item.QuotaType == 1 { |
|
|
modelPriceMap[item.ModelName] = item.ModelPrice |
|
|
} else { |
|
|
modelRatioMap[item.ModelName] = item.ModelRatio |
|
|
|
|
|
completionRatioMap[item.ModelName] = item.CompletionRatio |
|
|
} |
|
|
} |
|
|
|
|
|
converted := make(map[string]any) |
|
|
|
|
|
if len(modelRatioMap) > 0 { |
|
|
ratioAny := make(map[string]any, len(modelRatioMap)) |
|
|
for k, v := range modelRatioMap { |
|
|
ratioAny[k] = v |
|
|
} |
|
|
converted["model_ratio"] = ratioAny |
|
|
} |
|
|
|
|
|
if len(completionRatioMap) > 0 { |
|
|
compAny := make(map[string]any, len(completionRatioMap)) |
|
|
for k, v := range completionRatioMap { |
|
|
compAny[k] = v |
|
|
} |
|
|
converted["completion_ratio"] = compAny |
|
|
} |
|
|
|
|
|
if len(modelPriceMap) > 0 { |
|
|
priceAny := make(map[string]any, len(modelPriceMap)) |
|
|
for k, v := range modelPriceMap { |
|
|
priceAny[k] = v |
|
|
} |
|
|
converted["model_price"] = priceAny |
|
|
} |
|
|
|
|
|
ch <- upstreamResult{Name: uniqueName, Data: converted} |
|
|
}(chn) |
|
|
} |
|
|
|
|
|
wg.Wait() |
|
|
close(ch) |
|
|
|
|
|
localData := ratio_setting.GetExposedData() |
|
|
|
|
|
var testResults []dto.TestResult |
|
|
var successfulChannels []struct { |
|
|
name string |
|
|
data map[string]any |
|
|
} |
|
|
|
|
|
for r := range ch { |
|
|
if r.Err != "" { |
|
|
testResults = append(testResults, dto.TestResult{ |
|
|
Name: r.Name, |
|
|
Status: "error", |
|
|
Error: r.Err, |
|
|
}) |
|
|
} else { |
|
|
testResults = append(testResults, dto.TestResult{ |
|
|
Name: r.Name, |
|
|
Status: "success", |
|
|
}) |
|
|
successfulChannels = append(successfulChannels, struct { |
|
|
name string |
|
|
data map[string]any |
|
|
}{name: r.Name, data: r.Data}) |
|
|
} |
|
|
} |
|
|
|
|
|
differences := buildDifferences(localData, successfulChannels) |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": true, |
|
|
"data": gin.H{ |
|
|
"differences": differences, |
|
|
"test_results": testResults, |
|
|
}, |
|
|
}) |
|
|
} |
|
|
|
|
|
func buildDifferences(localData map[string]any, successfulChannels []struct { |
|
|
name string |
|
|
data map[string]any |
|
|
}) map[string]map[string]dto.DifferenceItem { |
|
|
differences := make(map[string]map[string]dto.DifferenceItem) |
|
|
|
|
|
allModels := make(map[string]struct{}) |
|
|
|
|
|
for _, ratioType := range ratioTypes { |
|
|
if localRatioAny, ok := localData[ratioType]; ok { |
|
|
if localRatio, ok := localRatioAny.(map[string]float64); ok { |
|
|
for modelName := range localRatio { |
|
|
allModels[modelName] = struct{}{} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
for _, channel := range successfulChannels { |
|
|
for _, ratioType := range ratioTypes { |
|
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { |
|
|
for modelName := range upstreamRatio { |
|
|
allModels[modelName] = struct{}{} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
confidenceMap := make(map[string]map[string]bool) |
|
|
|
|
|
|
|
|
for _, channel := range successfulChannels { |
|
|
confidenceMap[channel.name] = make(map[string]bool) |
|
|
|
|
|
modelRatios, hasModelRatio := channel.data["model_ratio"].(map[string]any) |
|
|
completionRatios, hasCompletionRatio := channel.data["completion_ratio"].(map[string]any) |
|
|
|
|
|
if hasModelRatio && hasCompletionRatio { |
|
|
|
|
|
for modelName := range allModels { |
|
|
|
|
|
confidenceMap[channel.name][modelName] = true |
|
|
|
|
|
|
|
|
if modelRatioVal, ok := modelRatios[modelName]; ok { |
|
|
if completionRatioVal, ok := completionRatios[modelName]; ok { |
|
|
|
|
|
if modelRatioFloat, ok := modelRatioVal.(float64); ok { |
|
|
if completionRatioFloat, ok := completionRatioVal.(float64); ok { |
|
|
if modelRatioFloat == 37.5 && completionRatioFloat == 1.0 { |
|
|
confidenceMap[channel.name][modelName] = false |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
for modelName := range allModels { |
|
|
confidenceMap[channel.name][modelName] = true |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
for modelName := range allModels { |
|
|
for _, ratioType := range ratioTypes { |
|
|
var localValue interface{} = nil |
|
|
if localRatioAny, ok := localData[ratioType]; ok { |
|
|
if localRatio, ok := localRatioAny.(map[string]float64); ok { |
|
|
if val, exists := localRatio[modelName]; exists { |
|
|
localValue = val |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
upstreamValues := make(map[string]interface{}) |
|
|
confidenceValues := make(map[string]bool) |
|
|
hasUpstreamValue := false |
|
|
hasDifference := false |
|
|
|
|
|
for _, channel := range successfulChannels { |
|
|
var upstreamValue interface{} = nil |
|
|
|
|
|
if upstreamRatio, ok := channel.data[ratioType].(map[string]any); ok { |
|
|
if val, exists := upstreamRatio[modelName]; exists { |
|
|
upstreamValue = val |
|
|
hasUpstreamValue = true |
|
|
|
|
|
if localValue != nil && !valuesEqual(localValue, val) { |
|
|
hasDifference = true |
|
|
} else if valuesEqual(localValue, val) { |
|
|
upstreamValue = "same" |
|
|
} |
|
|
} |
|
|
} |
|
|
if upstreamValue == nil && localValue == nil { |
|
|
upstreamValue = "same" |
|
|
} |
|
|
|
|
|
if localValue == nil && upstreamValue != nil && upstreamValue != "same" { |
|
|
hasDifference = true |
|
|
} |
|
|
|
|
|
upstreamValues[channel.name] = upstreamValue |
|
|
|
|
|
confidenceValues[channel.name] = confidenceMap[channel.name][modelName] |
|
|
} |
|
|
|
|
|
shouldInclude := false |
|
|
|
|
|
if localValue != nil { |
|
|
if hasDifference { |
|
|
shouldInclude = true |
|
|
} |
|
|
} else { |
|
|
if hasUpstreamValue { |
|
|
shouldInclude = true |
|
|
} |
|
|
} |
|
|
|
|
|
if shouldInclude { |
|
|
if differences[modelName] == nil { |
|
|
differences[modelName] = make(map[string]dto.DifferenceItem) |
|
|
} |
|
|
differences[modelName][ratioType] = dto.DifferenceItem{ |
|
|
Current: localValue, |
|
|
Upstreams: upstreamValues, |
|
|
Confidence: confidenceValues, |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
channelHasDiff := make(map[string]bool) |
|
|
for _, ratioMap := range differences { |
|
|
for _, item := range ratioMap { |
|
|
for chName, val := range item.Upstreams { |
|
|
if val != nil && val != "same" { |
|
|
channelHasDiff[chName] = true |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
for modelName, ratioMap := range differences { |
|
|
for ratioType, item := range ratioMap { |
|
|
for chName := range item.Upstreams { |
|
|
if !channelHasDiff[chName] { |
|
|
delete(item.Upstreams, chName) |
|
|
delete(item.Confidence, chName) |
|
|
} |
|
|
} |
|
|
|
|
|
allSame := true |
|
|
for _, v := range item.Upstreams { |
|
|
if v != "same" { |
|
|
allSame = false |
|
|
break |
|
|
} |
|
|
} |
|
|
if len(item.Upstreams) == 0 || allSame { |
|
|
delete(ratioMap, ratioType) |
|
|
} else { |
|
|
differences[modelName][ratioType] = item |
|
|
} |
|
|
} |
|
|
|
|
|
if len(ratioMap) == 0 { |
|
|
delete(differences, modelName) |
|
|
} |
|
|
} |
|
|
|
|
|
return differences |
|
|
} |
|
|
|
|
|
func GetSyncableChannels(c *gin.Context) { |
|
|
channels, err := model.GetAllChannels(0, 0, true, false) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": false, |
|
|
"message": err.Error(), |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
var syncableChannels []dto.SyncableChannel |
|
|
for _, channel := range channels { |
|
|
if channel.GetBaseURL() != "" { |
|
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{ |
|
|
ID: channel.Id, |
|
|
Name: channel.Name, |
|
|
BaseURL: channel.GetBaseURL(), |
|
|
Status: channel.Status, |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
syncableChannels = append(syncableChannels, dto.SyncableChannel{ |
|
|
ID: -100, |
|
|
Name: "官方倍率预设", |
|
|
BaseURL: "https://basellm.github.io", |
|
|
Status: 1, |
|
|
}) |
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": true, |
|
|
"message": "", |
|
|
"data": syncableChannels, |
|
|
}) |
|
|
} |
|
|
|