File size: 3,601 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package middleware

import (
	"bytes"
	"github.com/emirpasic/gods/v2/queues/circularbuffer"
	"io"
	"net/http"
	"sort"
	"sync"
	"time"

	"github.com/labstack/echo/v4"
	"github.com/mudler/LocalAI/core/application"
	"github.com/mudler/xlog"
)

type APIExchangeRequest struct {
	Method  string       `json:"method"`
	Path    string       `json:"path"`
	Headers *http.Header `json:"headers"`
	Body    *[]byte      `json:"body"`
}

type APIExchangeResponse struct {
	Status  int          `json:"status"`
	Headers *http.Header `json:"headers"`
	Body    *[]byte      `json:"body"`
}

type APIExchange struct {
	Timestamp time.Time           `json:"timestamp"`
	Request   APIExchangeRequest  `json:"request"`
	Response  APIExchangeResponse `json:"response"`
}

var traceBuffer *circularbuffer.Queue[APIExchange]
var mu sync.Mutex
var logChan = make(chan APIExchange, 100)

type bodyWriter struct {
	http.ResponseWriter
	body *bytes.Buffer
}

func (w *bodyWriter) Write(b []byte) (int, error) {
	w.body.Write(b)
	return w.ResponseWriter.Write(b)
}

func (w *bodyWriter) Flush() {
	if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
		flusher.Flush()
	}
}

// TraceMiddleware intercepts and logs JSON API requests and responses
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
	if app.ApplicationConfig().EnableTracing && traceBuffer == nil {
		traceBuffer = circularbuffer.New[APIExchange](app.ApplicationConfig().TracingMaxItems)

		go func() {
			for exchange := range logChan {
				mu.Lock()
				traceBuffer.Enqueue(exchange)
				mu.Unlock()
			}
		}()
	}

	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			if !app.ApplicationConfig().EnableTracing {
				return next(c)
			}

			if c.Request().Header.Get("Content-Type") != "application/json" {
				return next(c)
			}

			body, err := io.ReadAll(c.Request().Body)
			if err != nil {
				xlog.Error("Failed to read request body")
				return err
			}

			// Restore the body for downstream handlers
			c.Request().Body = io.NopCloser(bytes.NewBuffer(body))

			startTime := time.Now()

			// Wrap response writer to capture body
			resBody := new(bytes.Buffer)
			mw := &bodyWriter{
				ResponseWriter: c.Response().Writer,
				body:           resBody,
			}
			c.Response().Writer = mw

			err = next(c)
			if err != nil {
				c.Response().Writer = mw.ResponseWriter // Restore original writer if error
				return err
			}

			// Create exchange log
			requestHeaders := c.Request().Header.Clone()
			requestBody := make([]byte, len(body))
			copy(requestBody, body)
			responseHeaders := c.Response().Header().Clone()
			responseBody := make([]byte, resBody.Len())
			copy(responseBody, resBody.Bytes())
			exchange := APIExchange{
				Timestamp: startTime,
				Request: APIExchangeRequest{
					Method:  c.Request().Method,
					Path:    c.Path(),
					Headers: &requestHeaders,
					Body:    &requestBody,
				},
				Response: APIExchangeResponse{
					Status:  c.Response().Status,
					Headers: &responseHeaders,
					Body:    &responseBody,
				},
			}

			select {
			case logChan <- exchange:
			default:
				xlog.Warn("Trace channel full, dropping trace")
			}

			return nil
		}
	}
}

// GetTraces returns a copy of the logged API exchanges for display
func GetTraces() []APIExchange {
	mu.Lock()
	traces := traceBuffer.Values()
	mu.Unlock()

	sort.Slice(traces, func(i, j int) bool {
		return traces[i].Timestamp.Before(traces[j].Timestamp)
	})

	return traces
}

// ClearTraces clears the in-memory logs
func ClearTraces() {
	mu.Lock()
	traceBuffer.Clear()
	mu.Unlock()
}