|
|
package middleware |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"fmt" |
|
|
"github.com/gin-gonic/gin" |
|
|
"net/http" |
|
|
"one-api/common" |
|
|
"time" |
|
|
) |
|
|
|
|
|
var timeFormat = "2006-01-02T15:04:05.000Z" |
|
|
|
|
|
var inMemoryRateLimiter common.InMemoryRateLimiter |
|
|
|
|
|
func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { |
|
|
ctx := context.Background() |
|
|
rdb := common.RDB |
|
|
key := "rateLimit:" + mark + c.ClientIP() |
|
|
listLength, err := rdb.LLen(ctx, key).Result() |
|
|
if err != nil { |
|
|
fmt.Println(err.Error()) |
|
|
c.Status(http.StatusInternalServerError) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
if listLength < int64(maxRequestNum) { |
|
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) |
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) |
|
|
} else { |
|
|
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() |
|
|
oldTime, err := time.Parse(timeFormat, oldTimeStr) |
|
|
if err != nil { |
|
|
fmt.Println(err) |
|
|
c.Status(http.StatusInternalServerError) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
nowTimeStr := time.Now().Format(timeFormat) |
|
|
nowTime, err := time.Parse(timeFormat, nowTimeStr) |
|
|
if err != nil { |
|
|
fmt.Println(err) |
|
|
c.Status(http.StatusInternalServerError) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
if int64(nowTime.Sub(oldTime).Seconds()) < duration { |
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) |
|
|
c.Status(http.StatusTooManyRequests) |
|
|
c.Abort() |
|
|
return |
|
|
} else { |
|
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) |
|
|
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) |
|
|
rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { |
|
|
key := mark + c.ClientIP() |
|
|
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { |
|
|
c.Status(http.StatusTooManyRequests) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { |
|
|
if common.RedisEnabled { |
|
|
return func(c *gin.Context) { |
|
|
redisRateLimiter(c, maxRequestNum, duration, mark) |
|
|
} |
|
|
} else { |
|
|
|
|
|
inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) |
|
|
return func(c *gin.Context) { |
|
|
memoryRateLimiter(c, maxRequestNum, duration, mark) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
func GlobalWebRateLimit() func(c *gin.Context) { |
|
|
return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") |
|
|
} |
|
|
|
|
|
func GlobalAPIRateLimit() func(c *gin.Context) { |
|
|
return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") |
|
|
} |
|
|
|
|
|
func CriticalRateLimit() func(c *gin.Context) { |
|
|
return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") |
|
|
} |
|
|
|
|
|
func DownloadRateLimit() func(c *gin.Context) { |
|
|
return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") |
|
|
} |
|
|
|
|
|
func UploadRateLimit() func(c *gin.Context) { |
|
|
return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") |
|
|
} |
|
|
|