package sql import ( "context" "database/sql" "errors" "fmt" "time" "ccLoad/internal/util" ) // ==================== 渠道级冷却方法(操作 channels 表内联字段)==================== // BumpChannelCooldown 渠道级冷却:指数退避策略(认证错误5分钟起,其他1秒起,最大30分钟) func (s *SQLStore) BumpChannelCooldown(ctx context.Context, channelID int64, now time.Time, statusCode int) (time.Duration, error) { // 使用事务保护Read-Modify-Write操作,防止并发竞态 // 问题场景同BumpKeyCooldown,多个并发请求可能导致指数退避计算错误 var nextDuration time.Duration err := s.WithTransaction(ctx, func(tx *sql.Tx) error { // 1. 读取当前冷却状态(事务内,隐式锁定行) var cooldownUntil, cooldownDurationMs int64 err := tx.QueryRowContext(ctx, ` SELECT cooldown_until, cooldown_duration_ms FROM channels WHERE id = ? `, channelID).Scan(&cooldownUntil, &cooldownDurationMs) if err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.New("channel not found") } return fmt.Errorf("query channel cooldown: %w", err) } // 2. 计算新的冷却时间(指数退避) until := unixToTime(cooldownUntil) nextDuration = util.CalculateBackoffDuration(cooldownDurationMs, until, now, &statusCode) newUntil := now.Add(nextDuration) // 3. 更新 channels 表(事务内) _, err = tx.ExecContext(ctx, ` UPDATE channels SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? WHERE id = ? `, timeToUnix(newUntil), int64(nextDuration/time.Millisecond), timeToUnix(now), channelID) if err != nil { return fmt.Errorf("update channel cooldown: %w", err) } return nil }) return nextDuration, err } // ResetChannelCooldown 重置渠道冷却状态 // 优化:仅更新实际处于冷却中的记录,避免无谓的写入 func (s *SQLStore) ResetChannelCooldown(ctx context.Context, channelID int64) error { _, err := s.db.ExecContext(ctx, ` UPDATE channels SET cooldown_until = 0, cooldown_duration_ms = 0, updated_at = ? WHERE id = ? AND cooldown_until > 0 `, timeToUnix(time.Now()), channelID) if err != nil { return fmt.Errorf("reset channel cooldown: %w", err) } return nil } // SetChannelCooldown 设置渠道冷却(手动设置冷却时间) func (s *SQLStore) SetChannelCooldown(ctx context.Context, channelID int64, until time.Time) error { now := time.Now() durationMs := util.CalculateCooldownDuration(until, now) _, err := s.db.ExecContext(ctx, ` UPDATE channels SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? WHERE id = ? `, timeToUnix(until), durationMs, timeToUnix(now), channelID) if err != nil { return fmt.Errorf("set channel cooldown: %w", err) } return nil } // GetAllChannelCooldowns 批量查询所有渠道冷却状态(从 channels 表读取) func (s *SQLStore) GetAllChannelCooldowns(ctx context.Context) (map[int64]time.Time, error) { now := timeToUnix(time.Now()) query := `SELECT id, cooldown_until FROM channels WHERE cooldown_until > ?` rows, err := s.db.QueryContext(ctx, query, now) if err != nil { return nil, fmt.Errorf("query all channel cooldowns: %w", err) } defer func() { _ = rows.Close() }() result := make(map[int64]time.Time) for rows.Next() { var channelID int64 var until int64 if err := rows.Scan(&channelID, &until); err != nil { return nil, fmt.Errorf("scan channel cooldown: %w", err) } result[channelID] = unixToTime(until) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate channel cooldowns: %w", err) } return result, nil } // ==================== Key级别冷却机制(操作 api_keys 表内联字段)==================== // GetKeyCooldownUntil 查询指定Key的冷却截止时间(从 api_keys 表读取) func (s *SQLStore) GetKeyCooldownUntil(ctx context.Context, configID int64, keyIndex int) (time.Time, bool) { var cooldownUntil int64 err := s.db.QueryRowContext(ctx, ` SELECT cooldown_until FROM api_keys WHERE channel_id = ? AND key_index = ? `, configID, keyIndex).Scan(&cooldownUntil) if err != nil { return time.Time{}, false } if cooldownUntil == 0 { return time.Time{}, false } return unixToTime(cooldownUntil), true } // GetAllKeyCooldowns 批量查询所有Key冷却状态(从 api_keys 表读取) // 返回: map[channelID]map[keyIndex]cooldownUntil func (s *SQLStore) GetAllKeyCooldowns(ctx context.Context) (map[int64]map[int]time.Time, error) { now := timeToUnix(time.Now()) query := `SELECT channel_id, key_index, cooldown_until FROM api_keys WHERE cooldown_until > ? AND disabled = 0` rows, err := s.db.QueryContext(ctx, query, now) if err != nil { return nil, fmt.Errorf("query all key cooldowns: %w", err) } defer func() { _ = rows.Close() }() result := make(map[int64]map[int]time.Time) for rows.Next() { var channelID int64 var keyIndex int var until int64 if err := rows.Scan(&channelID, &keyIndex, &until); err != nil { return nil, fmt.Errorf("scan key cooldown: %w", err) } // 初始化渠道级map if result[channelID] == nil { result[channelID] = make(map[int]time.Time) } result[channelID][keyIndex] = unixToTime(until) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("rows error: %w", err) } return result, nil } // BumpKeyCooldown Key级别冷却:指数退避策略(认证错误5分钟起,其他1秒起,最大30分钟) func (s *SQLStore) BumpKeyCooldown(ctx context.Context, configID int64, keyIndex int, now time.Time, statusCode int) (time.Duration, error) { // 使用事务保护Read-Modify-Write操作,防止并发竞态 // 问题场景: // 请求A: 读取duration=1000 → 计算新值=2000 // 请求B: 读取duration=1000 → 计算新值=2000 (应该是4000!) // 请求A: 写入2000 // 请求B: 写入2000 (覆盖A的更新,指数退避失效!) // // 修复后: 整个操作在事务中原子执行,避免Lost Update问题 var nextDuration time.Duration err := s.WithTransaction(ctx, func(tx *sql.Tx) error { // 1. 读取当前冷却状态(事务内,隐式锁定行) var cooldownUntil, cooldownDurationMs int64 err := tx.QueryRowContext(ctx, ` SELECT cooldown_until, cooldown_duration_ms FROM api_keys WHERE channel_id = ? AND key_index = ? `, configID, keyIndex).Scan(&cooldownUntil, &cooldownDurationMs) if err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.New("api key not found") } return fmt.Errorf("query key cooldown: %w", err) } // 2. 计算新的冷却时间(指数退避) until := unixToTime(cooldownUntil) nextDuration = util.CalculateBackoffDuration(cooldownDurationMs, until, now, &statusCode) newUntil := now.Add(nextDuration) // 3. 更新 api_keys 表(事务内) _, err = tx.ExecContext(ctx, ` UPDATE api_keys SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? WHERE channel_id = ? AND key_index = ? `, timeToUnix(newUntil), int64(nextDuration/time.Millisecond), timeToUnix(now), configID, keyIndex) if err != nil { return fmt.Errorf("update key cooldown: %w", err) } return nil }) return nextDuration, err } // SetKeyCooldown 设置指定Key的冷却截止时间(操作 api_keys 表) func (s *SQLStore) SetKeyCooldown(ctx context.Context, configID int64, keyIndex int, until time.Time) error { now := time.Now() durationMs := util.CalculateCooldownDuration(until, now) _, err := s.db.ExecContext(ctx, ` UPDATE api_keys SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? WHERE channel_id = ? AND key_index = ? `, timeToUnix(until), durationMs, timeToUnix(now), configID, keyIndex) return err } // ResetKeyCooldown 重置指定Key的冷却状态(操作 api_keys 表) // 优化:仅更新实际处于冷却中的记录,避免无谓的写入 func (s *SQLStore) ResetKeyCooldown(ctx context.Context, configID int64, keyIndex int) error { _, err := s.db.ExecContext(ctx, ` UPDATE api_keys SET cooldown_until = 0, cooldown_duration_ms = 0, updated_at = ? WHERE channel_id = ? AND key_index = ? AND cooldown_until > 0 `, timeToUnix(time.Now()), configID, keyIndex) return err } // ClearAllKeyCooldowns 清理渠道的所有Key冷却数据(操作 api_keys 表) func (s *SQLStore) ClearAllKeyCooldowns(ctx context.Context, configID int64) error { _, err := s.db.ExecContext(ctx, ` UPDATE api_keys SET cooldown_until = 0, cooldown_duration_ms = 0, updated_at = ? WHERE channel_id = ? `, timeToUnix(time.Now()), configID) return err }