| | |
| | |
| | |
| | package middleware |
| |
|
| | import ( |
| | "bytes" |
| | "net/http" |
| | "strings" |
| |
|
| | "github.com/gin-gonic/gin" |
| | "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" |
| | "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" |
| | ) |
| |
|
| | |
| | type RequestInfo struct { |
| | URL string |
| | Method string |
| | Headers map[string][]string |
| | Body []byte |
| | RequestID string |
| | } |
| |
|
| | |
| | |
| | type ResponseWriterWrapper struct { |
| | gin.ResponseWriter |
| | body *bytes.Buffer |
| | isStreaming bool |
| | streamWriter logging.StreamingLogWriter |
| | chunkChannel chan []byte |
| | streamDone chan struct{} |
| | logger logging.RequestLogger |
| | requestInfo *RequestInfo |
| | statusCode int |
| | headers map[string][]string |
| | logOnErrorOnly bool |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper { |
| | return &ResponseWriterWrapper{ |
| | ResponseWriter: w, |
| | body: &bytes.Buffer{}, |
| | logger: logger, |
| | requestInfo: requestInfo, |
| | headers: make(map[string][]string), |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | func (w *ResponseWriterWrapper) Write(data []byte) (int, error) { |
| | |
| | |
| | w.ensureHeadersCaptured() |
| |
|
| | |
| | n, err := w.ResponseWriter.Write(data) |
| |
|
| | |
| | if w.isStreaming && w.chunkChannel != nil { |
| | |
| | select { |
| | case w.chunkChannel <- append([]byte(nil), data...): |
| | default: |
| | } |
| | return n, err |
| | } |
| |
|
| | if w.shouldBufferResponseBody() { |
| | w.body.Write(data) |
| | } |
| |
|
| | return n, err |
| | } |
| |
|
| | func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool { |
| | if w.logger != nil && w.logger.IsEnabled() { |
| | return true |
| | } |
| | if !w.logOnErrorOnly { |
| | return false |
| | } |
| | status := w.statusCode |
| | if status == 0 { |
| | if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil { |
| | status = statusWriter.Status() |
| | } else { |
| | status = http.StatusOK |
| | } |
| | } |
| | return status >= http.StatusBadRequest |
| | } |
| |
|
| | |
| | |
| | |
| | func (w *ResponseWriterWrapper) WriteString(data string) (int, error) { |
| | w.ensureHeadersCaptured() |
| |
|
| | |
| | n, err := w.ResponseWriter.WriteString(data) |
| |
|
| | |
| | if w.isStreaming && w.chunkChannel != nil { |
| | select { |
| | case w.chunkChannel <- []byte(data): |
| | default: |
| | } |
| | return n, err |
| | } |
| |
|
| | if w.shouldBufferResponseBody() { |
| | w.body.WriteString(data) |
| | } |
| | return n, err |
| | } |
| |
|
| | |
| | |
| | |
| | func (w *ResponseWriterWrapper) WriteHeader(statusCode int) { |
| | w.statusCode = statusCode |
| |
|
| | |
| | w.captureCurrentHeaders() |
| |
|
| | |
| | contentType := w.ResponseWriter.Header().Get("Content-Type") |
| | w.isStreaming = w.detectStreaming(contentType) |
| |
|
| | |
| | if w.isStreaming && w.logger.IsEnabled() { |
| | streamWriter, err := w.logger.LogStreamingRequest( |
| | w.requestInfo.URL, |
| | w.requestInfo.Method, |
| | w.requestInfo.Headers, |
| | w.requestInfo.Body, |
| | w.requestInfo.RequestID, |
| | ) |
| | if err == nil { |
| | w.streamWriter = streamWriter |
| | w.chunkChannel = make(chan []byte, 100) |
| | doneChan := make(chan struct{}) |
| | w.streamDone = doneChan |
| |
|
| | |
| | go w.processStreamingChunks(doneChan) |
| |
|
| | |
| | _ = streamWriter.WriteStatus(statusCode, w.headers) |
| | } |
| | } |
| |
|
| | |
| | w.ResponseWriter.WriteHeader(statusCode) |
| | } |
| |
|
| | |
| | |
| | |
| | func (w *ResponseWriterWrapper) ensureHeadersCaptured() { |
| | |
| | w.captureCurrentHeaders() |
| | } |
| |
|
| | |
| | |
| | func (w *ResponseWriterWrapper) captureCurrentHeaders() { |
| | |
| | if w.headers == nil { |
| | w.headers = make(map[string][]string) |
| | } |
| |
|
| | |
| | for key, values := range w.ResponseWriter.Header() { |
| | |
| | headerValues := make([]string, len(values)) |
| | copy(headerValues, values) |
| | w.headers[key] = headerValues |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool { |
| | |
| | if strings.Contains(contentType, "text/event-stream") { |
| | return true |
| | } |
| |
|
| | |
| | |
| | if strings.TrimSpace(contentType) != "" { |
| | return false |
| | } |
| |
|
| | |
| | if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { |
| | bodyStr := string(w.requestInfo.Body) |
| | return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) |
| | } |
| |
|
| | return false |
| | } |
| |
|
| | |
| | |
| | func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) { |
| | if done == nil { |
| | return |
| | } |
| |
|
| | defer close(done) |
| |
|
| | if w.streamWriter == nil || w.chunkChannel == nil { |
| | return |
| | } |
| |
|
| | for chunk := range w.chunkChannel { |
| | w.streamWriter.WriteChunkAsync(chunk) |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { |
| | if w.logger == nil { |
| | return nil |
| | } |
| |
|
| | finalStatusCode := w.statusCode |
| | if finalStatusCode == 0 { |
| | if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok { |
| | finalStatusCode = statusWriter.Status() |
| | } else { |
| | finalStatusCode = 200 |
| | } |
| | } |
| |
|
| | var slicesAPIResponseError []*interfaces.ErrorMessage |
| | apiResponseError, isExist := c.Get("API_RESPONSE_ERROR") |
| | if isExist { |
| | if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok { |
| | slicesAPIResponseError = apiErrors |
| | } |
| | } |
| |
|
| | hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest |
| | forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled() |
| | if !w.logger.IsEnabled() && !forceLog { |
| | return nil |
| | } |
| |
|
| | if w.isStreaming && w.streamWriter != nil { |
| | if w.chunkChannel != nil { |
| | close(w.chunkChannel) |
| | w.chunkChannel = nil |
| | } |
| |
|
| | if w.streamDone != nil { |
| | <-w.streamDone |
| | w.streamDone = nil |
| | } |
| |
|
| | |
| | apiRequest := w.extractAPIRequest(c) |
| | if len(apiRequest) > 0 { |
| | _ = w.streamWriter.WriteAPIRequest(apiRequest) |
| | } |
| | apiResponse := w.extractAPIResponse(c) |
| | if len(apiResponse) > 0 { |
| | _ = w.streamWriter.WriteAPIResponse(apiResponse) |
| | } |
| | if err := w.streamWriter.Close(); err != nil { |
| | w.streamWriter = nil |
| | return err |
| | } |
| | w.streamWriter = nil |
| | return nil |
| | } |
| |
|
| | return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog) |
| | } |
| |
|
| | func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string { |
| | w.ensureHeadersCaptured() |
| |
|
| | finalHeaders := make(map[string][]string, len(w.headers)) |
| | for key, values := range w.headers { |
| | headerValues := make([]string, len(values)) |
| | copy(headerValues, values) |
| | finalHeaders[key] = headerValues |
| | } |
| |
|
| | return finalHeaders |
| | } |
| |
|
| | func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte { |
| | apiRequest, isExist := c.Get("API_REQUEST") |
| | if !isExist { |
| | return nil |
| | } |
| | data, ok := apiRequest.([]byte) |
| | if !ok || len(data) == 0 { |
| | return nil |
| | } |
| | return data |
| | } |
| |
|
| | func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte { |
| | apiResponse, isExist := c.Get("API_RESPONSE") |
| | if !isExist { |
| | return nil |
| | } |
| | data, ok := apiResponse.([]byte) |
| | if !ok || len(data) == 0 { |
| | return nil |
| | } |
| | return data |
| | } |
| |
|
| | func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { |
| | if w.requestInfo == nil { |
| | return nil |
| | } |
| |
|
| | var requestBody []byte |
| | if len(w.requestInfo.Body) > 0 { |
| | requestBody = w.requestInfo.Body |
| | } |
| |
|
| | if loggerWithOptions, ok := w.logger.(interface { |
| | LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error |
| | }); ok { |
| | return loggerWithOptions.LogRequestWithOptions( |
| | w.requestInfo.URL, |
| | w.requestInfo.Method, |
| | w.requestInfo.Headers, |
| | requestBody, |
| | statusCode, |
| | headers, |
| | body, |
| | apiRequestBody, |
| | apiResponseBody, |
| | apiResponseErrors, |
| | forceLog, |
| | w.requestInfo.RequestID, |
| | ) |
| | } |
| |
|
| | return w.logger.LogRequest( |
| | w.requestInfo.URL, |
| | w.requestInfo.Method, |
| | w.requestInfo.Headers, |
| | requestBody, |
| | statusCode, |
| | headers, |
| | body, |
| | apiRequestBody, |
| | apiResponseBody, |
| | apiResponseErrors, |
| | w.requestInfo.RequestID, |
| | ) |
| | } |
| |
|