|
|
package service |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"errors" |
|
|
"fmt" |
|
|
"io" |
|
|
"net/http" |
|
|
"strconv" |
|
|
"strings" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/QuantumNous/new-api/dto" |
|
|
"github.com/QuantumNous/new-api/logger" |
|
|
"github.com/QuantumNous/new-api/types" |
|
|
) |
|
|
|
|
|
func MidjourneyErrorWrapper(code int, desc string) *dto.MidjourneyResponse { |
|
|
return &dto.MidjourneyResponse{ |
|
|
Code: code, |
|
|
Description: desc, |
|
|
} |
|
|
} |
|
|
|
|
|
func MidjourneyErrorWithStatusCodeWrapper(code int, desc string, statusCode int) *dto.MidjourneyResponseWithStatusCode { |
|
|
return &dto.MidjourneyResponseWithStatusCode{ |
|
|
StatusCode: statusCode, |
|
|
Response: *MidjourneyErrorWrapper(code, desc), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func ClaudeErrorWrapper(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { |
|
|
text := err.Error() |
|
|
lowerText := strings.ToLower(text) |
|
|
if !strings.HasPrefix(lowerText, "get file base64 from url") { |
|
|
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { |
|
|
common.SysLog(fmt.Sprintf("error: %s", text)) |
|
|
text = "请求上游地址失败" |
|
|
} |
|
|
} |
|
|
claudeError := types.ClaudeError{ |
|
|
Message: text, |
|
|
Type: "new_api_error", |
|
|
} |
|
|
return &dto.ClaudeErrorWithStatusCode{ |
|
|
Error: claudeError, |
|
|
StatusCode: statusCode, |
|
|
} |
|
|
} |
|
|
|
|
|
func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.ClaudeErrorWithStatusCode { |
|
|
claudeErr := ClaudeErrorWrapper(err, code, statusCode) |
|
|
claudeErr.LocalError = true |
|
|
return claudeErr |
|
|
} |
|
|
|
|
|
func RelayErrorHandler(ctx context.Context, resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) { |
|
|
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode) |
|
|
|
|
|
responseBody, err := io.ReadAll(resp.Body) |
|
|
if err != nil { |
|
|
return |
|
|
} |
|
|
CloseResponseBodyGracefully(resp) |
|
|
var errResponse dto.GeneralErrorResponse |
|
|
|
|
|
err = common.Unmarshal(responseBody, &errResponse) |
|
|
if err != nil { |
|
|
if showBodyWhenFail { |
|
|
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) |
|
|
} else { |
|
|
if common.DebugEnabled { |
|
|
logger.LogInfo(ctx, fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) |
|
|
} |
|
|
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) |
|
|
} |
|
|
return |
|
|
} |
|
|
if errResponse.Error.Message != "" { |
|
|
|
|
|
newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode) |
|
|
} else { |
|
|
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode) |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
func ResetStatusCode(newApiErr *types.NewAPIError, statusCodeMappingStr string) { |
|
|
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" { |
|
|
return |
|
|
} |
|
|
statusCodeMapping := make(map[string]string) |
|
|
err := common.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) |
|
|
if err != nil { |
|
|
return |
|
|
} |
|
|
if newApiErr.StatusCode == http.StatusOK { |
|
|
return |
|
|
} |
|
|
codeStr := strconv.Itoa(newApiErr.StatusCode) |
|
|
if _, ok := statusCodeMapping[codeStr]; ok { |
|
|
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr]) |
|
|
newApiErr.StatusCode = intCode |
|
|
} |
|
|
} |
|
|
|
|
|
func TaskErrorWrapperLocal(err error, code string, statusCode int) *dto.TaskError { |
|
|
openaiErr := TaskErrorWrapper(err, code, statusCode) |
|
|
openaiErr.LocalError = true |
|
|
return openaiErr |
|
|
} |
|
|
|
|
|
func TaskErrorWrapper(err error, code string, statusCode int) *dto.TaskError { |
|
|
text := err.Error() |
|
|
lowerText := strings.ToLower(text) |
|
|
if strings.Contains(lowerText, "post") || strings.Contains(lowerText, "dial") || strings.Contains(lowerText, "http") { |
|
|
common.SysLog(fmt.Sprintf("error: %s", text)) |
|
|
|
|
|
text = common.MaskSensitiveInfo(text) |
|
|
} |
|
|
|
|
|
taskError := &dto.TaskError{ |
|
|
Code: code, |
|
|
Message: text, |
|
|
StatusCode: statusCode, |
|
|
Error: err, |
|
|
} |
|
|
|
|
|
return taskError |
|
|
} |
|
|
|