| package middleware |
|
|
| import ( |
| "context" |
| "errors" |
| "strings" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/config" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" |
| "github.com/Wei-Shaw/sub2api/internal/pkg/ip" |
| "github.com/Wei-Shaw/sub2api/internal/service" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| |
| func NewAPIKeyAuthMiddleware(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) APIKeyAuthMiddleware { |
| return APIKeyAuthMiddleware(apiKeyAuthWithSubscription(apiKeyService, subscriptionService, cfg)) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) gin.HandlerFunc { |
| return func(c *gin.Context) { |
| |
|
|
| queryKey := strings.TrimSpace(c.Query("key")) |
| queryApiKey := strings.TrimSpace(c.Query("api_key")) |
| if queryKey != "" || queryApiKey != "" { |
| AbortWithError(c, 400, "api_key_in_query_deprecated", "API key in query parameter is deprecated. Please use Authorization header instead.") |
| return |
| } |
|
|
| |
| authHeader := c.GetHeader("Authorization") |
| var apiKeyString string |
|
|
| if authHeader != "" { |
| |
| parts := strings.SplitN(authHeader, " ", 2) |
| if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { |
| apiKeyString = strings.TrimSpace(parts[1]) |
| } |
| } |
|
|
| |
| if apiKeyString == "" { |
| apiKeyString = c.GetHeader("x-api-key") |
| } |
|
|
| |
| if apiKeyString == "" { |
| apiKeyString = c.GetHeader("x-goog-api-key") |
| } |
|
|
| |
| if apiKeyString == "" { |
| AbortWithError(c, 401, "API_KEY_REQUIRED", "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header") |
| return |
| } |
|
|
| |
|
|
| apiKey, err := apiKeyService.GetByKey(c.Request.Context(), apiKeyString) |
| if err != nil { |
| if errors.Is(err, service.ErrAPIKeyNotFound) { |
| AbortWithError(c, 401, "INVALID_API_KEY", "Invalid API key") |
| return |
| } |
| AbortWithError(c, 500, "INTERNAL_ERROR", "Failed to validate API key") |
| return |
| } |
|
|
| |
|
|
| |
| if !apiKey.IsActive() && |
| apiKey.Status != service.StatusAPIKeyExpired && |
| apiKey.Status != service.StatusAPIKeyQuotaExhausted { |
| AbortWithError(c, 401, "API_KEY_DISABLED", "API key is disabled") |
| return |
| } |
|
|
| |
| |
| if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { |
| clientIP := ip.GetTrustedClientIP(c) |
| allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) |
| if !allowed { |
| AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") |
| return |
| } |
| } |
|
|
| |
| if apiKey.User == nil { |
| AbortWithError(c, 401, "USER_NOT_FOUND", "User associated with API key not found") |
| return |
| } |
|
|
| |
| if !apiKey.User.IsActive() { |
| AbortWithError(c, 401, "USER_INACTIVE", "User account is not active") |
| return |
| } |
|
|
| |
|
|
| if cfg.RunMode == config.RunModeSimple { |
| c.Set(string(ContextKeyAPIKey), apiKey) |
| c.Set(string(ContextKeyUser), AuthSubject{ |
| UserID: apiKey.User.ID, |
| Concurrency: apiKey.User.Concurrency, |
| }) |
| c.Set(string(ContextKeyUserRole), apiKey.User.Role) |
| setGroupContext(c, apiKey.Group) |
| _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) |
| c.Next() |
| return |
| } |
|
|
| |
|
|
| |
| skipBilling := c.Request.URL.Path == "/v1/usage" |
|
|
| var subscription *service.UserSubscription |
| isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() |
|
|
| if isSubscriptionType && subscriptionService != nil { |
| sub, subErr := subscriptionService.GetActiveSubscription( |
| c.Request.Context(), |
| apiKey.User.ID, |
| apiKey.Group.ID, |
| ) |
| if subErr != nil { |
| if !skipBilling { |
| AbortWithError(c, 403, "SUBSCRIPTION_NOT_FOUND", "No active subscription found for this group") |
| return |
| } |
| |
| } else { |
| subscription = sub |
| } |
| } |
|
|
| |
|
|
| if !skipBilling { |
| |
| switch apiKey.Status { |
| case service.StatusAPIKeyQuotaExhausted: |
| AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") |
| return |
| case service.StatusAPIKeyExpired: |
| AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") |
| return |
| } |
|
|
| |
| if apiKey.IsExpired() { |
| AbortWithError(c, 403, "API_KEY_EXPIRED", "API key 已过期") |
| return |
| } |
| if apiKey.IsQuotaExhausted() { |
| AbortWithError(c, 429, "API_KEY_QUOTA_EXHAUSTED", "API key 额度已用完") |
| return |
| } |
|
|
| |
| if subscription != nil { |
| needsMaintenance, validateErr := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) |
| if validateErr != nil { |
| code := "SUBSCRIPTION_INVALID" |
| status := 403 |
| if errors.Is(validateErr, service.ErrDailyLimitExceeded) || |
| errors.Is(validateErr, service.ErrWeeklyLimitExceeded) || |
| errors.Is(validateErr, service.ErrMonthlyLimitExceeded) { |
| code = "USAGE_LIMIT_EXCEEDED" |
| status = 429 |
| } |
| AbortWithError(c, status, code, validateErr.Error()) |
| return |
| } |
|
|
| |
| if needsMaintenance { |
| maintenanceCopy := *subscription |
| subscriptionService.DoWindowMaintenance(&maintenanceCopy) |
| } |
| } else { |
| |
| if apiKey.User.Balance <= 0 { |
| AbortWithError(c, 403, "INSUFFICIENT_BALANCE", "Insufficient account balance") |
| return |
| } |
| } |
| } |
|
|
| |
|
|
| if subscription != nil { |
| c.Set(string(ContextKeySubscription), subscription) |
| } |
| c.Set(string(ContextKeyAPIKey), apiKey) |
| c.Set(string(ContextKeyUser), AuthSubject{ |
| UserID: apiKey.User.ID, |
| Concurrency: apiKey.User.Concurrency, |
| }) |
| c.Set(string(ContextKeyUserRole), apiKey.User.Role) |
| setGroupContext(c, apiKey.Group) |
| _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) |
|
|
| c.Next() |
| } |
| } |
|
|
| |
| func GetAPIKeyFromContext(c *gin.Context) (*service.APIKey, bool) { |
| value, exists := c.Get(string(ContextKeyAPIKey)) |
| if !exists { |
| return nil, false |
| } |
| apiKey, ok := value.(*service.APIKey) |
| return apiKey, ok |
| } |
|
|
| |
| func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool) { |
| value, exists := c.Get(string(ContextKeySubscription)) |
| if !exists { |
| return nil, false |
| } |
| subscription, ok := value.(*service.UserSubscription) |
| return subscription, ok |
| } |
|
|
| func setGroupContext(c *gin.Context, group *service.Group) { |
| if !service.IsGroupContextValid(group) { |
| return |
| } |
| if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) { |
| return |
| } |
| ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group) |
| c.Request = c.Request.WithContext(ctx) |
| } |
|
|