|
|
package limiter |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
_ "embed" |
|
|
"fmt" |
|
|
"sync" |
|
|
|
|
|
"github.com/QuantumNous/new-api/common" |
|
|
"github.com/go-redis/redis/v8" |
|
|
) |
|
|
|
|
|
|
|
|
var rateLimitScript string |
|
|
|
|
|
type RedisLimiter struct { |
|
|
client *redis.Client |
|
|
limitScriptSHA string |
|
|
} |
|
|
|
|
|
var ( |
|
|
instance *RedisLimiter |
|
|
once sync.Once |
|
|
) |
|
|
|
|
|
func New(ctx context.Context, r *redis.Client) *RedisLimiter { |
|
|
once.Do(func() { |
|
|
|
|
|
limitSHA, err := r.ScriptLoad(ctx, rateLimitScript).Result() |
|
|
if err != nil { |
|
|
common.SysLog(fmt.Sprintf("Failed to load rate limit script: %v", err)) |
|
|
} |
|
|
instance = &RedisLimiter{ |
|
|
client: r, |
|
|
limitScriptSHA: limitSHA, |
|
|
} |
|
|
}) |
|
|
|
|
|
return instance |
|
|
} |
|
|
|
|
|
func (rl *RedisLimiter) Allow(ctx context.Context, key string, opts ...Option) (bool, error) { |
|
|
|
|
|
config := &Config{ |
|
|
Capacity: 10, |
|
|
Rate: 1, |
|
|
Requested: 1, |
|
|
} |
|
|
|
|
|
|
|
|
for _, opt := range opts { |
|
|
opt(config) |
|
|
} |
|
|
|
|
|
|
|
|
result, err := rl.client.EvalSha( |
|
|
ctx, |
|
|
rl.limitScriptSHA, |
|
|
[]string{key}, |
|
|
config.Requested, |
|
|
config.Rate, |
|
|
config.Capacity, |
|
|
).Int() |
|
|
|
|
|
if err != nil { |
|
|
return false, fmt.Errorf("rate limit failed: %w", err) |
|
|
} |
|
|
return result == 1, nil |
|
|
} |
|
|
|
|
|
|
|
|
type Config struct { |
|
|
Capacity int64 |
|
|
Rate int64 |
|
|
Requested int64 |
|
|
} |
|
|
|
|
|
type Option func(*Config) |
|
|
|
|
|
func WithCapacity(c int64) Option { |
|
|
return func(cfg *Config) { cfg.Capacity = c } |
|
|
} |
|
|
|
|
|
func WithRate(r int64) Option { |
|
|
return func(cfg *Config) { cfg.Rate = r } |
|
|
} |
|
|
|
|
|
func WithRequested(n int64) Option { |
|
|
return func(cfg *Config) { cfg.Requested = n } |
|
|
} |
|
|
|