File size: 5,069 Bytes
0f07ba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
)
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
// Create validator function
validator := getApiKeyValidationFunction(applicationConfig)
// Create error handler
errorHandler := getApiKeyErrorHandler(applicationConfig)
// Create Next function (skip middleware for certain requests)
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
// Wrap it with our custom key lookup that checks multiple sources
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if len(applicationConfig.ApiKeys) == 0 {
return next(c)
}
// Skip if skipper says so
if skipper != nil && skipper(c) {
return next(c)
}
// Try to extract key from multiple sources
key, err := extractKeyFromMultipleSources(c)
if err != nil {
return errorHandler(err, c)
}
// Validate the key
valid, err := validator(key, c)
if err != nil || !valid {
return errorHandler(ErrMissingOrMalformedAPIKey, c)
}
// Store key in context for later use
c.Set("api_key", key)
return next(c)
}
}, nil
}
// extractKeyFromMultipleSources checks multiple sources for the API key
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
// Check Authorization header first
auth := c.Request().Header.Get("Authorization")
if auth != "" {
// Check for Bearer scheme
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer "), nil
}
// If no Bearer prefix, return as-is (for backward compatibility)
return auth, nil
}
// Check x-api-key header
if key := c.Request().Header.Get("x-api-key"); key != "" {
return key, nil
}
// Check xi-api-key header
if key := c.Request().Header.Get("xi-api-key"); key != "" {
return key, nil
}
// Check token cookie
cookie, err := c.Cookie("token")
if err == nil && cookie != nil && cookie.Value != "" {
return cookie.Value, nil
}
return "", ErrMissingOrMalformedAPIKey
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
return func(err error, c echo.Context) error {
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 {
return nil // if no keys are set up, any error we get here is not an error.
}
c.Response().Header().Set("WWW-Authenticate", "Bearer")
if applicationConfig.OpaqueErrors {
return c.NoContent(http.StatusUnauthorized)
}
// Check if the request content type is JSON
contentType := c.Request().Header.Get("Content-Type")
if strings.Contains(contentType, "application/json") {
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
Error: &schema.APIError{
Message: "An authentication key is required",
Code: 401,
Type: "invalid_request_error",
},
})
}
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
"BaseURL": BaseURL(c),
})
}
if applicationConfig.OpaqueErrors {
return c.NoContent(http.StatusInternalServerError)
}
return err
}
}
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
if applicationConfig.UseSubtleKeyComparison {
return func(key string, c echo.Context) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
return true, nil
}
}
return false, ErrMissingOrMalformedAPIKey
}
}
return func(key string, c echo.Context) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if key == validKey {
return true, nil
}
}
return false, ErrMissingOrMalformedAPIKey
}
}
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
return func(c echo.Context) bool {
path := c.Request().URL.Path
for _, p := range applicationConfig.PathWithoutAuth {
if strings.HasPrefix(path, p) {
return true
}
}
// Handle GET request exemptions if enabled
if applicationConfig.DisableApiKeyRequirementForHttpGet {
if c.Request().Method != http.MethodGet {
return false
}
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
if rx.MatchString(c.Path()) {
return true
}
}
}
return false
}
}
|