File size: 5,069 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package middleware

import (
	"crypto/subtle"
	"errors"
	"net/http"
	"strings"

	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
	"github.com/mudler/LocalAI/core/config"
	"github.com/mudler/LocalAI/core/schema"
)

var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")

// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
	// Create validator function
	validator := getApiKeyValidationFunction(applicationConfig)

	// Create error handler
	errorHandler := getApiKeyErrorHandler(applicationConfig)

	// Create Next function (skip middleware for certain requests)
	skipper := getApiKeyRequiredFilterFunction(applicationConfig)

	// Wrap it with our custom key lookup that checks multiple sources
	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			if len(applicationConfig.ApiKeys) == 0 {
				return next(c)
			}

			// Skip if skipper says so
			if skipper != nil && skipper(c) {
				return next(c)
			}

			// Try to extract key from multiple sources
			key, err := extractKeyFromMultipleSources(c)
			if err != nil {
				return errorHandler(err, c)
			}

			// Validate the key
			valid, err := validator(key, c)
			if err != nil || !valid {
				return errorHandler(ErrMissingOrMalformedAPIKey, c)
			}

			// Store key in context for later use
			c.Set("api_key", key)

			return next(c)
		}
	}, nil
}

// extractKeyFromMultipleSources checks multiple sources for the API key
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
	// Check Authorization header first
	auth := c.Request().Header.Get("Authorization")
	if auth != "" {
		// Check for Bearer scheme
		if strings.HasPrefix(auth, "Bearer ") {
			return strings.TrimPrefix(auth, "Bearer "), nil
		}
		// If no Bearer prefix, return as-is (for backward compatibility)
		return auth, nil
	}

	// Check x-api-key header
	if key := c.Request().Header.Get("x-api-key"); key != "" {
		return key, nil
	}

	// Check xi-api-key header
	if key := c.Request().Header.Get("xi-api-key"); key != "" {
		return key, nil
	}

	// Check token cookie
	cookie, err := c.Cookie("token")
	if err == nil && cookie != nil && cookie.Value != "" {
		return cookie.Value, nil
	}

	return "", ErrMissingOrMalformedAPIKey
}

func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
	return func(err error, c echo.Context) error {
		if errors.Is(err, ErrMissingOrMalformedAPIKey) {
			if len(applicationConfig.ApiKeys) == 0 {
				return nil // if no keys are set up, any error we get here is not an error.
			}
			c.Response().Header().Set("WWW-Authenticate", "Bearer")
			if applicationConfig.OpaqueErrors {
				return c.NoContent(http.StatusUnauthorized)
			}

			// Check if the request content type is JSON
			contentType := c.Request().Header.Get("Content-Type")
			if strings.Contains(contentType, "application/json") {
				return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
					Error: &schema.APIError{
						Message: "An authentication key is required",
						Code:    401,
						Type:    "invalid_request_error",
					},
				})
			}

			return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
				"BaseURL": BaseURL(c),
			})
		}
		if applicationConfig.OpaqueErrors {
			return c.NoContent(http.StatusInternalServerError)
		}
		return err
	}
}

func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
	if applicationConfig.UseSubtleKeyComparison {
		return func(key string, c echo.Context) (bool, error) {
			if len(applicationConfig.ApiKeys) == 0 {
				return true, nil // If no keys are setup, accept everything
			}
			for _, validKey := range applicationConfig.ApiKeys {
				if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
					return true, nil
				}
			}
			return false, ErrMissingOrMalformedAPIKey
		}
	}

	return func(key string, c echo.Context) (bool, error) {
		if len(applicationConfig.ApiKeys) == 0 {
			return true, nil // If no keys are setup, accept everything
		}
		for _, validKey := range applicationConfig.ApiKeys {
			if key == validKey {
				return true, nil
			}
		}
		return false, ErrMissingOrMalformedAPIKey
	}
}

func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
	return func(c echo.Context) bool {
		path := c.Request().URL.Path

		for _, p := range applicationConfig.PathWithoutAuth {
			if strings.HasPrefix(path, p) {
				return true
			}
		}

		// Handle GET request exemptions if enabled
		if applicationConfig.DisableApiKeyRequirementForHttpGet {
			if c.Request().Method != http.MethodGet {
				return false
			}
			for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
				if rx.MatchString(c.Path()) {
					return true
				}
			}
		}

		return false
	}
}