|
|
package controller |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"encoding/json" |
|
|
"errors" |
|
|
"fmt" |
|
|
"github.com/gin-gonic/gin" |
|
|
"net/http" |
|
|
"one-api/common" |
|
|
"one-api/model" |
|
|
"strconv" |
|
|
"sync" |
|
|
"time" |
|
|
) |
|
|
|
|
|
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) { |
|
|
switch channel.Type { |
|
|
case common.ChannelTypePaLM: |
|
|
fallthrough |
|
|
case common.ChannelTypeAnthropic: |
|
|
fallthrough |
|
|
case common.ChannelTypeBaidu: |
|
|
fallthrough |
|
|
case common.ChannelTypeZhipu: |
|
|
fallthrough |
|
|
case common.ChannelTypeAli: |
|
|
fallthrough |
|
|
case common.ChannelType360: |
|
|
fallthrough |
|
|
case common.ChannelTypeXunfei: |
|
|
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil |
|
|
case common.ChannelTypeAzure: |
|
|
request.Model = "gpt-35-turbo" |
|
|
default: |
|
|
request.Model = "gpt-3.5-turbo" |
|
|
} |
|
|
requestURL := common.ChannelBaseURLs[channel.Type] |
|
|
if channel.Type == common.ChannelTypeAzure { |
|
|
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) |
|
|
} else { |
|
|
if channel.BaseURL != "" { |
|
|
requestURL = channel.BaseURL |
|
|
} |
|
|
requestURL += "/v1/chat/completions" |
|
|
} |
|
|
|
|
|
jsonData, err := json.Marshal(request) |
|
|
if err != nil { |
|
|
return err, nil |
|
|
} |
|
|
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData)) |
|
|
if err != nil { |
|
|
return err, nil |
|
|
} |
|
|
if channel.Type == common.ChannelTypeAzure { |
|
|
req.Header.Set("api-key", channel.Key) |
|
|
} else { |
|
|
req.Header.Set("Authorization", "Bearer "+channel.Key) |
|
|
} |
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
resp, err := httpClient.Do(req) |
|
|
if err != nil { |
|
|
return err, nil |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
var response TextResponse |
|
|
err = json.NewDecoder(resp.Body).Decode(&response) |
|
|
if err != nil { |
|
|
return err, nil |
|
|
} |
|
|
if response.Usage.CompletionTokens == 0 { |
|
|
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error |
|
|
} |
|
|
return nil, nil |
|
|
} |
|
|
|
|
|
func buildTestRequest() *ChatRequest { |
|
|
testRequest := &ChatRequest{ |
|
|
Model: "", |
|
|
MaxTokens: 1, |
|
|
} |
|
|
testMessage := Message{ |
|
|
Role: "user", |
|
|
Content: "hi", |
|
|
} |
|
|
testRequest.Messages = append(testRequest.Messages, testMessage) |
|
|
return testRequest |
|
|
} |
|
|
|
|
|
func TestChannel(c *gin.Context) { |
|
|
id, err := strconv.Atoi(c.Param("id")) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": false, |
|
|
"message": err.Error(), |
|
|
}) |
|
|
return |
|
|
} |
|
|
channel, err := model.GetChannelById(id, true) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": false, |
|
|
"message": err.Error(), |
|
|
}) |
|
|
return |
|
|
} |
|
|
testRequest := buildTestRequest() |
|
|
tik := time.Now() |
|
|
err, _ = testChannel(channel, *testRequest) |
|
|
tok := time.Now() |
|
|
milliseconds := tok.Sub(tik).Milliseconds() |
|
|
go channel.UpdateResponseTime(milliseconds) |
|
|
consumedTime := float64(milliseconds) / 1000.0 |
|
|
if err != nil { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": false, |
|
|
"message": err.Error(), |
|
|
"time": consumedTime, |
|
|
}) |
|
|
return |
|
|
} |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": true, |
|
|
"message": "", |
|
|
"time": consumedTime, |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
var testAllChannelsLock sync.Mutex |
|
|
var testAllChannelsRunning bool = false |
|
|
|
|
|
|
|
|
func disableChannel(channelId int, channelName string, reason string) { |
|
|
if common.RootUserEmail == "" { |
|
|
common.RootUserEmail = model.GetRootUserEmail() |
|
|
} |
|
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) |
|
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) |
|
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) |
|
|
err := common.SendEmail(subject, common.RootUserEmail, content) |
|
|
if err != nil { |
|
|
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) |
|
|
} |
|
|
} |
|
|
|
|
|
func testAllChannels(notify bool) error { |
|
|
if common.RootUserEmail == "" { |
|
|
common.RootUserEmail = model.GetRootUserEmail() |
|
|
} |
|
|
testAllChannelsLock.Lock() |
|
|
if testAllChannelsRunning { |
|
|
testAllChannelsLock.Unlock() |
|
|
return errors.New("测试已在运行中") |
|
|
} |
|
|
testAllChannelsRunning = true |
|
|
testAllChannelsLock.Unlock() |
|
|
channels, err := model.GetAllChannels(0, 0, true) |
|
|
if err != nil { |
|
|
return err |
|
|
} |
|
|
testRequest := buildTestRequest() |
|
|
var disableThreshold = int64(common.ChannelDisableThreshold * 1000) |
|
|
if disableThreshold == 0 { |
|
|
disableThreshold = 10000000 |
|
|
} |
|
|
go func() { |
|
|
for _, channel := range channels { |
|
|
if channel.Status != common.ChannelStatusEnabled { |
|
|
continue |
|
|
} |
|
|
tik := time.Now() |
|
|
err, openaiErr := testChannel(channel, *testRequest) |
|
|
tok := time.Now() |
|
|
milliseconds := tok.Sub(tik).Milliseconds() |
|
|
if milliseconds > disableThreshold { |
|
|
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) |
|
|
disableChannel(channel.Id, channel.Name, err.Error()) |
|
|
} |
|
|
if shouldDisableChannel(openaiErr, -1) { |
|
|
disableChannel(channel.Id, channel.Name, err.Error()) |
|
|
} |
|
|
channel.UpdateResponseTime(milliseconds) |
|
|
time.Sleep(common.RequestInterval) |
|
|
} |
|
|
testAllChannelsLock.Lock() |
|
|
testAllChannelsRunning = false |
|
|
testAllChannelsLock.Unlock() |
|
|
if notify { |
|
|
err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") |
|
|
if err != nil { |
|
|
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) |
|
|
} |
|
|
} |
|
|
}() |
|
|
return nil |
|
|
} |
|
|
|
|
|
func TestAllChannels(c *gin.Context) { |
|
|
err := testAllChannels(true) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": false, |
|
|
"message": err.Error(), |
|
|
}) |
|
|
return |
|
|
} |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"success": true, |
|
|
"message": "", |
|
|
}) |
|
|
return |
|
|
} |
|
|
|
|
|
func AutomaticallyTestChannels(frequency int) { |
|
|
for { |
|
|
time.Sleep(time.Duration(frequency) * time.Minute) |
|
|
common.SysLog("testing all channels") |
|
|
_ = testAllChannels(false) |
|
|
common.SysLog("channel test finished") |
|
|
} |
|
|
} |
|
|
|