Spaces:
Build error
Build error
| package middleware | |
| import ( | |
| "context" | |
| "fmt" | |
| "net/http" | |
| "one-api/common" | |
| "one-api/common/limiter" | |
| "one-api/constant" | |
| "one-api/setting" | |
| "strconv" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "github.com/go-redis/redis/v8" | |
| ) | |
| const ( | |
| ModelRequestRateLimitCountMark = "MRRL" | |
| ModelRequestRateLimitSuccessCountMark = "MRRLS" | |
| ) | |
| // 检查Redis中的请求限制 | |
| func checkRedisRateLimit(ctx context.Context, rdb *redis.Client, key string, maxCount int, duration int64) (bool, error) { | |
| // 如果maxCount为0,表示不限制 | |
| if maxCount == 0 { | |
| return true, nil | |
| } | |
| // 获取当前计数 | |
| length, err := rdb.LLen(ctx, key).Result() | |
| if err != nil { | |
| return false, err | |
| } | |
| // 如果未达到限制,允许请求 | |
| if length < int64(maxCount) { | |
| return true, nil | |
| } | |
| // 检查时间窗口 | |
| oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() | |
| oldTime, err := time.Parse(timeFormat, oldTimeStr) | |
| if err != nil { | |
| return false, err | |
| } | |
| nowTimeStr := time.Now().Format(timeFormat) | |
| nowTime, err := time.Parse(timeFormat, nowTimeStr) | |
| if err != nil { | |
| return false, err | |
| } | |
| // 如果在时间窗口内已达到限制,拒绝请求 | |
| subTime := nowTime.Sub(oldTime).Seconds() | |
| if int64(subTime) < duration { | |
| rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) | |
| return false, nil | |
| } | |
| return true, nil | |
| } | |
| // 记录Redis请求 | |
| func recordRedisRequest(ctx context.Context, rdb *redis.Client, key string, maxCount int) { | |
| // 如果maxCount为0,不记录请求 | |
| if maxCount == 0 { | |
| return | |
| } | |
| now := time.Now().Format(timeFormat) | |
| rdb.LPush(ctx, key, now) | |
| rdb.LTrim(ctx, key, 0, int64(maxCount-1)) | |
| rdb.Expire(ctx, key, time.Duration(setting.ModelRequestRateLimitDurationMinutes)*time.Minute) | |
| } | |
| // Redis限流处理器 | |
| func redisRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { | |
| return func(c *gin.Context) { | |
| userId := strconv.Itoa(c.GetInt("id")) | |
| ctx := context.Background() | |
| rdb := common.RDB | |
| // 1. 检查成功请求数限制 | |
| successKey := fmt.Sprintf("rateLimit:%s:%s", ModelRequestRateLimitSuccessCountMark, userId) | |
| allowed, err := checkRedisRateLimit(ctx, rdb, successKey, successMaxCount, duration) | |
| if err != nil { | |
| fmt.Println("检查成功请求数限制失败:", err.Error()) | |
| abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") | |
| return | |
| } | |
| if !allowed { | |
| abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到请求数限制:%d分钟内最多请求%d次", setting.ModelRequestRateLimitDurationMinutes, successMaxCount)) | |
| return | |
| } | |
| //2.检查总请求数限制并记录总请求(当totalMaxCount为0时会自动跳过,使用令牌桶限流器 | |
| if totalMaxCount > 0 { | |
| totalKey := fmt.Sprintf("rateLimit:%s", userId) | |
| // 初始化 | |
| tb := limiter.New(ctx, rdb) | |
| allowed, err = tb.Allow( | |
| ctx, | |
| totalKey, | |
| limiter.WithCapacity(int64(totalMaxCount)*duration), | |
| limiter.WithRate(int64(totalMaxCount)), | |
| limiter.WithRequested(duration), | |
| ) | |
| if err != nil { | |
| fmt.Println("检查总请求数限制失败:", err.Error()) | |
| abortWithOpenAiMessage(c, http.StatusInternalServerError, "rate_limit_check_failed") | |
| return | |
| } | |
| if !allowed { | |
| abortWithOpenAiMessage(c, http.StatusTooManyRequests, fmt.Sprintf("您已达到总请求数限制:%d分钟内最多请求%d次,包括失败次数,请检查您的请求是否正确", setting.ModelRequestRateLimitDurationMinutes, totalMaxCount)) | |
| } | |
| } | |
| // 4. 处理请求 | |
| c.Next() | |
| // 5. 如果请求成功,记录成功请求 | |
| if c.Writer.Status() < 400 { | |
| recordRedisRequest(ctx, rdb, successKey, successMaxCount) | |
| } | |
| } | |
| } | |
| // 内存限流处理器 | |
| func memoryRateLimitHandler(duration int64, totalMaxCount, successMaxCount int) gin.HandlerFunc { | |
| inMemoryRateLimiter.Init(time.Duration(setting.ModelRequestRateLimitDurationMinutes) * time.Minute) | |
| return func(c *gin.Context) { | |
| userId := strconv.Itoa(c.GetInt("id")) | |
| totalKey := ModelRequestRateLimitCountMark + userId | |
| successKey := ModelRequestRateLimitSuccessCountMark + userId | |
| // 1. 检查总请求数限制(当totalMaxCount为0时跳过) | |
| if totalMaxCount > 0 && !inMemoryRateLimiter.Request(totalKey, totalMaxCount, duration) { | |
| c.Status(http.StatusTooManyRequests) | |
| c.Abort() | |
| return | |
| } | |
| // 2. 检查成功请求数限制 | |
| // 使用一个临时key来检查限制,这样可以避免实际记录 | |
| checkKey := successKey + "_check" | |
| if !inMemoryRateLimiter.Request(checkKey, successMaxCount, duration) { | |
| c.Status(http.StatusTooManyRequests) | |
| c.Abort() | |
| return | |
| } | |
| // 3. 处理请求 | |
| c.Next() | |
| // 4. 如果请求成功,记录到实际的成功请求计数中 | |
| if c.Writer.Status() < 400 { | |
| inMemoryRateLimiter.Request(successKey, successMaxCount, duration) | |
| } | |
| } | |
| } | |
| // ModelRequestRateLimit 模型请求限流中间件 | |
| func ModelRequestRateLimit() func(c *gin.Context) { | |
| return func(c *gin.Context) { | |
| // 在每个请求时检查是否启用限流 | |
| if !setting.ModelRequestRateLimitEnabled { | |
| c.Next() | |
| return | |
| } | |
| // 计算限流参数 | |
| duration := int64(setting.ModelRequestRateLimitDurationMinutes * 60) | |
| totalMaxCount := setting.ModelRequestRateLimitCount | |
| successMaxCount := setting.ModelRequestRateLimitSuccessCount | |
| // 获取分组 | |
| group := common.GetContextKeyString(c, constant.ContextKeyTokenGroup) | |
| if group == "" { | |
| group = common.GetContextKeyString(c, constant.ContextKeyUserGroup) | |
| } | |
| //获取分组的限流配置 | |
| groupTotalCount, groupSuccessCount, found := setting.GetGroupRateLimit(group) | |
| if found { | |
| totalMaxCount = groupTotalCount | |
| successMaxCount = groupSuccessCount | |
| } | |
| // 根据存储类型选择并执行限流处理器 | |
| if common.RedisEnabled { | |
| redisRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) | |
| } else { | |
| memoryRateLimitHandler(duration, totalMaxCount, successMaxCount)(c) | |
| } | |
| } | |
| } | |