| | package middleware |
| |
|
| | import ( |
| | "context" |
| | "fmt" |
| | "net/http" |
| | "strconv" |
| | "time" |
| |
|
| | "github.com/QuantumNous/new-api/common" |
| | "github.com/QuantumNous/new-api/common/limiter" |
| | "github.com/QuantumNous/new-api/constant" |
| | "github.com/QuantumNous/new-api/setting" |
| |
|
| | "github.com/gin-gonic/gin" |
| | "github.com/go-redis/redis/v8" |
| | ) |
| |
|
| | const ( |
| | ModelRequestRateLimitCountMark = "MRRL" |
| | ModelRequestRateLimitSuccessCountMark = "MRRLS" |
| | ) |
| |
|
| | |
| | func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { |
| | |
| | if maxCount == 0 { |
| | return true, nil |
| | } |
| |
|
| | |
| | length, err := rdb.LLen(ctx, key).Result() |
| | if err != nil { |
| | return false, err |
| | } |
| |
|
| | |
| | if length < int64(maxCount) { |
| | return true, nil |
| | } |
| |
|
| | |
| | oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() |
| | oldTime, err := time.Parse(timeFormat, oldTimeStr) |
| | if err != nil { |
| | return false, err |
| | } |
| |
|
| | nowTimeStr := time.Now().Format(timeFormat) |
| | nowTime, err := time.Parse(timeFormat, nowTimeStr) |
| | if err != nil { |
| | return false, err |
| | } |
| | |
| | subTime := nowTime.Sub(oldTime).Seconds() |
| | if int64(subTime) < duration { |
| | rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) |
| | return false, nil |
| | } |
| |
|
| | return true, nil |
| | } |
| |
|
| | |
| | func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { |
| | |
| | if maxCount == 0 { |
| | return |
| | } |
| |
|
| | now := time.Now().Format(timeFormat) |
| | rdb.LPush(ctx, key, now) |
| | rdb.LTrim(ctx, key, 0, int64(maxCount-1)) |
| | rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) |
| | } |
| |
|
| | |
| | func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { |
| | return func(c *gin.Context) { |
| | userId := strconv.Itoa(c.GetInt("id")) |
| | ctx := context.Background() |
| | rdb := common.RDB |
| |
|
| | |
| | successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) |
| | allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) |
| | if err != nil { |
| | fmt.Println("检查成功请求数限制失败:", err.Error()) |
| | abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") |
| | return |
| | } |
| | if !allowed { |
| | abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) |
| | return |
| | } |
| |
|
| | |
| | if totalMaxCount > 0 { |
| | totalKey := fmt.Sprintf("rateLimit:%s", userId) |
| | |
| | tb := limiter.New(ctx, rdb) |
| | allowed, err = tb.Allow( |
| | ctx, |
| | totalKey, |
| | limiter.WithCapacity(int64(totalMaxCount)*duration), |
| | limiter.WithRate(int64(totalMaxCount)), |
| | limiter.WithRequested(duration), |
| | ) |
| |
|
| | if err != nil { |
| | fmt.Println("检查总请求数限制失败:", err.Error()) |
| | abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") |
| | return |
| | } |
| |
|
| | if !allowed { |
| | abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) |
| | } |
| | } |
| |
|
| | |
| | c.Next() |
| |
|
| | |
| | if c.Writer.Status() < 400 { |
| | recordRedisRequest(ctx, rdb, successKey, successMaxCount) |
| | } |
| | } |
| | } |
| |
|
| | |
| | func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { |
| | inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) |
| |
|
| | return func(c *gin.Context) { |
| | userId := strconv.Itoa(c.GetInt("id")) |
| | totalKey := ModelRequestRateLimitCountMark + userId |
| | successKey := ModelRequestRateLimitSuccessCountMark + userId |
| |
|
| | |
| | if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { |
| | c.Status(http.StatusTooManyRequests) |
| | c.Abort() |
| | return |
| | } |
| |
|
| | |
| | |
| | checkKey := successKey + "_check" |
| | if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { |
| | c.Status(http.StatusTooManyRequests) |
| | c.Abort() |
| | return |
| | } |
| |
|
| | |
| | c.Next() |
| |
|
| | |
| | if c.Writer.Status() < 400 { |
| | inMemoryRateLimiter.Request(successKey, successMaxCount, duration) |
| | } |
| | } |
| | } |
| |
|
| | |
| | func ModelRequestRateLimit() func(c *gin.Context) { |
| | return func(c *gin.Context) { |
| | |
| | if !setting.ModelRequestRateLimitEnabled { |
| | c.Next() |
| | return |
| | } |
| |
|
| | |
| | duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) |
| | totalMaxCount := setting.ModelRequestRateLimitCount |
| | successMaxCount := setting.ModelRequestRateLimitSuccessCount |
| |
|
| | |
| | group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) |
| | if group == "" { |
| | group = common.GetContextKeyString(c, constant.ContextKeyUserGroup) |
| | } |
| |
|
| | |
| | groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) |
| | if found { |
| | totalMaxCount = groupTotalCount |
| | successMaxCount = groupSuccessCount |
| | } |
| |
|
| | |
| | if common.RedisEnabled { |
| | redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) |
| | } else { |
| | memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) |
| | } |
| | } |
| | } |
| |
|