Spaces:
Build error
Build error
| package middleware | |
| import ( | |
| "context" | |
| "fmt" | |
| "github.com/gin-gonic/gin" | |
| "net/http" | |
| "one-api/common" | |
| "time" | |
| ) | |
| var timeFormat = "2006-01-02T15:04:05.000Z" | |
| var inMemoryRateLimiter common.InMemoryRateLimiter | |
| var defNext = func(c *gin.Context) { | |
| c.Next() | |
| } | |
| func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { | |
| ctx := context.Background() | |
| rdb := common.RDB | |
| key := "rateLimit:" + mark + c.ClientIP() | |
| listLength, err := rdb.LLen(ctx, key).Result() | |
| if err != nil { | |
| fmt.Println(err.Error()) | |
| c.Status(http.StatusInternalServerError) | |
| c.Abort() | |
| return | |
| } | |
| if listLength < int64(maxRequestNum) { | |
| rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | |
| rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | |
| } else { | |
| oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | |
| oldTime, err := time.Parse(timeFormat, oldTimeStr) | |
| if err != nil { | |
| fmt.Println(err) | |
| c.Status(http.StatusInternalServerError) | |
| c.Abort() | |
| return | |
| } | |
| nowTimeStr := time.Now().Format(timeFormat) | |
| nowTime, err := time.Parse(timeFormat, nowTimeStr) | |
| if err != nil { | |
| fmt.Println(err) | |
| c.Status(http.StatusInternalServerError) | |
| c.Abort() | |
| return | |
| } | |
| // time.Since will return negative number! | |
| // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows | |
| if int64(nowTime.Sub(oldTime).Seconds()) < duration { | |
| rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | |
| c.Status(http.StatusTooManyRequests) | |
| c.Abort() | |
| return | |
| } else { | |
| rdb.LPush(ctx, key, time.Now().Format(timeFormat)) | |
| rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) | |
| rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) | |
| } | |
| } | |
| } | |
| func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { | |
| key := mark + c.ClientIP() | |
| if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { | |
| c.Status(http.StatusTooManyRequests) | |
| c.Abort() | |
| return | |
| } | |
| } | |
| func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { | |
| if common.RedisEnabled { | |
| return func(c *gin.Context) { | |
| redisRateLimiter(c, maxRequestNum, duration, mark) | |
| } | |
| } else { | |
| // It's safe to call multi times. | |
| inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) | |
| return func(c *gin.Context) { | |
| memoryRateLimiter(c, maxRequestNum, duration, mark) | |
| } | |
| } | |
| } | |
| func GlobalWebRateLimit() func(c *gin.Context) { | |
| if common.GlobalWebRateLimitEnable { | |
| return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") | |
| } | |
| return defNext | |
| } | |
| func GlobalAPIRateLimit() func(c *gin.Context) { | |
| if common.GlobalApiRateLimitEnable { | |
| return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") | |
| } | |
| return defNext | |
| } | |
| func CriticalRateLimit() func(c *gin.Context) { | |
| return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") | |
| } | |
| func DownloadRateLimit() func(c *gin.Context) { | |
| return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") | |
| } | |
| func UploadRateLimit() func(c *gin.Context) { | |
| return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") | |
| } | |