| package limiter |
|
|
| import ( |
| "context" |
| _ "embed" |
| "fmt" |
| "sync" |
|
|
| "github.com/QuantumNous/new-api/common" |
| "github.com/go-redis/redis/v8" |
| ) |
|
|
| |
| var rateLimitScript string |
|
|
| type RedisLimiter struct { |
| client *redis.Client |
| limitScriptSHA string |
| } |
|
|
| var ( |
| instance *RedisLimiter |
| once sync.Once |
| ) |
|
|
| func New(ctx context.Context, r *redis.Client) *RedisLimiter { |
| once.Do(func() { |
| |
| limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() |
| if err != nil { |
| common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) |
| } |
| instance = &RedisLimiter{ |
| client: r, |
| limitScriptSHA: limitSHA, |
| } |
| }) |
|
|
| return instance |
| } |
|
|
| func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) { |
| |
| config := &Config{ |
| Capacity: 10, |
| Rate: 1, |
| Requested: 1, |
| } |
|
|
| |
| for _, opt := range opts { |
| opt(config) |
| } |
|
|
| |
| result, err := rl.client.EvalSha( |
| ctx, |
| rl.limitScriptSHA, |
| []string{key}, |
| config.Requested, |
| config.Rate, |
| config.Capacity, |
| ).Int() |
|
|
| if err != nil { |
| return false, fmt.Errorf("rate limit failed: %w", err) |
| } |
| return result == 1, nil |
| } |
|
|
| |
| type Config struct { |
| Capacity int64 |
| Rate int64 |
| Requested int64 |
| } |
|
|
| type Option func(*Config) |
|
|
| func WithCapacity(c int64) Option { |
| return func(cfg *Config) { cfg.Capacity = c } |
| } |
|
|
| func WithRate(r int64) Option { |
| return func(cfg *Config) { cfg.Rate = r } |
| } |
|
|
| func WithRequested(n int64) Option { |
| return func(cfg *Config) { cfg.Requested = n } |
| } |
|
|