|
|
package types |
|
|
|
|
|
import ( |
|
|
"errors" |
|
|
"fmt" |
|
|
"net/http" |
|
|
"strings" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
) |
|
|
|
|
|
type OpenAIError struct { |
|
|
Message string `json:"message"` |
|
|
Type string `json:"type"` |
|
|
Param string `json:"param"` |
|
|
Code any `json:"code"` |
|
|
} |
|
|
|
|
|
type ClaudeError struct { |
|
|
Type string `json:"type,omitempty"` |
|
|
Message string `json:"message,omitempty"` |
|
|
} |
|
|
|
|
|
type ErrorType string |
|
|
|
|
|
const ( |
|
|
ErrorTypeNewAPIError ErrorType = "new_api_error" |
|
|
ErrorTypeOpenAIError ErrorType = "openai_error" |
|
|
ErrorTypeClaudeError ErrorType = "claude_error" |
|
|
ErrorTypeMidjourneyError ErrorType = "midjourney_error" |
|
|
ErrorTypeGeminiError ErrorType = "gemini_error" |
|
|
ErrorTypeRerankError ErrorType = "rerank_error" |
|
|
ErrorTypeUpstreamError ErrorType = "upstream_error" |
|
|
) |
|
|
|
|
|
type ErrorCode string |
|
|
|
|
|
const ( |
|
|
ErrorCodeInvalidRequest ErrorCode = "invalid_request" |
|
|
ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" |
|
|
|
|
|
|
|
|
ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" |
|
|
ErrorCodeModelPriceError ErrorCode = "model_price_error" |
|
|
ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" |
|
|
ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" |
|
|
ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" |
|
|
ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" |
|
|
ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" |
|
|
|
|
|
|
|
|
ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" |
|
|
ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" |
|
|
ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid" |
|
|
ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" |
|
|
ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" |
|
|
ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" |
|
|
ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" |
|
|
|
|
|
|
|
|
ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" |
|
|
ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed" |
|
|
ErrorCodeAccessDenied ErrorCode = "access_denied" |
|
|
|
|
|
|
|
|
ErrorCodeBadRequestBody ErrorCode = "bad_request_body" |
|
|
|
|
|
|
|
|
ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed" |
|
|
ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code" |
|
|
ErrorCodeBadResponse ErrorCode = "bad_response" |
|
|
ErrorCodeBadResponseBody ErrorCode = "bad_response_body" |
|
|
ErrorCodeEmptyResponse ErrorCode = "empty_response" |
|
|
ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" |
|
|
ErrorCodeModelNotFound ErrorCode = "model_not_found" |
|
|
ErrorCodePromptBlocked ErrorCode = "prompt_blocked" |
|
|
|
|
|
|
|
|
ErrorCodeQueryDataError ErrorCode = "query_data_error" |
|
|
ErrorCodeUpdateDataError ErrorCode = "update_data_error" |
|
|
|
|
|
|
|
|
ErrorCodeInsufficientUserQuota ErrorCode = "insufficient_user_quota" |
|
|
ErrorCodePreConsumeTokenQuotaFailed ErrorCode = "pre_consume_token_quota_failed" |
|
|
) |
|
|
|
|
|
type NewAPIError struct { |
|
|
Err error |
|
|
RelayError any |
|
|
skipRetry bool |
|
|
recordErrorLog *bool |
|
|
errorType ErrorType |
|
|
errorCode ErrorCode |
|
|
StatusCode int |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) GetErrorCode() ErrorCode { |
|
|
if e == nil { |
|
|
return "" |
|
|
} |
|
|
return e.errorCode |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) GetErrorType() ErrorType { |
|
|
if e == nil { |
|
|
return "" |
|
|
} |
|
|
return e.errorType |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) Error() string { |
|
|
if e == nil { |
|
|
return "" |
|
|
} |
|
|
if e.Err == nil { |
|
|
|
|
|
return string(e.errorCode) |
|
|
} |
|
|
return e.Err.Error() |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) MaskSensitiveError() string { |
|
|
if e == nil { |
|
|
return "" |
|
|
} |
|
|
if e.Err == nil { |
|
|
return string(e.errorCode) |
|
|
} |
|
|
errStr := e.Err.Error() |
|
|
if e.errorCode == ErrorCodeCountTokenFailed { |
|
|
return errStr |
|
|
} |
|
|
return common.MaskSensitiveInfo(errStr) |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) SetMessage(message string) { |
|
|
e.Err = errors.New(message) |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) ToOpenAIError() OpenAIError { |
|
|
var result OpenAIError |
|
|
switch e.errorType { |
|
|
case ErrorTypeOpenAIError: |
|
|
if openAIError, ok := e.RelayError.(OpenAIError); ok { |
|
|
result = openAIError |
|
|
} |
|
|
case ErrorTypeClaudeError: |
|
|
if claudeError, ok := e.RelayError.(ClaudeError); ok { |
|
|
result = OpenAIError{ |
|
|
Message: e.Error(), |
|
|
Type: claudeError.Type, |
|
|
Param: "", |
|
|
Code: e.errorCode, |
|
|
} |
|
|
} |
|
|
default: |
|
|
result = OpenAIError{ |
|
|
Message: e.Error(), |
|
|
Type: string(e.errorType), |
|
|
Param: "", |
|
|
Code: e.errorCode, |
|
|
} |
|
|
} |
|
|
if e.errorCode != ErrorCodeCountTokenFailed { |
|
|
result.Message = common.MaskSensitiveInfo(result.Message) |
|
|
} |
|
|
if result.Message == "" { |
|
|
result.Message = string(e.errorType) |
|
|
} |
|
|
return result |
|
|
} |
|
|
|
|
|
func (e *NewAPIError) ToClaudeError() ClaudeError { |
|
|
var result ClaudeError |
|
|
switch e.errorType { |
|
|
case ErrorTypeOpenAIError: |
|
|
if openAIError, ok := e.RelayError.(OpenAIError); ok { |
|
|
result = ClaudeError{ |
|
|
Message: e.Error(), |
|
|
Type: fmt.Sprintf("%v", openAIError.Code), |
|
|
} |
|
|
} |
|
|
case ErrorTypeClaudeError: |
|
|
if claudeError, ok := e.RelayError.(ClaudeError); ok { |
|
|
result = claudeError |
|
|
} |
|
|
default: |
|
|
result = ClaudeError{ |
|
|
Message: e.Error(), |
|
|
Type: string(e.errorType), |
|
|
} |
|
|
} |
|
|
if e.errorCode != ErrorCodeCountTokenFailed { |
|
|
result.Message = common.MaskSensitiveInfo(result.Message) |
|
|
} |
|
|
if result.Message == "" { |
|
|
result.Message = string(e.errorType) |
|
|
} |
|
|
return result |
|
|
} |
|
|
|
|
|
type NewAPIErrorOptions func(*NewAPIError) |
|
|
|
|
|
func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { |
|
|
var newErr *NewAPIError |
|
|
|
|
|
if errors.As(err, &newErr) { |
|
|
for _, op := range ops { |
|
|
op(newErr) |
|
|
} |
|
|
return newErr |
|
|
} |
|
|
e := &NewAPIError{ |
|
|
Err: err, |
|
|
RelayError: nil, |
|
|
errorType: ErrorTypeNewAPIError, |
|
|
StatusCode: http.StatusInternalServerError, |
|
|
errorCode: errorCode, |
|
|
} |
|
|
for _, op := range ops { |
|
|
op(e) |
|
|
} |
|
|
return e |
|
|
} |
|
|
|
|
|
func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
|
|
var newErr *NewAPIError |
|
|
|
|
|
if errors.As(err, &newErr) { |
|
|
if newErr.RelayError == nil { |
|
|
openaiError := OpenAIError{ |
|
|
Message: newErr.Error(), |
|
|
Type: string(errorCode), |
|
|
Code: errorCode, |
|
|
} |
|
|
newErr.RelayError = openaiError |
|
|
} |
|
|
for _, op := range ops { |
|
|
op(newErr) |
|
|
} |
|
|
return newErr |
|
|
} |
|
|
openaiError := OpenAIError{ |
|
|
Message: err.Error(), |
|
|
Type: string(errorCode), |
|
|
Code: errorCode, |
|
|
} |
|
|
return WithOpenAIError(openaiError, statusCode, ops...) |
|
|
} |
|
|
|
|
|
func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
|
|
openaiError := OpenAIError{ |
|
|
Type: string(errorCode), |
|
|
Code: errorCode, |
|
|
} |
|
|
return WithOpenAIError(openaiError, statusCode, ops...) |
|
|
} |
|
|
|
|
|
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
|
|
e := &NewAPIError{ |
|
|
Err: err, |
|
|
RelayError: OpenAIError{ |
|
|
Message: err.Error(), |
|
|
Type: string(errorCode), |
|
|
}, |
|
|
errorType: ErrorTypeNewAPIError, |
|
|
StatusCode: statusCode, |
|
|
errorCode: errorCode, |
|
|
} |
|
|
for _, op := range ops { |
|
|
op(e) |
|
|
} |
|
|
|
|
|
return e |
|
|
} |
|
|
|
|
|
func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
|
|
code, ok := openAIError.Code.(string) |
|
|
if !ok { |
|
|
if openAIError.Code != nil { |
|
|
code = fmt.Sprintf("%v", openAIError.Code) |
|
|
} else { |
|
|
code = "unknown_error" |
|
|
} |
|
|
} |
|
|
if openAIError.Type == "" { |
|
|
openAIError.Type = "upstream_error" |
|
|
} |
|
|
e := &NewAPIError{ |
|
|
RelayError: openAIError, |
|
|
errorType: ErrorTypeOpenAIError, |
|
|
StatusCode: statusCode, |
|
|
Err: errors.New(openAIError.Message), |
|
|
errorCode: ErrorCode(code), |
|
|
} |
|
|
for _, op := range ops { |
|
|
op(e) |
|
|
} |
|
|
return e |
|
|
} |
|
|
|
|
|
func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { |
|
|
if claudeError.Type == "" { |
|
|
claudeError.Type = "upstream_error" |
|
|
} |
|
|
e := &NewAPIError{ |
|
|
RelayError: claudeError, |
|
|
errorType: ErrorTypeClaudeError, |
|
|
StatusCode: statusCode, |
|
|
Err: errors.New(claudeError.Message), |
|
|
errorCode: ErrorCode(claudeError.Type), |
|
|
} |
|
|
for _, op := range ops { |
|
|
op(e) |
|
|
} |
|
|
return e |
|
|
} |
|
|
|
|
|
func IsChannelError(err *NewAPIError) bool { |
|
|
if err == nil { |
|
|
return false |
|
|
} |
|
|
return strings.HasPrefix(string(err.errorCode), "channel:") |
|
|
} |
|
|
|
|
|
func IsSkipRetryError(err *NewAPIError) bool { |
|
|
if err == nil { |
|
|
return false |
|
|
} |
|
|
|
|
|
return err.skipRetry |
|
|
} |
|
|
|
|
|
func ErrOptionWithSkipRetry() NewAPIErrorOptions { |
|
|
return func(e *NewAPIError) { |
|
|
e.skipRetry = true |
|
|
} |
|
|
} |
|
|
|
|
|
func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions { |
|
|
return func(e *NewAPIError) { |
|
|
e.recordErrorLog = common.GetPointer(false) |
|
|
} |
|
|
} |
|
|
|
|
|
func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions { |
|
|
return func(e *NewAPIError) { |
|
|
if common.DebugEnabled { |
|
|
fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err) |
|
|
} |
|
|
e.Err = errors.New(replaceStr) |
|
|
} |
|
|
} |
|
|
|
|
|
func IsRecordErrorLog(e *NewAPIError) bool { |
|
|
if e == nil { |
|
|
return false |
|
|
} |
|
|
if e.recordErrorLog == nil { |
|
|
|
|
|
return true |
|
|
} |
|
|
return *e.recordErrorLog |
|
|
} |
|
|
|