| package requestbody |
|
|
| import ( |
| "bytes" |
| "errors" |
| "io" |
| "mime" |
| "net/http" |
| "strings" |
| "unicode/utf8" |
| ) |
|
|
| var ( |
| ErrInvalidUTF8Body = errors.New("invalid utf-8 request body") |
| errRequestBodyTooLarge = errors.New("request body too large") |
| ) |
|
|
| const maxJSONUTF8ValidationSize = 100 << 20 |
|
|
| |
| |
| func ValidateJSONUTF8(next http.Handler) http.Handler { |
| if next == nil { |
| return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) |
| } |
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if shouldValidateJSONBody(r) { |
| r.Body = validateAndReplayBody(r.Body) |
| } |
| next.ServeHTTP(w, r) |
| }) |
| } |
|
|
| func shouldValidateJSONBody(r *http.Request) bool { |
| if r == nil || r.Body == nil { |
| return false |
| } |
| path := "" |
| if r.URL != nil { |
| path = r.URL.Path |
| } |
| return isJSONContentType(r.Header.Get("Content-Type")) || isKnownJSONRequestPath(r.Method, path) |
| } |
|
|
| func isJSONContentType(raw string) bool { |
| raw = strings.TrimSpace(raw) |
| if raw == "" { |
| return false |
| } |
| mediaType, _, err := mime.ParseMediaType(raw) |
| if err != nil { |
| mediaType = raw |
| } |
| mediaType = strings.ToLower(strings.TrimSpace(mediaType)) |
| return strings.Contains(mediaType, "json") |
| } |
|
|
| func isKnownJSONRequestPath(method, path string) bool { |
| switch strings.ToUpper(strings.TrimSpace(method)) { |
| case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: |
| default: |
| return false |
| } |
| path = strings.TrimSpace(path) |
| if path == "" { |
| return false |
| } |
| switch { |
| case path == "/v1/chat/completions" || path == "/chat/completions": |
| return true |
| case path == "/v1/responses" || path == "/responses": |
| return true |
| case path == "/v1/embeddings" || path == "/embeddings": |
| return true |
| case path == "/anthropic/v1/messages" || path == "/v1/messages" || path == "/messages": |
| return true |
| case path == "/anthropic/v1/messages/count_tokens" || path == "/v1/messages/count_tokens" || path == "/messages/count_tokens": |
| return true |
| case strings.HasPrefix(path, "/v1beta/models/") || strings.HasPrefix(path, "/v1/models/"): |
| return strings.Contains(path, ":generateContent") || strings.Contains(path, ":streamGenerateContent") |
| case strings.HasPrefix(path, "/admin/"): |
| return true |
| default: |
| return false |
| } |
| } |
|
|
| func validateAndReplayBody(body io.ReadCloser) io.ReadCloser { |
| if body == nil { |
| return body |
| } |
| raw, err := io.ReadAll(io.LimitReader(body, maxJSONUTF8ValidationSize+1)) |
| if err != nil { |
| return &errorReadCloser{err: err, closer: body} |
| } |
| if len(raw) > maxJSONUTF8ValidationSize { |
| return &errorReadCloser{err: errRequestBodyTooLarge, closer: body} |
| } |
| if !utf8.Valid(raw) { |
| return &errorReadCloser{err: ErrInvalidUTF8Body, closer: body} |
| } |
| return &replayReadCloser{Reader: bytes.NewReader(raw), closer: body} |
| } |
|
|
| type replayReadCloser struct { |
| *bytes.Reader |
| closer io.Closer |
| } |
|
|
| func (r *replayReadCloser) Close() error { |
| if r == nil || r.closer == nil { |
| return nil |
| } |
| return r.closer.Close() |
| } |
|
|
| type errorReadCloser struct { |
| err error |
| closer io.Closer |
| } |
|
|
| func (r *errorReadCloser) Read([]byte) (int, error) { |
| if r == nil || r.err == nil { |
| return 0, io.EOF |
| } |
| return 0, r.err |
| } |
|
|
| func (r *errorReadCloser) Close() error { |
| if r == nil || r.closer == nil { |
| return nil |
| } |
| return r.closer.Close() |
| } |
|
|