|
|
package middleware |
|
|
|
|
|
import ( |
|
|
"crypto/subtle" |
|
|
"errors" |
|
|
"net/http" |
|
|
"strings" |
|
|
|
|
|
"github.com/labstack/echo/v4" |
|
|
"github.com/labstack/echo/v4/middleware" |
|
|
"github.com/mudler/LocalAI/core/config" |
|
|
"github.com/mudler/LocalAI/core/schema" |
|
|
) |
|
|
|
|
|
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key") |
|
|
|
|
|
|
|
|
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) { |
|
|
|
|
|
validator := getApiKeyValidationFunction(applicationConfig) |
|
|
|
|
|
|
|
|
errorHandler := getApiKeyErrorHandler(applicationConfig) |
|
|
|
|
|
|
|
|
skipper := getApiKeyRequiredFilterFunction(applicationConfig) |
|
|
|
|
|
|
|
|
return func(next echo.HandlerFunc) echo.HandlerFunc { |
|
|
return func(c echo.Context) error { |
|
|
if len(applicationConfig.ApiKeys) == 0 { |
|
|
return next(c) |
|
|
} |
|
|
|
|
|
|
|
|
if skipper != nil && skipper(c) { |
|
|
return next(c) |
|
|
} |
|
|
|
|
|
|
|
|
key, err := extractKeyFromMultipleSources(c) |
|
|
if err != nil { |
|
|
return errorHandler(err, c) |
|
|
} |
|
|
|
|
|
|
|
|
valid, err := validator(key, c) |
|
|
if err != nil || !valid { |
|
|
return errorHandler(ErrMissingOrMalformedAPIKey, c) |
|
|
} |
|
|
|
|
|
|
|
|
c.Set("api_key", key) |
|
|
|
|
|
return next(c) |
|
|
} |
|
|
}, nil |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func extractKeyFromMultipleSources(c echo.Context) (string, error) { |
|
|
|
|
|
auth := c.Request().Header.Get("Authorization") |
|
|
if auth != "" { |
|
|
|
|
|
if strings.HasPrefix(auth, "Bearer ") { |
|
|
return strings.TrimPrefix(auth, "Bearer "), nil |
|
|
} |
|
|
|
|
|
return auth, nil |
|
|
} |
|
|
|
|
|
|
|
|
if key := c.Request().Header.Get("x-api-key"); key != "" { |
|
|
return key, nil |
|
|
} |
|
|
|
|
|
|
|
|
if key := c.Request().Header.Get("xi-api-key"); key != "" { |
|
|
return key, nil |
|
|
} |
|
|
|
|
|
|
|
|
cookie, err := c.Cookie("token") |
|
|
if err == nil && cookie != nil && cookie.Value != "" { |
|
|
return cookie.Value, nil |
|
|
} |
|
|
|
|
|
return "", ErrMissingOrMalformedAPIKey |
|
|
} |
|
|
|
|
|
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error { |
|
|
return func(err error, c echo.Context) error { |
|
|
if errors.Is(err, ErrMissingOrMalformedAPIKey) { |
|
|
if len(applicationConfig.ApiKeys) == 0 { |
|
|
return nil |
|
|
} |
|
|
c.Response().Header().Set("WWW-Authenticate", "Bearer") |
|
|
if applicationConfig.OpaqueErrors { |
|
|
return c.NoContent(http.StatusUnauthorized) |
|
|
} |
|
|
|
|
|
|
|
|
contentType := c.Request().Header.Get("Content-Type") |
|
|
if strings.Contains(contentType, "application/json") { |
|
|
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ |
|
|
Error: &schema.APIError{ |
|
|
Message: "An authentication key is required", |
|
|
Code: 401, |
|
|
Type: "invalid_request_error", |
|
|
}, |
|
|
}) |
|
|
} |
|
|
|
|
|
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ |
|
|
"BaseURL": BaseURL(c), |
|
|
}) |
|
|
} |
|
|
if applicationConfig.OpaqueErrors { |
|
|
return c.NoContent(http.StatusInternalServerError) |
|
|
} |
|
|
return err |
|
|
} |
|
|
} |
|
|
|
|
|
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) { |
|
|
if applicationConfig.UseSubtleKeyComparison { |
|
|
return func(key string, c echo.Context) (bool, error) { |
|
|
if len(applicationConfig.ApiKeys) == 0 { |
|
|
return true, nil |
|
|
} |
|
|
for _, validKey := range applicationConfig.ApiKeys { |
|
|
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { |
|
|
return true, nil |
|
|
} |
|
|
} |
|
|
return false, ErrMissingOrMalformedAPIKey |
|
|
} |
|
|
} |
|
|
|
|
|
return func(key string, c echo.Context) (bool, error) { |
|
|
if len(applicationConfig.ApiKeys) == 0 { |
|
|
return true, nil |
|
|
} |
|
|
for _, validKey := range applicationConfig.ApiKeys { |
|
|
if key == validKey { |
|
|
return true, nil |
|
|
} |
|
|
} |
|
|
return false, ErrMissingOrMalformedAPIKey |
|
|
} |
|
|
} |
|
|
|
|
|
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper { |
|
|
return func(c echo.Context) bool { |
|
|
path := c.Request().URL.Path |
|
|
|
|
|
for _, p := range applicationConfig.PathWithoutAuth { |
|
|
if strings.HasPrefix(path, p) { |
|
|
return true |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if applicationConfig.DisableApiKeyRequirementForHttpGet { |
|
|
if c.Request().Method != http.MethodGet { |
|
|
return false |
|
|
} |
|
|
for _, rx := range applicationConfig.HttpGetExemptedEndpoints { |
|
|
if rx.MatchString(c.Path()) { |
|
|
return true |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return false |
|
|
} |
|
|
} |
|
|
|