| | package controller |
| |
|
| | import ( |
| | "bytes" |
| | "encoding/json" |
| | "errors" |
| | "fmt" |
| | "io" |
| | "math" |
| | "net/http" |
| | "net/http/httptest" |
| | "net/url" |
| | "one-api/common" |
| | "one-api/dto" |
| | "one-api/middleware" |
| | "one-api/model" |
| | "one-api/relay" |
| | relaycommon "one-api/relay/common" |
| | "one-api/relay/constant" |
| | "one-api/relay/helper" |
| | "one-api/service" |
| | "strconv" |
| | "strings" |
| | "sync" |
| | "time" |
| |
|
| | "github.com/bytedance/gopkg/util/gopool" |
| |
|
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { |
| | tik := time.Now() |
| | if channel.Type == common.ChannelTypeMidjourney { |
| | return errors.New("midjourney channel test is not supported"), nil |
| | } |
| | if channel.Type == common.ChannelTypeMidjourneyPlus { |
| | return errors.New("midjourney plus channel test is not supported!!!"), nil |
| | } |
| | if channel.Type == common.ChannelTypeSunoAPI { |
| | return errors.New("suno channel test is not supported"), nil |
| | } |
| | w := httptest.NewRecorder() |
| | c, _ := gin.CreateTestContext(w) |
| |
|
| | requestPath := "/v1/chat/completions" |
| |
|
| | |
| | if strings.Contains(strings.ToLower(testModel), "embedding") || |
| | strings.HasPrefix(testModel, "m3e") || |
| | strings.Contains(testModel, "bge-") || |
| | strings.Contains(testModel, "embed") || |
| | channel.Type == common.ChannelTypeMokaAI { |
| | requestPath = "/v1/embeddings" |
| | } |
| |
|
| | c.Request = &http.Request{ |
| | Method: "POST", |
| | URL: &url.URL{Path: requestPath}, |
| | Body: nil, |
| | Header: make(http.Header), |
| | } |
| |
|
| | if testModel == "" { |
| | if channel.TestModel != nil && *channel.TestModel != "" { |
| | testModel = *channel.TestModel |
| | } else { |
| | if len(channel.GetModels()) > 0 { |
| | testModel = channel.GetModels()[0] |
| | } else { |
| | testModel = "gpt-4o-mini" |
| | } |
| | } |
| | } |
| |
|
| | cache, err := model.GetUserCache(1) |
| | if err != nil { |
| | return err, nil |
| | } |
| | cache.WriteContext(c) |
| |
|
| | c.Request.Header.Set("Authorization", "Bearer "+channel.Key) |
| | c.Request.Header.Set("Content-Type", "application/json") |
| | c.Set("channel", channel.Type) |
| | c.Set("base_url", channel.GetBaseURL()) |
| | group, _ := model.GetUserGroup(1, false) |
| | c.Set("group", group) |
| |
|
| | middleware.SetupContextForSelectedChannel(c, channel, testModel) |
| |
|
| | info := relaycommon.GenRelayInfo(c) |
| |
|
| | err = helper.ModelMappedHelper(c, info) |
| | if err != nil { |
| | return err, nil |
| | } |
| | testModel = info.UpstreamModelName |
| |
|
| | apiType, _ := constant.ChannelType2APIType(channel.Type) |
| | adaptor := relay.GetAdaptor(apiType) |
| | if adaptor == nil { |
| | return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil |
| | } |
| |
|
| | request := buildTestRequest(testModel) |
| | |
| | logInfo := *info |
| | logInfo.ApiKey = "" |
| | common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) |
| |
|
| | priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) |
| | if err != nil { |
| | return err, nil |
| | } |
| |
|
| | adaptor.Init(info) |
| |
|
| | convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) |
| | if err != nil { |
| | return err, nil |
| | } |
| | jsonData, err := json.Marshal(convertedRequest) |
| | if err != nil { |
| | return err, nil |
| | } |
| | requestBody := bytes.NewBuffer(jsonData) |
| | c.Request.Body = io.NopCloser(requestBody) |
| | resp, err := adaptor.DoRequest(c, info, requestBody) |
| | if err != nil { |
| | return err, nil |
| | } |
| | var httpResp *http.Response |
| | if resp != nil { |
| | httpResp = resp.(*http.Response) |
| | if httpResp.StatusCode != http.StatusOK { |
| | err := service.RelayErrorHandler(httpResp, true) |
| | return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err |
| | } |
| | } |
| | usageA, respErr := adaptor.DoResponse(c, httpResp, info) |
| | if respErr != nil { |
| | return fmt.Errorf("%s", respErr.Error.Message), respErr |
| | } |
| | if usageA == nil { |
| | return errors.New("usage is nil"), nil |
| | } |
| | usage := usageA.(*dto.Usage) |
| | result := w.Result() |
| | respBody, err := io.ReadAll(result.Body) |
| | if err != nil { |
| | return err, nil |
| | } |
| | info.PromptTokens = usage.PromptTokens |
| |
|
| | quota := 0 |
| | if !priceData.UsePrice { |
| | quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*priceData.CompletionRatio)) |
| | quota = int(math.Round(float64(quota) * priceData.ModelRatio)) |
| | if priceData.ModelRatio != 0 && quota <= 0 { |
| | quota = 1 |
| | } |
| | } else { |
| | quota = int(priceData.ModelPrice * common.QuotaPerUnit) |
| | } |
| | tok := time.Now() |
| | milliseconds := tok.Sub(tik).Milliseconds() |
| | consumedTime := float64(milliseconds) / 1000.0 |
| | other := service.GenerateTextOtherInfo(c, info, priceData.ModelRatio, priceData.GroupRatio, priceData.CompletionRatio, |
| | usage.PromptTokensDetails.CachedTokens, priceData.CacheRatio, priceData.ModelPrice) |
| | model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, info.OriginModelName, "模型测试", |
| | quota, "模型测试", 0, quota, int(consumedTime), false, info.Group, other) |
| | common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) |
| | return nil, nil |
| | } |
| |
|
| | func buildTestRequest(model string) *dto.GeneralOpenAIRequest { |
| | testRequest := &dto.GeneralOpenAIRequest{ |
| | Model: "", |
| | Stream: false, |
| | } |
| |
|
| | |
| | if strings.Contains(strings.ToLower(model), "embedding") || |
| | strings.HasPrefix(model, "m3e") || |
| | strings.Contains(model, "bge-") { |
| | testRequest.Model = model |
| | |
| | testRequest.Input = []string{"hello world"} |
| | return testRequest |
| | } |
| | |
| | if strings.HasPrefix(model, "o") { |
| | testRequest.MaxCompletionTokens = 10 |
| | } else if strings.Contains(model, "thinking") { |
| | if !strings.Contains(model, "claude") { |
| | testRequest.MaxTokens = 50 |
| | } |
| | } else if strings.Contains(model, "gemini") { |
| | testRequest.MaxTokens = 300 |
| | } else { |
| | testRequest.MaxTokens = 10 |
| | } |
| | content, _ := json.Marshal("hi") |
| | testMessage := dto.Message{ |
| | Role: "user", |
| | Content: content, |
| | } |
| | testRequest.Model = model |
| | testRequest.Messages = append(testRequest.Messages, testMessage) |
| | return testRequest |
| | } |
| |
|
| | func TestChannel(c *gin.Context) { |
| | channelId, err := strconv.Atoi(c.Param("id")) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | channel, err := model.GetChannelById(channelId, true) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | testModel := c.Query("model") |
| | tik := time.Now() |
| | err, _ = testChannel(channel, testModel) |
| | tok := time.Now() |
| | milliseconds := tok.Sub(tik).Milliseconds() |
| | go channel.UpdateResponseTime(milliseconds) |
| | consumedTime := float64(milliseconds) / 1000.0 |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | "time": consumedTime, |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | "time": consumedTime, |
| | }) |
| | return |
| | } |
| |
|
| | var testAllChannelsLock sync.Mutex |
| | var testAllChannelsRunning bool = false |
| |
|
| | func testAllChannels(notify bool) error { |
| |
|
| | testAllChannelsLock.Lock() |
| | if testAllChannelsRunning { |
| | testAllChannelsLock.Unlock() |
| | return errors.New("测试已在运行中") |
| | } |
| | testAllChannelsRunning = true |
| | testAllChannelsLock.Unlock() |
| | channels, err := model.GetAllChannels(0, 0, true, false) |
| | if err != nil { |
| | return err |
| | } |
| | var disableThreshold = int64(common.ChannelDisableThreshold * 1000) |
| | if disableThreshold == 0 { |
| | disableThreshold = 10000000 |
| | } |
| | gopool.Go(func() { |
| | for _, channel := range channels { |
| | isChannelEnabled := channel.Status == common.ChannelStatusEnabled |
| | tik := time.Now() |
| | err, openaiWithStatusErr := testChannel(channel, "") |
| | tok := time.Now() |
| | milliseconds := tok.Sub(tik).Milliseconds() |
| |
|
| | shouldBanChannel := false |
| |
|
| | |
| | if openaiWithStatusErr != nil { |
| | oaiErr := openaiWithStatusErr.Error |
| | err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) |
| | shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) |
| | } |
| |
|
| | if milliseconds > disableThreshold { |
| | err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) |
| | shouldBanChannel = true |
| | } |
| |
|
| | |
| | if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() { |
| | service.DisableChannel(channel.Id, channel.Name, err.Error()) |
| | } |
| |
|
| | |
| | if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { |
| | service.EnableChannel(channel.Id, channel.Name) |
| | } |
| |
|
| | channel.UpdateResponseTime(milliseconds) |
| | time.Sleep(common.RequestInterval) |
| | } |
| | testAllChannelsLock.Lock() |
| | testAllChannelsRunning = false |
| | testAllChannelsLock.Unlock() |
| | if notify { |
| | service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") |
| | } |
| | }) |
| | return nil |
| | } |
| |
|
| | func TestAllChannels(c *gin.Context) { |
| | err := testAllChannels(true) |
| | if err != nil { |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": false, |
| | "message": err.Error(), |
| | }) |
| | return |
| | } |
| | c.JSON(http.StatusOK, gin.H{ |
| | "success": true, |
| | "message": "", |
| | }) |
| | return |
| | } |
| |
|
| | func AutomaticallyTestChannels(frequency int) { |
| | for { |
| | time.Sleep(time.Duration(frequency) * time.Minute) |
| | common.SysLog("testing all channels") |
| | _ = testAllChannels(false) |
| | common.SysLog("channel test finished") |
| | } |
| | } |
| |
|