| | package middleware |
| |
|
| | import ( |
| | "context" |
| | "fmt" |
| | "net/http" |
| | "time" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/gin-gonic/gin" |
| | ) |
| |
|
| | var timeFormat = "2006-01-02T15:04:05.000Z" |
| |
|
| | var inMemoryRateLimiter common.InMemoryRateLimiter |
| |
|
| | var defNext = func(c *gin.Context) { |
| | c.Next() |
| | } |
| |
|
| | func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { |
| | ctx := context.Background() |
| | rdb := common.RDB |
| | key := "rateLimit:" + mark + c.ClientIP() |
| | listLength, err := rdb.LLen(ctx, key).Result() |
| | if err != nil { |
| | fmt.Println(err.Error()) |
| | c.Status(http.StatusInternalServerError) |
| | c.Abort() |
| | return |
| | } |
| | if listLength < int64(maxRequestNum) { |
| | rdb.LPush(ctx, key, time.Now().Format(timeFormat)) |
| | rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) |
| | } else { |
| | oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() |
| | oldTime, err := time.Parse(timeFormat, oldTimeStr) |
| | if err != nil { |
| | fmt.Println(err) |
| | c.Status(http.StatusInternalServerError) |
| | c.Abort() |
| | return |
| | } |
| | nowTimeStr := time.Now().Format(timeFormat) |
| | nowTime, err := time.Parse(timeFormat, nowTimeStr) |
| | if err != nil { |
| | fmt.Println(err) |
| | c.Status(http.StatusInternalServerError) |
| | c.Abort() |
| | return |
| | } |
| | |
| | |
| | if int64(nowTime.Sub(oldTime).Seconds()) < duration { |
| | rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) |
| | c.Status(http.StatusTooManyRequests) |
| | c.Abort() |
| | return |
| | } else { |
| | rdb.LPush(ctx, key, time.Now().Format(timeFormat)) |
| | rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) |
| | rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) |
| | } |
| | } |
| | } |
| |
|
| | func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { |
| | key := mark + c.ClientIP() |
| | if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { |
| | c.Status(http.StatusTooManyRequests) |
| | c.Abort() |
| | return |
| | } |
| | } |
| |
|
| | func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { |
| | if common.RedisEnabled { |
| | return func(c *gin.Context) { |
| | redisRateLimiter(c, maxRequestNum, duration, mark) |
| | } |
| | } else { |
| | |
| | inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) |
| | return func(c *gin.Context) { |
| | memoryRateLimiter(c, maxRequestNum, duration, mark) |
| | } |
| | } |
| | } |
| |
|
| | func GlobalWebRateLimit() func(c *gin.Context) { |
| | if common.GlobalWebRateLimitEnable { |
| | return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") |
| | } |
| | return defNext |
| | } |
| |
|
| | func GlobalAPIRateLimit() func(c *gin.Context) { |
| | if common.GlobalApiRateLimitEnable { |
| | return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") |
| | } |
| | return defNext |
| | } |
| |
|
| | func CriticalRateLimit() func(c *gin.Context) { |
| | if common.CriticalRateLimitEnable { |
| | return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") |
| | } |
| | return defNext |
| | } |
| |
|
| | func DownloadRateLimit() func(c *gin.Context) { |
| | return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") |
| | } |
| |
|
| | func UploadRateLimit() func(c *gin.Context) { |
| | return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") |
| | } |
| |
|