|
|
package middleware |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"github.com/emirpasic/gods/v2/queues/circularbuffer" |
|
|
"io" |
|
|
"net/http" |
|
|
"sort" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/labstack/echo/v4" |
|
|
"github.com/mudler/LocalAI/core/application" |
|
|
"github.com/mudler/xlog" |
|
|
) |
|
|
|
|
|
type APIExchangeRequest struct { |
|
|
Method string `json:"method"` |
|
|
Path string `json:"path"` |
|
|
Headers *http.Header `json:"headers"` |
|
|
Body *[]byte `json:"body"` |
|
|
} |
|
|
|
|
|
type APIExchangeResponse struct { |
|
|
Status int `json:"status"` |
|
|
Headers *http.Header `json:"headers"` |
|
|
Body *[]byte `json:"body"` |
|
|
} |
|
|
|
|
|
type APIExchange struct { |
|
|
Timestamp time.Time `json:"timestamp"` |
|
|
Request APIExchangeRequest `json:"request"` |
|
|
Response APIExchangeResponse `json:"response"` |
|
|
} |
|
|
|
|
|
var traceBuffer *circularbuffer.Queue[APIExchange] |
|
|
var mu sync.Mutex |
|
|
var logChan = make(chan APIExchange, 100) |
|
|
|
|
|
type bodyWriter struct { |
|
|
http.ResponseWriter |
|
|
body *bytes.Buffer |
|
|
} |
|
|
|
|
|
func (w *bodyWriter) Write(b []byte) (int, error) { |
|
|
w.body.Write(b) |
|
|
return w.ResponseWriter.Write(b) |
|
|
} |
|
|
|
|
|
func (w *bodyWriter) Flush() { |
|
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok { |
|
|
flusher.Flush() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc { |
|
|
if app.ApplicationConfig().EnableTracing && traceBuffer == nil { |
|
|
traceBuffer = circularbuffer.New[APIExchange](app.ApplicationConfig().TracingMaxItems) |
|
|
|
|
|
go func() { |
|
|
for exchange := range logChan { |
|
|
mu.Lock() |
|
|
traceBuffer.Enqueue(exchange) |
|
|
mu.Unlock() |
|
|
} |
|
|
}() |
|
|
} |
|
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
|
return func(c echo.Context) error { |
|
|
if !app.ApplicationConfig().EnableTracing { |
|
|
return next(c) |
|
|
} |
|
|
|
|
|
if c.Request().Header.Get("Content-Type") != "application/json" { |
|
|
return next(c) |
|
|
} |
|
|
|
|
|
body, err := io.ReadAll(c.Request().Body) |
|
|
if err != nil { |
|
|
xlog.Error("Failed to read request body") |
|
|
return err |
|
|
} |
|
|
|
|
|
|
|
|
c.Request().Body = io.NopCloser(bytes.NewBuffer(body)) |
|
|
|
|
|
startTime := time.Now() |
|
|
|
|
|
|
|
|
resBody := new(bytes.Buffer) |
|
|
mw := &bodyWriter{ |
|
|
ResponseWriter: c.Response().Writer, |
|
|
body: resBody, |
|
|
} |
|
|
c.Response().Writer = mw |
|
|
|
|
|
err = next(c) |
|
|
if err != nil { |
|
|
c.Response().Writer = mw.ResponseWriter |
|
|
return err |
|
|
} |
|
|
|
|
|
|
|
|
requestHeaders := c.Request().Header.Clone() |
|
|
requestBody := make([]byte, len(body)) |
|
|
copy(requestBody, body) |
|
|
responseHeaders := c.Response().Header().Clone() |
|
|
responseBody := make([]byte, resBody.Len()) |
|
|
copy(responseBody, resBody.Bytes()) |
|
|
exchange := APIExchange{ |
|
|
Timestamp: startTime, |
|
|
Request: APIExchangeRequest{ |
|
|
Method: c.Request().Method, |
|
|
Path: c.Path(), |
|
|
Headers: &requestHeaders, |
|
|
Body: &requestBody, |
|
|
}, |
|
|
Response: APIExchangeResponse{ |
|
|
Status: c.Response().Status, |
|
|
Headers: &responseHeaders, |
|
|
Body: &responseBody, |
|
|
}, |
|
|
} |
|
|
|
|
|
select { |
|
|
case logChan <- exchange: |
|
|
default: |
|
|
xlog.Warn("Trace channel full, dropping trace") |
|
|
} |
|
|
|
|
|
return nil |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func GetTraces() []APIExchange { |
|
|
mu.Lock() |
|
|
traces := traceBuffer.Values() |
|
|
mu.Unlock() |
|
|
|
|
|
sort.Slice(traces, func(i, j int) bool { |
|
|
return traces[i].Timestamp.Before(traces[j].Timestamp) |
|
|
}) |
|
|
|
|
|
return traces |
|
|
} |
|
|
|
|
|
|
|
|
func ClearTraces() { |
|
|
mu.Lock() |
|
|
traceBuffer.Clear() |
|
|
mu.Unlock() |
|
|
} |
|
|
|