Spaces:
Build error
Build error
| package types | |
| import ( | |
| "errors" | |
| "fmt" | |
| "net/http" | |
| "one-api/common" | |
| "strings" | |
| ) | |
| type OpenAIError struct { | |
| Message string `json:"message"` | |
| Type string `json:"type"` | |
| Param string `json:"param"` | |
| Code any `json:"code"` | |
| } | |
| type ClaudeError struct { | |
| Type string `json:"type,omitempty"` | |
| Message string `json:"message,omitempty"` | |
| } | |
| type ErrorType string | |
| const ( | |
| ErrorTypeNewAPIError ErrorType = "new_api_error" | |
| ErrorTypeOpenAIError ErrorType = "openai_error" | |
| ErrorTypeClaudeError ErrorType = "claude_error" | |
| ErrorTypeMidjourneyError ErrorType = "midjourney_error" | |
| ErrorTypeGeminiError ErrorType = "gemini_error" | |
| ErrorTypeRerankError ErrorType = "rerank_error" | |
| ErrorTypeUpstreamError ErrorType = "upstream_error" | |
| ) | |
| type ErrorCode string | |
| const ( | |
| ErrorCodeInvalidRequest ErrorCode = "invalid_request" | |
| ErrorCodeSensitiveWordsDetected ErrorCode = "sensitive_words_detected" | |
| // new api error | |
| ErrorCodeCountTokenFailed ErrorCode = "count_token_failed" | |
| ErrorCodeModelPriceError ErrorCode = "model_price_error" | |
| ErrorCodeInvalidApiType ErrorCode = "invalid_api_type" | |
| ErrorCodeJsonMarshalFailed ErrorCode = "json_marshal_failed" | |
| ErrorCodeDoRequestFailed ErrorCode = "do_request_failed" | |
| ErrorCodeGetChannelFailed ErrorCode = "get_channel_failed" | |
| ErrorCodeGenRelayInfoFailed ErrorCode = "gen_relay_info_failed" | |
| // channel error | |
| ErrorCodeChannelNoAvailableKey ErrorCode = "channel:no_available_key" | |
| ErrorCodeChannelParamOverrideInvalid ErrorCode = "channel:param_override_invalid" | |
| ErrorCodeChannelHeaderOverrideInvalid ErrorCode = "channel:header_override_invalid" | |
| ErrorCodeChannelModelMappedError ErrorCode = "channel:model_mapped_error" | |
| ErrorCodeChannelAwsClientError ErrorCode = "channel:aws_client_error" | |
| ErrorCodeChannelInvalidKey ErrorCode = "channel:invalid_key" | |
| ErrorCodeChannelResponseTimeExceeded ErrorCode = "channel:response_time_exceeded" | |
| // client request error | |
| ErrorCodeReadRequestBodyFailed ErrorCode = "read_request_body_failed" | |
| ErrorCodeConvertRequestFailed ErrorCode = "convert_request_failed" | |
| ErrorCodeAccessDenied ErrorCode = "access_denied" | |
| // response error | |
| ErrorCodeReadResponseBodyFailed ErrorCode = "read_response_body_failed" | |
| ErrorCodeBadResponseStatusCode ErrorCode = "bad_response_status_code" | |
| ErrorCodeBadResponse ErrorCode = "bad_response" | |
| ErrorCodeBadResponseBody ErrorCode = "bad_response_body" | |
| ErrorCodeEmptyResponse ErrorCode = "empty_response" | |
| ErrorCodeAwsInvokeError ErrorCode = "aws_invoke_error" | |
| ErrorCodeModelNotFound ErrorCode = "model_not_found" | |
| // sql error | |
| ErrorCodeQueryDataError ErrorCode = "query_data_error" | |
| ErrorCodeUpdateDataError ErrorCode = "update_data_error" | |
| // quota error | |
| ErrorCodeInsufficientUserQuota ErrorCode = "insufficient_user_quota" | |
| ErrorCodePreConsumeTokenQuotaFailed ErrorCode = "pre_consume_token_quota_failed" | |
| ) | |
| type NewAPIError struct { | |
| Err error | |
| RelayError any | |
| skipRetry bool | |
| recordErrorLog *bool | |
| errorType ErrorType | |
| errorCode ErrorCode | |
| StatusCode int | |
| } | |
| func (e *NewAPIError) GetErrorCode() ErrorCode { | |
| if e == nil { | |
| return "" | |
| } | |
| return e.errorCode | |
| } | |
| func (e *NewAPIError) GetErrorType() ErrorType { | |
| if e == nil { | |
| return "" | |
| } | |
| return e.errorType | |
| } | |
| func (e *NewAPIError) Error() string { | |
| if e == nil { | |
| return "" | |
| } | |
| if e.Err == nil { | |
| // fallback message when underlying error is missing | |
| return string(e.errorCode) | |
| } | |
| return e.Err.Error() | |
| } | |
| func (e *NewAPIError) MaskSensitiveError() string { | |
| if e == nil { | |
| return "" | |
| } | |
| if e.Err == nil { | |
| return string(e.errorCode) | |
| } | |
| errStr := e.Err.Error() | |
| if e.errorCode == ErrorCodeCountTokenFailed { | |
| return errStr | |
| } | |
| return common.MaskSensitiveInfo(errStr) | |
| } | |
| func (e *NewAPIError) SetMessage(message string) { | |
| e.Err = errors.New(message) | |
| } | |
| func (e *NewAPIError) ToOpenAIError() OpenAIError { | |
| var result OpenAIError | |
| switch e.errorType { | |
| case ErrorTypeOpenAIError: | |
| if openAIError, ok := e.RelayError.(OpenAIError); ok { | |
| result = openAIError | |
| } | |
| case ErrorTypeClaudeError: | |
| if claudeError, ok := e.RelayError.(ClaudeError); ok { | |
| result = OpenAIError{ | |
| Message: e.Error(), | |
| Type: claudeError.Type, | |
| Param: "", | |
| Code: e.errorCode, | |
| } | |
| } | |
| default: | |
| result = OpenAIError{ | |
| Message: e.Error(), | |
| Type: string(e.errorType), | |
| Param: "", | |
| Code: e.errorCode, | |
| } | |
| } | |
| if e.errorCode != ErrorCodeCountTokenFailed { | |
| result.Message = common.MaskSensitiveInfo(result.Message) | |
| } | |
| return result | |
| } | |
| func (e *NewAPIError) ToClaudeError() ClaudeError { | |
| var result ClaudeError | |
| switch e.errorType { | |
| case ErrorTypeOpenAIError: | |
| if openAIError, ok := e.RelayError.(OpenAIError); ok { | |
| result = ClaudeError{ | |
| Message: e.Error(), | |
| Type: fmt.Sprintf("%v", openAIError.Code), | |
| } | |
| } | |
| case ErrorTypeClaudeError: | |
| if claudeError, ok := e.RelayError.(ClaudeError); ok { | |
| result = claudeError | |
| } | |
| default: | |
| result = ClaudeError{ | |
| Message: e.Error(), | |
| Type: string(e.errorType), | |
| } | |
| } | |
| if e.errorCode != ErrorCodeCountTokenFailed { | |
| result.Message = common.MaskSensitiveInfo(result.Message) | |
| } | |
| return result | |
| } | |
| type NewAPIErrorOptions func(*NewAPIError) | |
| func NewError(err error, errorCode ErrorCode, ops ...NewAPIErrorOptions) *NewAPIError { | |
| var newErr *NewAPIError | |
| // 保留深层传递的 new err | |
| if errors.As(err, &newErr) { | |
| for _, op := range ops { | |
| op(newErr) | |
| } | |
| return newErr | |
| } | |
| e := &NewAPIError{ | |
| Err: err, | |
| RelayError: nil, | |
| errorType: ErrorTypeNewAPIError, | |
| StatusCode: http.StatusInternalServerError, | |
| errorCode: errorCode, | |
| } | |
| for _, op := range ops { | |
| op(e) | |
| } | |
| return e | |
| } | |
| func NewOpenAIError(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { | |
| var newErr *NewAPIError | |
| // 保留深层传递的 new err | |
| if errors.As(err, &newErr) { | |
| if newErr.RelayError == nil { | |
| openaiError := OpenAIError{ | |
| Message: newErr.Error(), | |
| Type: string(errorCode), | |
| Code: errorCode, | |
| } | |
| newErr.RelayError = openaiError | |
| } | |
| for _, op := range ops { | |
| op(newErr) | |
| } | |
| return newErr | |
| } | |
| openaiError := OpenAIError{ | |
| Message: err.Error(), | |
| Type: string(errorCode), | |
| Code: errorCode, | |
| } | |
| return WithOpenAIError(openaiError, statusCode, ops...) | |
| } | |
| func InitOpenAIError(errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { | |
| openaiError := OpenAIError{ | |
| Type: string(errorCode), | |
| Code: errorCode, | |
| } | |
| return WithOpenAIError(openaiError, statusCode, ops...) | |
| } | |
| func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { | |
| e := &NewAPIError{ | |
| Err: err, | |
| RelayError: OpenAIError{ | |
| Message: err.Error(), | |
| Type: string(errorCode), | |
| }, | |
| errorType: ErrorTypeNewAPIError, | |
| StatusCode: statusCode, | |
| errorCode: errorCode, | |
| } | |
| for _, op := range ops { | |
| op(e) | |
| } | |
| return e | |
| } | |
| func WithOpenAIError(openAIError OpenAIError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { | |
| code, ok := openAIError.Code.(string) | |
| if !ok { | |
| if openAIError.Code != nil { | |
| code = fmt.Sprintf("%v", openAIError.Code) | |
| } else { | |
| code = "unknown_error" | |
| } | |
| } | |
| if openAIError.Type == "" { | |
| openAIError.Type = "upstream_error" | |
| } | |
| e := &NewAPIError{ | |
| RelayError: openAIError, | |
| errorType: ErrorTypeOpenAIError, | |
| StatusCode: statusCode, | |
| Err: errors.New(openAIError.Message), | |
| errorCode: ErrorCode(code), | |
| } | |
| for _, op := range ops { | |
| op(e) | |
| } | |
| return e | |
| } | |
| func WithClaudeError(claudeError ClaudeError, statusCode int, ops ...NewAPIErrorOptions) *NewAPIError { | |
| if claudeError.Type == "" { | |
| claudeError.Type = "upstream_error" | |
| } | |
| e := &NewAPIError{ | |
| RelayError: claudeError, | |
| errorType: ErrorTypeClaudeError, | |
| StatusCode: statusCode, | |
| Err: errors.New(claudeError.Message), | |
| errorCode: ErrorCode(claudeError.Type), | |
| } | |
| for _, op := range ops { | |
| op(e) | |
| } | |
| return e | |
| } | |
| func IsChannelError(err *NewAPIError) bool { | |
| if err == nil { | |
| return false | |
| } | |
| return strings.HasPrefix(string(err.errorCode), "channel:") | |
| } | |
| func IsSkipRetryError(err *NewAPIError) bool { | |
| if err == nil { | |
| return false | |
| } | |
| return err.skipRetry | |
| } | |
| func ErrOptionWithSkipRetry() NewAPIErrorOptions { | |
| return func(e *NewAPIError) { | |
| e.skipRetry = true | |
| } | |
| } | |
| func ErrOptionWithNoRecordErrorLog() NewAPIErrorOptions { | |
| return func(e *NewAPIError) { | |
| e.recordErrorLog = common.GetPointer(false) | |
| } | |
| } | |
| func ErrOptionWithHideErrMsg(replaceStr string) NewAPIErrorOptions { | |
| return func(e *NewAPIError) { | |
| if common.DebugEnabled { | |
| fmt.Printf("ErrOptionWithHideErrMsg: %s, origin error: %s", replaceStr, e.Err) | |
| } | |
| e.Err = errors.New(replaceStr) | |
| } | |
| } | |
| func IsRecordErrorLog(e *NewAPIError) bool { | |
| if e == nil { | |
| return false | |
| } | |
| if e.recordErrorLog == nil { | |
| // default to true if not set | |
| return true | |
| } | |
| return *e.recordErrorLog | |
| } | |