| package repository |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "strconv" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/service" |
| "github.com/redis/go-redis/v9" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| const ( |
| |
| |
| accountSlotKeyPrefix = "concurrency:account:" |
| |
| userSlotKeyPrefix = "concurrency:user:" |
| |
| waitQueueKeyPrefix = "concurrency:wait:" |
| |
| accountWaitKeyPrefix = "wait:account:" |
|
|
| |
| defaultSlotTTLMinutes = 15 |
| ) |
|
|
| var ( |
| |
| |
| |
| |
| |
| |
| acquireScript = redis.NewScript(` |
| local key = KEYS[1] |
| local maxConcurrency = tonumber(ARGV[1]) |
| local ttl = tonumber(ARGV[2]) |
| local requestID = ARGV[3] |
| |
| -- 使用 Redis 服务器时间,确保多实例时钟一致 |
| local timeResult = redis.call('TIME') |
| local now = tonumber(timeResult[1]) |
| local expireBefore = now - ttl |
| |
| -- 清理过期槽位 |
| redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) |
| |
| -- 检查是否已存在(支持重试场景刷新时间戳) |
| local exists = redis.call('ZSCORE', key, requestID) |
| if exists ~= false then |
| redis.call('ZADD', key, now, requestID) |
| redis.call('EXPIRE', key, ttl) |
| return 1 |
| end |
| |
| -- 检查是否达到并发上限 |
| local count = redis.call('ZCARD', key) |
| if count < maxConcurrency then |
| redis.call('ZADD', key, now, requestID) |
| redis.call('EXPIRE', key, ttl) |
| return 1 |
| end |
| |
| return 0 |
| `) |
|
|
| |
| |
| |
| |
| getCountScript = redis.NewScript(` |
| local key = KEYS[1] |
| local ttl = tonumber(ARGV[1]) |
| |
| -- 使用 Redis 服务器时间 |
| local timeResult = redis.call('TIME') |
| local now = tonumber(timeResult[1]) |
| local expireBefore = now - ttl |
| |
| redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) |
| return redis.call('ZCARD', key) |
| `) |
|
|
| |
| |
| |
| |
| incrementWaitScript = redis.NewScript(` |
| local current = redis.call('GET', KEYS[1]) |
| if current == false then |
| current = 0 |
| else |
| current = tonumber(current) |
| end |
| |
| if current >= tonumber(ARGV[1]) then |
| return 0 |
| end |
| |
| local newVal = redis.call('INCR', KEYS[1]) |
| |
| -- Refresh TTL so long-running traffic doesn't expire active queue counters. |
| redis.call('EXPIRE', KEYS[1], ARGV[2]) |
| |
| return 1 |
| `) |
|
|
| |
| incrementAccountWaitScript = redis.NewScript(` |
| local current = redis.call('GET', KEYS[1]) |
| if current == false then |
| current = 0 |
| else |
| current = tonumber(current) |
| end |
| |
| if current >= tonumber(ARGV[1]) then |
| return 0 |
| end |
| |
| local newVal = redis.call('INCR', KEYS[1]) |
| |
| -- Refresh TTL so long-running traffic doesn't expire active queue counters. |
| redis.call('EXPIRE', KEYS[1], ARGV[2]) |
| |
| return 1 |
| `) |
|
|
| |
| decrementWaitScript = redis.NewScript(` |
| local current = redis.call('GET', KEYS[1]) |
| if current ~= false and tonumber(current) > 0 then |
| redis.call('DECR', KEYS[1]) |
| end |
| return 1 |
| `) |
|
|
| |
| |
| |
| cleanupExpiredSlotsScript = redis.NewScript(` |
| local key = KEYS[1] |
| local ttl = tonumber(ARGV[1]) |
| local timeResult = redis.call('TIME') |
| local now = tonumber(timeResult[1]) |
| local expireBefore = now - ttl |
| redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore) |
| if redis.call('ZCARD', key) == 0 then |
| redis.call('DEL', key) |
| else |
| redis.call('EXPIRE', key, ttl) |
| end |
| return 1 |
| `) |
|
|
| |
| |
| |
| startupCleanupScript = redis.NewScript(` |
| local activePrefix = ARGV[1] |
| local slotTTL = tonumber(ARGV[2]) |
| local removed = 0 |
| for i = 1, #KEYS do |
| local key = KEYS[i] |
| local members = redis.call('ZRANGE', key, 0, -1) |
| for _, member in ipairs(members) do |
| if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then |
| removed = removed + redis.call('ZREM', key, member) |
| end |
| end |
| if redis.call('ZCARD', key) == 0 then |
| redis.call('DEL', key) |
| else |
| redis.call('EXPIRE', key, slotTTL) |
| end |
| end |
| return removed |
| `) |
| ) |
|
|
| type concurrencyCache struct { |
| rdb *redis.Client |
| slotTTLSeconds int |
| waitQueueTTLSeconds int |
| } |
|
|
| |
| |
| |
| func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache { |
| if slotTTLMinutes <= 0 { |
| slotTTLMinutes = defaultSlotTTLMinutes |
| } |
| if waitQueueTTLSeconds <= 0 { |
| waitQueueTTLSeconds = slotTTLMinutes * 60 |
| } |
| return &concurrencyCache{ |
| rdb: rdb, |
| slotTTLSeconds: slotTTLMinutes * 60, |
| waitQueueTTLSeconds: waitQueueTTLSeconds, |
| } |
| } |
|
|
| |
| func accountSlotKey(accountID int64) string { |
| return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID) |
| } |
|
|
| func userSlotKey(userID int64) string { |
| return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID) |
| } |
|
|
| func waitQueueKey(userID int64) string { |
| return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID) |
| } |
|
|
| func accountWaitKey(accountID int64) string { |
| return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID) |
| } |
|
|
| |
|
|
| func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { |
| key := accountSlotKey(accountID) |
| |
| result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int() |
| if err != nil { |
| return false, err |
| } |
| return result == 1, nil |
| } |
|
|
| func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { |
| key := accountSlotKey(accountID) |
| return c.rdb.ZRem(ctx, key, requestID).Err() |
| } |
|
|
| func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { |
| key := accountSlotKey(accountID) |
| |
| result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int() |
| if err != nil { |
| return 0, err |
| } |
| return result, nil |
| } |
|
|
| func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { |
| if len(accountIDs) == 0 { |
| return map[int64]int{}, nil |
| } |
|
|
| now, err := c.rdb.Time(ctx).Result() |
| if err != nil { |
| return nil, fmt.Errorf("redis TIME: %w", err) |
| } |
| cutoffTime := now.Unix() - int64(c.slotTTLSeconds) |
|
|
| pipe := c.rdb.Pipeline() |
| type accountCmd struct { |
| accountID int64 |
| zcardCmd *redis.IntCmd |
| } |
| cmds := make([]accountCmd, 0, len(accountIDs)) |
| for _, accountID := range accountIDs { |
| slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10) |
| pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) |
| cmds = append(cmds, accountCmd{ |
| accountID: accountID, |
| zcardCmd: pipe.ZCard(ctx, slotKey), |
| }) |
| } |
|
|
| if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { |
| return nil, fmt.Errorf("pipeline exec: %w", err) |
| } |
|
|
| result := make(map[int64]int, len(accountIDs)) |
| for _, cmd := range cmds { |
| result[cmd.accountID] = int(cmd.zcardCmd.Val()) |
| } |
| return result, nil |
| } |
|
|
| |
|
|
| func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { |
| key := userSlotKey(userID) |
| |
| result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int() |
| if err != nil { |
| return false, err |
| } |
| return result == 1, nil |
| } |
|
|
| func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { |
| key := userSlotKey(userID) |
| return c.rdb.ZRem(ctx, key, requestID).Err() |
| } |
|
|
| func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { |
| key := userSlotKey(userID) |
| |
| result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int() |
| if err != nil { |
| return 0, err |
| } |
| return result, nil |
| } |
|
|
| |
|
|
| func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { |
| key := waitQueueKey(userID) |
| result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() |
| if err != nil { |
| return false, err |
| } |
| return result == 1, nil |
| } |
|
|
| func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error { |
| key := waitQueueKey(userID) |
| _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() |
| return err |
| } |
|
|
| |
|
|
| func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { |
| key := accountWaitKey(accountID) |
| result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int() |
| if err != nil { |
| return false, err |
| } |
| return result == 1, nil |
| } |
|
|
| func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { |
| key := accountWaitKey(accountID) |
| _, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result() |
| return err |
| } |
|
|
| func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { |
| key := accountWaitKey(accountID) |
| val, err := c.rdb.Get(ctx, key).Int() |
| if err != nil && !errors.Is(err, redis.Nil) { |
| return 0, err |
| } |
| if errors.Is(err, redis.Nil) { |
| return 0, nil |
| } |
| return val, nil |
| } |
|
|
| func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { |
| if len(accounts) == 0 { |
| return map[int64]*service.AccountLoadInfo{}, nil |
| } |
|
|
| |
| |
| now, err := c.rdb.Time(ctx).Result() |
| if err != nil { |
| return nil, fmt.Errorf("redis TIME: %w", err) |
| } |
| cutoffTime := now.Unix() - int64(c.slotTTLSeconds) |
|
|
| pipe := c.rdb.Pipeline() |
|
|
| type accountCmds struct { |
| id int64 |
| maxConcurrency int |
| zcardCmd *redis.IntCmd |
| getCmd *redis.StringCmd |
| } |
| cmds := make([]accountCmds, 0, len(accounts)) |
| for _, acc := range accounts { |
| slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10) |
| waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10) |
| pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) |
| ac := accountCmds{ |
| id: acc.ID, |
| maxConcurrency: acc.MaxConcurrency, |
| zcardCmd: pipe.ZCard(ctx, slotKey), |
| getCmd: pipe.Get(ctx, waitKey), |
| } |
| cmds = append(cmds, ac) |
| } |
|
|
| if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { |
| return nil, fmt.Errorf("pipeline exec: %w", err) |
| } |
|
|
| loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts)) |
| for _, ac := range cmds { |
| currentConcurrency := int(ac.zcardCmd.Val()) |
| waitingCount := 0 |
| if v, err := ac.getCmd.Int(); err == nil { |
| waitingCount = v |
| } |
| loadRate := 0 |
| if ac.maxConcurrency > 0 { |
| loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency |
| } |
| loadMap[ac.id] = &service.AccountLoadInfo{ |
| AccountID: ac.id, |
| CurrentConcurrency: currentConcurrency, |
| WaitingCount: waitingCount, |
| LoadRate: loadRate, |
| } |
| } |
|
|
| return loadMap, nil |
| } |
|
|
| func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { |
| if len(users) == 0 { |
| return map[int64]*service.UserLoadInfo{}, nil |
| } |
|
|
| |
| now, err := c.rdb.Time(ctx).Result() |
| if err != nil { |
| return nil, fmt.Errorf("redis TIME: %w", err) |
| } |
| cutoffTime := now.Unix() - int64(c.slotTTLSeconds) |
|
|
| pipe := c.rdb.Pipeline() |
|
|
| type userCmds struct { |
| id int64 |
| maxConcurrency int |
| zcardCmd *redis.IntCmd |
| getCmd *redis.StringCmd |
| } |
| cmds := make([]userCmds, 0, len(users)) |
| for _, u := range users { |
| slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10) |
| waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10) |
| pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) |
| uc := userCmds{ |
| id: u.ID, |
| maxConcurrency: u.MaxConcurrency, |
| zcardCmd: pipe.ZCard(ctx, slotKey), |
| getCmd: pipe.Get(ctx, waitKey), |
| } |
| cmds = append(cmds, uc) |
| } |
|
|
| if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { |
| return nil, fmt.Errorf("pipeline exec: %w", err) |
| } |
|
|
| loadMap := make(map[int64]*service.UserLoadInfo, len(users)) |
| for _, uc := range cmds { |
| currentConcurrency := int(uc.zcardCmd.Val()) |
| waitingCount := 0 |
| if v, err := uc.getCmd.Int(); err == nil { |
| waitingCount = v |
| } |
| loadRate := 0 |
| if uc.maxConcurrency > 0 { |
| loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency |
| } |
| loadMap[uc.id] = &service.UserLoadInfo{ |
| UserID: uc.id, |
| CurrentConcurrency: currentConcurrency, |
| WaitingCount: waitingCount, |
| LoadRate: loadRate, |
| } |
| } |
|
|
| return loadMap, nil |
| } |
|
|
| func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { |
| key := accountSlotKey(accountID) |
| _, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result() |
| return err |
| } |
|
|
| func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error { |
| if activeRequestPrefix == "" { |
| return nil |
| } |
|
|
| |
| slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"} |
| for _, pattern := range slotPatterns { |
| if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil { |
| return err |
| } |
| } |
|
|
| |
| waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"} |
| for _, pattern := range waitPatterns { |
| if err := c.deleteKeysByPattern(ctx, pattern); err != nil { |
| return err |
| } |
| } |
|
|
| return nil |
| } |
|
|
| |
| func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error { |
| const scanCount = 200 |
| var cursor uint64 |
| for { |
| keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() |
| if err != nil { |
| return fmt.Errorf("scan %s: %w", pattern, err) |
| } |
| if len(keys) > 0 { |
| _, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result() |
| if err != nil { |
| return fmt.Errorf("cleanup slots %s: %w", pattern, err) |
| } |
| } |
| cursor = nextCursor |
| if cursor == 0 { |
| break |
| } |
| } |
| return nil |
| } |
|
|
| |
| func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error { |
| const scanCount = 200 |
| var cursor uint64 |
| for { |
| keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result() |
| if err != nil { |
| return fmt.Errorf("scan %s: %w", pattern, err) |
| } |
| if len(keys) > 0 { |
| if err := c.rdb.Del(ctx, keys...).Err(); err != nil { |
| return fmt.Errorf("del %s: %w", pattern, err) |
| } |
| } |
| cursor = nextCursor |
| if cursor == 0 { |
| break |
| } |
| } |
| return nil |
| } |
|
|