| package controller
|
|
|
| import (
|
| "bytes"
|
| "encoding/json"
|
| "errors"
|
| "fmt"
|
| "io"
|
| "math"
|
| "net/http"
|
| "net/http/httptest"
|
| "net/url"
|
| "strconv"
|
| "strings"
|
| "sync"
|
| "time"
|
|
|
| "github.com/QuantumNous/new-api/common"
|
| "github.com/QuantumNous/new-api/constant"
|
| "github.com/QuantumNous/new-api/dto"
|
| "github.com/QuantumNous/new-api/middleware"
|
| "github.com/QuantumNous/new-api/model"
|
| "github.com/QuantumNous/new-api/relay"
|
| relaycommon "github.com/QuantumNous/new-api/relay/common"
|
| relayconstant "github.com/QuantumNous/new-api/relay/constant"
|
| "github.com/QuantumNous/new-api/relay/helper"
|
| "github.com/QuantumNous/new-api/service"
|
| "github.com/QuantumNous/new-api/setting/operation_setting"
|
| "github.com/QuantumNous/new-api/setting/ratio_setting"
|
| "github.com/QuantumNous/new-api/types"
|
|
|
| "github.com/bytedance/gopkg/util/gopool"
|
| "github.com/samber/lo"
|
|
|
| "github.com/gin-gonic/gin"
|
| )
|
|
|
| type testResult struct {
|
| context *gin.Context
|
| localErr error
|
| newAPIError *types.NewAPIError
|
| }
|
|
|
| func testChannel(channel *model.Channel, testModel string, endpointType string) testResult {
|
| tik := time.Now()
|
| var unsupportedTestChannelTypes = []int{
|
| constant.ChannelTypeMidjourney,
|
| constant.ChannelTypeMidjourneyPlus,
|
| constant.ChannelTypeSunoAPI,
|
| constant.ChannelTypeKling,
|
| constant.ChannelTypeJimeng,
|
| constant.ChannelTypeDoubaoVideo,
|
| constant.ChannelTypeVidu,
|
| }
|
| if lo.Contains(unsupportedTestChannelTypes, channel.Type) {
|
| channelTypeName := constant.GetChannelTypeName(channel.Type)
|
| return testResult{
|
| localErr: fmt.Errorf("%s channel test is not supported", channelTypeName),
|
| }
|
| }
|
| w := httptest.NewRecorder()
|
| c, _ := gin.CreateTestContext(w)
|
|
|
| testModel = strings.TrimSpace(testModel)
|
| if testModel == "" {
|
| if channel.TestModel != nil && *channel.TestModel != "" {
|
| testModel = strings.TrimSpace(*channel.TestModel)
|
| } else {
|
| models := channel.GetModels()
|
| if len(models) > 0 {
|
| testModel = strings.TrimSpace(models[0])
|
| }
|
| if testModel == "" {
|
| testModel = "gpt-4o-mini"
|
| }
|
| }
|
| }
|
|
|
| requestPath := "/v1/chat/completions"
|
|
|
|
|
| if endpointType != "" {
|
| if endpointInfo, ok := common.GetDefaultEndpointInfo(constant.EndpointType(endpointType)); ok {
|
| requestPath = endpointInfo.Path
|
| }
|
| } else {
|
|
|
|
|
| if strings.Contains(strings.ToLower(testModel), "rerank") {
|
| requestPath = "/v1/rerank"
|
| }
|
|
|
|
|
| if strings.Contains(strings.ToLower(testModel), "embedding") ||
|
| strings.HasPrefix(testModel, "m3e") ||
|
| strings.Contains(testModel, "bge-") ||
|
| strings.Contains(testModel, "embed") ||
|
| channel.Type == constant.ChannelTypeMokaAI {
|
| requestPath = "/v1/embeddings"
|
| }
|
|
|
|
|
| if channel.Type == constant.ChannelTypeVolcEngine && strings.Contains(testModel, "seedream") {
|
| requestPath = "/v1/images/generations"
|
| }
|
|
|
|
|
| if strings.Contains(strings.ToLower(testModel), "codex") {
|
| requestPath = "/v1/responses"
|
| }
|
|
|
|
|
| if strings.HasSuffix(testModel, ratio_setting.CompactModelSuffix) {
|
| requestPath = "/v1/responses/compact"
|
| }
|
| }
|
| if strings.HasPrefix(requestPath, "/v1/responses/compact") {
|
| testModel = ratio_setting.WithCompactModelSuffix(testModel)
|
| }
|
|
|
| c.Request = &http.Request{
|
| Method: "POST",
|
| URL: &url.URL{Path: requestPath},
|
| Body: nil,
|
| Header: make(http.Header),
|
| }
|
|
|
| cache, err := model.GetUserCache(1)
|
| if err != nil {
|
| return testResult{
|
| localErr: err,
|
| newAPIError: nil,
|
| }
|
| }
|
| cache.WriteContext(c)
|
|
|
|
|
| c.Request.Header.Set("Content-Type", "application/json")
|
| c.Set("channel", channel.Type)
|
| c.Set("base_url", channel.GetBaseURL())
|
| group, _ := model.GetUserGroup(1, false)
|
| c.Set("group", group)
|
|
|
| newAPIError := middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
| if newAPIError != nil {
|
| return testResult{
|
| context: c,
|
| localErr: newAPIError,
|
| newAPIError: newAPIError,
|
| }
|
| }
|
|
|
|
|
| var relayFormat types.RelayFormat
|
| if endpointType != "" {
|
|
|
| switch constant.EndpointType(endpointType) {
|
| case constant.EndpointTypeOpenAI:
|
| relayFormat = types.RelayFormatOpenAI
|
| case constant.EndpointTypeOpenAIResponse:
|
| relayFormat = types.RelayFormatOpenAIResponses
|
| case constant.EndpointTypeOpenAIResponseCompact:
|
| relayFormat = types.RelayFormatOpenAIResponsesCompaction
|
| case constant.EndpointTypeAnthropic:
|
| relayFormat = types.RelayFormatClaude
|
| case constant.EndpointTypeGemini:
|
| relayFormat = types.RelayFormatGemini
|
| case constant.EndpointTypeJinaRerank:
|
| relayFormat = types.RelayFormatRerank
|
| case constant.EndpointTypeImageGeneration:
|
| relayFormat = types.RelayFormatOpenAIImage
|
| case constant.EndpointTypeEmbeddings:
|
| relayFormat = types.RelayFormatEmbedding
|
| default:
|
| relayFormat = types.RelayFormatOpenAI
|
| }
|
| } else {
|
|
|
| relayFormat = types.RelayFormatOpenAI
|
| if c.Request.URL.Path == "/v1/embeddings" {
|
| relayFormat = types.RelayFormatEmbedding
|
| }
|
| if c.Request.URL.Path == "/v1/images/generations" {
|
| relayFormat = types.RelayFormatOpenAIImage
|
| }
|
| if c.Request.URL.Path == "/v1/messages" {
|
| relayFormat = types.RelayFormatClaude
|
| }
|
| if strings.Contains(c.Request.URL.Path, "/v1beta/models") {
|
| relayFormat = types.RelayFormatGemini
|
| }
|
| if c.Request.URL.Path == "/v1/rerank" || c.Request.URL.Path == "/rerank" {
|
| relayFormat = types.RelayFormatRerank
|
| }
|
| if c.Request.URL.Path == "/v1/responses" {
|
| relayFormat = types.RelayFormatOpenAIResponses
|
| }
|
| if strings.HasPrefix(c.Request.URL.Path, "/v1/responses/compact") {
|
| relayFormat = types.RelayFormatOpenAIResponsesCompaction
|
| }
|
| }
|
|
|
| request := buildTestRequest(testModel, endpointType, channel)
|
|
|
| info, err := relaycommon.GenRelayInfo(c, relayFormat, request, nil)
|
|
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewError(err, types.ErrorCodeGenRelayInfoFailed),
|
| }
|
| }
|
|
|
| info.IsChannelTest = true
|
| info.InitChannelMeta(c)
|
|
|
| err = helper.ModelMappedHelper(c, info, request)
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewError(err, types.ErrorCodeChannelModelMappedError),
|
| }
|
| }
|
|
|
| testModel = info.UpstreamModelName
|
|
|
| request.SetModelName(testModel)
|
|
|
| apiType, _ := common.ChannelType2APIType(channel.Type)
|
| if info.RelayMode == relayconstant.RelayModeResponsesCompact &&
|
| apiType != constant.APITypeOpenAI &&
|
| apiType != constant.APITypeCodex {
|
| return testResult{
|
| context: c,
|
| localErr: fmt.Errorf("responses compaction test only supports openai/codex channels, got api type %d", apiType),
|
| newAPIError: types.NewError(fmt.Errorf("unsupported api type: %d", apiType), types.ErrorCodeInvalidApiType),
|
| }
|
| }
|
| adaptor := relay.GetAdaptor(apiType)
|
| if adaptor == nil {
|
| return testResult{
|
| context: c,
|
| localErr: fmt.Errorf("invalid api type: %d, adaptor is nil", apiType),
|
| newAPIError: types.NewError(fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), types.ErrorCodeInvalidApiType),
|
| }
|
| }
|
|
|
|
|
|
|
|
|
| common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, info.ToString()))
|
|
|
| priceData, err := helper.ModelPriceHelper(c, info, 0, request.GetTokenCountMeta())
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewError(err, types.ErrorCodeModelPriceError),
|
| }
|
| }
|
|
|
| adaptor.Init(info)
|
|
|
| var convertedRequest any
|
|
|
| switch info.RelayMode {
|
| case relayconstant.RelayModeEmbeddings:
|
|
|
| if embeddingReq, ok := request.(*dto.EmbeddingRequest); ok {
|
| convertedRequest, err = adaptor.ConvertEmbeddingRequest(c, info, *embeddingReq)
|
| } else {
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("invalid embedding request type"),
|
| newAPIError: types.NewError(errors.New("invalid embedding request type"), types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| case relayconstant.RelayModeImagesGenerations:
|
|
|
| if imageReq, ok := request.(*dto.ImageRequest); ok {
|
| convertedRequest, err = adaptor.ConvertImageRequest(c, info, *imageReq)
|
| } else {
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("invalid image request type"),
|
| newAPIError: types.NewError(errors.New("invalid image request type"), types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| case relayconstant.RelayModeRerank:
|
|
|
| if rerankReq, ok := request.(*dto.RerankRequest); ok {
|
| convertedRequest, err = adaptor.ConvertRerankRequest(c, info.RelayMode, *rerankReq)
|
| } else {
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("invalid rerank request type"),
|
| newAPIError: types.NewError(errors.New("invalid rerank request type"), types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| case relayconstant.RelayModeResponses:
|
|
|
| if responseReq, ok := request.(*dto.OpenAIResponsesRequest); ok {
|
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *responseReq)
|
| } else {
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("invalid response request type"),
|
| newAPIError: types.NewError(errors.New("invalid response request type"), types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| case relayconstant.RelayModeResponsesCompact:
|
|
|
| switch req := request.(type) {
|
| case *dto.OpenAIResponsesCompactionRequest:
|
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, dto.OpenAIResponsesRequest{
|
| Model: req.Model,
|
| Input: req.Input,
|
| Instructions: req.Instructions,
|
| PreviousResponseID: req.PreviousResponseID,
|
| })
|
| case *dto.OpenAIResponsesRequest:
|
| convertedRequest, err = adaptor.ConvertOpenAIResponsesRequest(c, info, *req)
|
| default:
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("invalid response compaction request type"),
|
| newAPIError: types.NewError(errors.New("invalid response compaction request type"), types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| default:
|
|
|
| if generalReq, ok := request.(*dto.GeneralOpenAIRequest); ok {
|
| convertedRequest, err = adaptor.ConvertOpenAIRequest(c, info, generalReq)
|
| } else {
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("invalid general request type"),
|
| newAPIError: types.NewError(errors.New("invalid general request type"), types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| }
|
|
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewError(err, types.ErrorCodeConvertRequestFailed),
|
| }
|
| }
|
| jsonData, err := json.Marshal(convertedRequest)
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewError(err, types.ErrorCodeJsonMarshalFailed),
|
| }
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if len(info.ParamOverride) > 0 {
|
| jsonData, err = relaycommon.ApplyParamOverride(jsonData, info.ParamOverride, relaycommon.BuildParamOverrideContext(info))
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewError(err, types.ErrorCodeChannelParamOverrideInvalid),
|
| }
|
| }
|
| }
|
|
|
| requestBody := bytes.NewBuffer(jsonData)
|
| c.Request.Body = io.NopCloser(bytes.NewBuffer(jsonData))
|
| resp, err := adaptor.DoRequest(c, info, requestBody)
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError),
|
| }
|
| }
|
| var httpResp *http.Response
|
| if resp != nil {
|
| httpResp = resp.(*http.Response)
|
| if httpResp.StatusCode != http.StatusOK {
|
| err := service.RelayErrorHandler(c.Request.Context(), httpResp, true)
|
| common.SysError(fmt.Sprintf(
|
| "channel test bad response: channel_id=%d name=%s type=%d model=%s endpoint_type=%s status=%d err=%v",
|
| channel.Id,
|
| channel.Name,
|
| channel.Type,
|
| testModel,
|
| endpointType,
|
| httpResp.StatusCode,
|
| err,
|
| ))
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeBadResponse, http.StatusInternalServerError),
|
| }
|
| }
|
| }
|
| usageA, respErr := adaptor.DoResponse(c, httpResp, info)
|
| if respErr != nil {
|
| return testResult{
|
| context: c,
|
| localErr: respErr,
|
| newAPIError: respErr,
|
| }
|
| }
|
| if usageA == nil {
|
| return testResult{
|
| context: c,
|
| localErr: errors.New("usage is nil"),
|
| newAPIError: types.NewOpenAIError(errors.New("usage is nil"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError),
|
| }
|
| }
|
| usage := usageA.(*dto.Usage)
|
| result := w.Result()
|
| respBody, err := io.ReadAll(result.Body)
|
| if err != nil {
|
| return testResult{
|
| context: c,
|
| localErr: err,
|
| newAPIError: types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError),
|
| }
|
| }
|
| info.SetEstimatePromptTokens(usage.PromptTokens)
|
|
|
| quota := 0
|
| if !priceData.UsePrice {
|
| quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio))
|
| quota = int(math.Round(float64(quota) * priceData.ModelRatio))
|
| if priceData.ModelRatio != 0 && quota <= 0 {
|
| quota = 1
|
| }
|
| } else {
|
| quota = int(priceData.ModelPrice * common.QuotaPerUnit)
|
| }
|
| tok := time.Now()
|
| milliseconds := tok.Sub(tik).Milliseconds()
|
| consumedTime := float64(milliseconds) / 1000.0
|
| other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatioInfo.GroupRatio, priceData.CompletionRatio,
|
| usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
| model.RecordConsumeLog(c, 1, model.RecordConsumeLogParams{
|
| ChannelId: channel.Id,
|
| PromptTokens: usage.PromptTokens,
|
| CompletionTokens: usage.CompletionTokens,
|
| ModelName: info.OriginModelName,
|
| TokenName: "模型测试",
|
| Quota: quota,
|
| Content: "模型测试",
|
| UseTimeSeconds: int(consumedTime),
|
| IsStream: info.IsStream,
|
| Group: info.UsingGroup,
|
| Other: other,
|
| })
|
| common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
| return testResult{
|
| context: c,
|
| localErr: nil,
|
| newAPIError: nil,
|
| }
|
| }
|
|
|
| func buildTestRequest(model string, endpointType string, channel *model.Channel) dto.Request {
|
| testResponsesInput := json.RawMessage(`[{"role":"user","content":"hi"}]`)
|
|
|
|
|
| if endpointType != "" {
|
| switch constant.EndpointType(endpointType) {
|
| case constant.EndpointTypeEmbeddings:
|
|
|
| return &dto.EmbeddingRequest{
|
| Model: model,
|
| Input: []any{"hello world"},
|
| }
|
| case constant.EndpointTypeImageGeneration:
|
|
|
| return &dto.ImageRequest{
|
| Model: model,
|
| Prompt: "a cute cat",
|
| N: 1,
|
| Size: "1024x1024",
|
| }
|
| case constant.EndpointTypeJinaRerank:
|
|
|
| return &dto.RerankRequest{
|
| Model: model,
|
| Query: "What is Deep Learning?",
|
| Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
|
| TopN: 2,
|
| }
|
| case constant.EndpointTypeOpenAIResponse:
|
|
|
| return &dto.OpenAIResponsesRequest{
|
| Model: model,
|
| Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
| }
|
| case constant.EndpointTypeOpenAIResponseCompact:
|
|
|
| return &dto.OpenAIResponsesCompactionRequest{
|
| Model: model,
|
| Input: testResponsesInput,
|
| }
|
| case constant.EndpointTypeAnthropic, constant.EndpointTypeGemini, constant.EndpointTypeOpenAI:
|
|
|
| maxTokens := uint(16)
|
| if constant.EndpointType(endpointType) == constant.EndpointTypeGemini {
|
| maxTokens = 3000
|
| }
|
| return &dto.GeneralOpenAIRequest{
|
| Model: model,
|
| Stream: false,
|
| Messages: []dto.Message{
|
| {
|
| Role: "user",
|
| Content: "hi",
|
| },
|
| },
|
| MaxTokens: maxTokens,
|
| }
|
| }
|
| }
|
|
|
|
|
| if strings.Contains(strings.ToLower(model), "rerank") {
|
| return &dto.RerankRequest{
|
| Model: model,
|
| Query: "What is Deep Learning?",
|
| Documents: []any{"Deep Learning is a subset of machine learning.", "Machine learning is a field of artificial intelligence."},
|
| TopN: 2,
|
| }
|
| }
|
|
|
|
|
| if strings.Contains(strings.ToLower(model), "embedding") ||
|
| strings.HasPrefix(model, "m3e") ||
|
| strings.Contains(model, "bge-") {
|
|
|
| return &dto.EmbeddingRequest{
|
| Model: model,
|
| Input: []any{"hello world"},
|
| }
|
| }
|
|
|
|
|
| if strings.HasSuffix(model, ratio_setting.CompactModelSuffix) {
|
| return &dto.OpenAIResponsesCompactionRequest{
|
| Model: model,
|
| Input: testResponsesInput,
|
| }
|
| }
|
|
|
|
|
| if strings.Contains(strings.ToLower(model), "codex") {
|
| return &dto.OpenAIResponsesRequest{
|
| Model: model,
|
| Input: json.RawMessage(`[{"role":"user","content":"hi"}]`),
|
| }
|
| }
|
|
|
|
|
| testRequest := &dto.GeneralOpenAIRequest{
|
| Model: model,
|
| Stream: false,
|
| Messages: []dto.Message{
|
| {
|
| Role: "user",
|
| Content: "hi",
|
| },
|
| },
|
| }
|
|
|
| if strings.HasPrefix(model, "o") {
|
| testRequest.MaxCompletionTokens = 16
|
| } else if strings.Contains(model, "thinking") {
|
| if !strings.Contains(model, "claude") {
|
| testRequest.MaxTokens = 50
|
| }
|
| } else if strings.Contains(model, "gemini") {
|
| testRequest.MaxTokens = 3000
|
| } else {
|
| testRequest.MaxTokens = 16
|
| }
|
|
|
| return testRequest
|
| }
|
|
|
| func TestChannel(c *gin.Context) {
|
| channelId, err := strconv.Atoi(c.Param("id"))
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| channel, err := model.CacheGetChannel(channelId)
|
| if err != nil {
|
| channel, err = model.GetChannelById(channelId, true)
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| }
|
|
|
|
|
|
|
|
|
|
|
| testModel := c.Query("model")
|
| endpointType := c.Query("endpoint_type")
|
| tik := time.Now()
|
| result := testChannel(channel, testModel, endpointType)
|
| if result.localErr != nil {
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": false,
|
| "message": result.localErr.Error(),
|
| "time": 0.0,
|
| })
|
| return
|
| }
|
| tok := time.Now()
|
| milliseconds := tok.Sub(tik).Milliseconds()
|
| go channel.UpdateResponseTime(milliseconds)
|
| consumedTime := float64(milliseconds) / 1000.0
|
| if result.newAPIError != nil {
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": false,
|
| "message": result.newAPIError.Error(),
|
| "time": consumedTime,
|
| })
|
| return
|
| }
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": true,
|
| "message": "",
|
| "time": consumedTime,
|
| })
|
| }
|
|
|
| var testAllChannelsLock sync.Mutex
|
| var testAllChannelsRunning bool = false
|
|
|
| func testAllChannels(notify bool) error {
|
|
|
| testAllChannelsLock.Lock()
|
| if testAllChannelsRunning {
|
| testAllChannelsLock.Unlock()
|
| return errors.New("测试已在运行中")
|
| }
|
| testAllChannelsRunning = true
|
| testAllChannelsLock.Unlock()
|
| channels, getChannelErr := model.GetAllChannels(0, 0, true, false)
|
| if getChannelErr != nil {
|
| return getChannelErr
|
| }
|
| var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
| if disableThreshold == 0 {
|
| disableThreshold = 10000000
|
| }
|
| gopool.Go(func() {
|
|
|
| defer func() {
|
| testAllChannelsLock.Lock()
|
| testAllChannelsRunning = false
|
| testAllChannelsLock.Unlock()
|
| }()
|
|
|
| for _, channel := range channels {
|
| isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
| tik := time.Now()
|
| result := testChannel(channel, "", "")
|
| tok := time.Now()
|
| milliseconds := tok.Sub(tik).Milliseconds()
|
|
|
| shouldBanChannel := false
|
| newAPIError := result.newAPIError
|
|
|
| if newAPIError != nil {
|
| shouldBanChannel = service.ShouldDisableChannel(channel.Type, result.newAPIError)
|
| }
|
|
|
|
|
| if common.AutomaticDisableChannelEnabled && !shouldBanChannel {
|
| if milliseconds > disableThreshold {
|
| err := fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
| newAPIError = types.NewOpenAIError(err, types.ErrorCodeChannelResponseTimeExceeded, http.StatusRequestTimeout)
|
| shouldBanChannel = true
|
| }
|
| }
|
|
|
|
|
| if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
|
| processChannelError(result.context, *types.NewChannelError(channel.Id, channel.Type, channel.Name, channel.ChannelInfo.IsMultiKey, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.GetAutoBan()), newAPIError)
|
| }
|
|
|
|
|
| if !isChannelEnabled && service.ShouldEnableChannel(newAPIError, channel.Status) {
|
| service.EnableChannel(channel.Id, common.GetContextKeyString(result.context, constant.ContextKeyChannelKey), channel.Name)
|
| }
|
|
|
| channel.UpdateResponseTime(milliseconds)
|
| time.Sleep(common.RequestInterval)
|
| }
|
|
|
| if notify {
|
| service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成")
|
| }
|
| })
|
| return nil
|
| }
|
|
|
| func TestAllChannels(c *gin.Context) {
|
| err := testAllChannels(true)
|
| if err != nil {
|
| common.ApiError(c, err)
|
| return
|
| }
|
| c.JSON(http.StatusOK, gin.H{
|
| "success": true,
|
| "message": "",
|
| })
|
| }
|
|
|
| var autoTestChannelsOnce sync.Once
|
|
|
| func AutomaticallyTestChannels() {
|
|
|
| if !common.IsMasterNode {
|
| return
|
| }
|
| autoTestChannelsOnce.Do(func() {
|
| for {
|
| if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
| time.Sleep(1 * time.Minute)
|
| continue
|
| }
|
| for {
|
| frequency := operation_setting.GetMonitorSetting().AutoTestChannelMinutes
|
| time.Sleep(time.Duration(int(math.Round(frequency))) * time.Minute)
|
| common.SysLog(fmt.Sprintf("automatically test channels with interval %f minutes", frequency))
|
| common.SysLog("automatically testing all channels")
|
| _ = testAllChannels(false)
|
| common.SysLog("automatically channel test finished")
|
| if !operation_setting.GetMonitorSetting().AutoTestChannelEnabled {
|
| break
|
| }
|
| }
|
| }
|
| })
|
| }
|
|
|