| package sql |
|
|
| import ( |
| "context" |
| "database/sql" |
| "errors" |
| "fmt" |
| "time" |
|
|
| "ccLoad/internal/util" |
| ) |
|
|
| |
|
|
| |
| func (s *SQLStore) BumpChannelCooldown(ctx context.Context, channelID int64, now time.Time, statusCode int) (time.Duration, error) { |
| |
| |
|
|
| var nextDuration time.Duration |
|
|
| err := s.WithTransaction(ctx, func(tx *sql.Tx) error { |
| |
| var cooldownUntil, cooldownDurationMs int64 |
| err := tx.QueryRowContext(ctx, ` |
| SELECT cooldown_until, cooldown_duration_ms |
| FROM channels |
| WHERE id = ? |
| `, channelID).Scan(&cooldownUntil, &cooldownDurationMs) |
|
|
| if err != nil { |
| if errors.Is(err, sql.ErrNoRows) { |
| return errors.New("channel not found") |
| } |
| return fmt.Errorf("query channel cooldown: %w", err) |
| } |
|
|
| |
| until := unixToTime(cooldownUntil) |
| nextDuration = util.CalculateBackoffDuration(cooldownDurationMs, until, now, &statusCode) |
| newUntil := now.Add(nextDuration) |
|
|
| |
| _, err = tx.ExecContext(ctx, ` |
| UPDATE channels |
| SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? |
| WHERE id = ? |
| `, timeToUnix(newUntil), int64(nextDuration/time.Millisecond), timeToUnix(now), channelID) |
|
|
| if err != nil { |
| return fmt.Errorf("update channel cooldown: %w", err) |
| } |
|
|
| return nil |
| }) |
|
|
| return nextDuration, err |
| } |
|
|
| |
| |
| func (s *SQLStore) ResetChannelCooldown(ctx context.Context, channelID int64) error { |
| _, err := s.db.ExecContext(ctx, ` |
| UPDATE channels |
| SET cooldown_until = 0, cooldown_duration_ms = 0, updated_at = ? |
| WHERE id = ? AND cooldown_until > 0 |
| `, timeToUnix(time.Now()), channelID) |
|
|
| if err != nil { |
| return fmt.Errorf("reset channel cooldown: %w", err) |
| } |
|
|
| return nil |
| } |
|
|
| |
| func (s *SQLStore) SetChannelCooldown(ctx context.Context, channelID int64, until time.Time) error { |
| now := time.Now() |
| durationMs := util.CalculateCooldownDuration(until, now) |
|
|
| _, err := s.db.ExecContext(ctx, ` |
| UPDATE channels |
| SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? |
| WHERE id = ? |
| `, timeToUnix(until), durationMs, timeToUnix(now), channelID) |
|
|
| if err != nil { |
| return fmt.Errorf("set channel cooldown: %w", err) |
| } |
|
|
| return nil |
| } |
|
|
| |
| func (s *SQLStore) GetAllChannelCooldowns(ctx context.Context) (map[int64]time.Time, error) { |
| now := timeToUnix(time.Now()) |
| query := `SELECT id, cooldown_until FROM channels WHERE cooldown_until > ?` |
|
|
| rows, err := s.db.QueryContext(ctx, query, now) |
| if err != nil { |
| return nil, fmt.Errorf("query all channel cooldowns: %w", err) |
| } |
| defer func() { _ = rows.Close() }() |
|
|
| result := make(map[int64]time.Time) |
| for rows.Next() { |
| var channelID int64 |
| var until int64 |
|
|
| if err := rows.Scan(&channelID, &until); err != nil { |
| return nil, fmt.Errorf("scan channel cooldown: %w", err) |
| } |
|
|
| result[channelID] = unixToTime(until) |
| } |
|
|
| if err := rows.Err(); err != nil { |
| return nil, fmt.Errorf("iterate channel cooldowns: %w", err) |
| } |
|
|
| return result, nil |
| } |
|
|
| |
|
|
| |
| func (s *SQLStore) GetKeyCooldownUntil(ctx context.Context, configID int64, keyIndex int) (time.Time, bool) { |
| var cooldownUntil int64 |
| err := s.db.QueryRowContext(ctx, ` |
| SELECT cooldown_until |
| FROM api_keys |
| WHERE channel_id = ? AND key_index = ? |
| `, configID, keyIndex).Scan(&cooldownUntil) |
|
|
| if err != nil { |
| return time.Time{}, false |
| } |
|
|
| if cooldownUntil == 0 { |
| return time.Time{}, false |
| } |
|
|
| return unixToTime(cooldownUntil), true |
| } |
|
|
| |
| |
| func (s *SQLStore) GetAllKeyCooldowns(ctx context.Context) (map[int64]map[int]time.Time, error) { |
| now := timeToUnix(time.Now()) |
| query := `SELECT channel_id, key_index, cooldown_until FROM api_keys WHERE cooldown_until > ? AND disabled = 0` |
|
|
| rows, err := s.db.QueryContext(ctx, query, now) |
| if err != nil { |
| return nil, fmt.Errorf("query all key cooldowns: %w", err) |
| } |
| defer func() { _ = rows.Close() }() |
|
|
| result := make(map[int64]map[int]time.Time) |
| for rows.Next() { |
| var channelID int64 |
| var keyIndex int |
| var until int64 |
|
|
| if err := rows.Scan(&channelID, &keyIndex, &until); err != nil { |
| return nil, fmt.Errorf("scan key cooldown: %w", err) |
| } |
|
|
| |
| if result[channelID] == nil { |
| result[channelID] = make(map[int]time.Time) |
| } |
| result[channelID][keyIndex] = unixToTime(until) |
| } |
|
|
| if err := rows.Err(); err != nil { |
| return nil, fmt.Errorf("rows error: %w", err) |
| } |
|
|
| return result, nil |
| } |
|
|
| |
| func (s *SQLStore) BumpKeyCooldown(ctx context.Context, configID int64, keyIndex int, now time.Time, statusCode int) (time.Duration, error) { |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| var nextDuration time.Duration |
|
|
| err := s.WithTransaction(ctx, func(tx *sql.Tx) error { |
| |
| var cooldownUntil, cooldownDurationMs int64 |
| err := tx.QueryRowContext(ctx, ` |
| SELECT cooldown_until, cooldown_duration_ms |
| FROM api_keys |
| WHERE channel_id = ? AND key_index = ? |
| `, configID, keyIndex).Scan(&cooldownUntil, &cooldownDurationMs) |
|
|
| if err != nil { |
| if errors.Is(err, sql.ErrNoRows) { |
| return errors.New("api key not found") |
| } |
| return fmt.Errorf("query key cooldown: %w", err) |
| } |
|
|
| |
| until := unixToTime(cooldownUntil) |
| nextDuration = util.CalculateBackoffDuration(cooldownDurationMs, until, now, &statusCode) |
| newUntil := now.Add(nextDuration) |
|
|
| |
| _, err = tx.ExecContext(ctx, ` |
| UPDATE api_keys |
| SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? |
| WHERE channel_id = ? AND key_index = ? |
| `, timeToUnix(newUntil), int64(nextDuration/time.Millisecond), timeToUnix(now), configID, keyIndex) |
|
|
| if err != nil { |
| return fmt.Errorf("update key cooldown: %w", err) |
| } |
|
|
| return nil |
| }) |
|
|
| return nextDuration, err |
| } |
|
|
| |
| func (s *SQLStore) SetKeyCooldown(ctx context.Context, configID int64, keyIndex int, until time.Time) error { |
| now := time.Now() |
| durationMs := util.CalculateCooldownDuration(until, now) |
|
|
| _, err := s.db.ExecContext(ctx, ` |
| UPDATE api_keys |
| SET cooldown_until = ?, cooldown_duration_ms = ?, updated_at = ? |
| WHERE channel_id = ? AND key_index = ? |
| `, timeToUnix(until), durationMs, timeToUnix(now), configID, keyIndex) |
|
|
| return err |
| } |
|
|
| |
| |
| func (s *SQLStore) ResetKeyCooldown(ctx context.Context, configID int64, keyIndex int) error { |
| _, err := s.db.ExecContext(ctx, ` |
| UPDATE api_keys |
| SET cooldown_until = 0, cooldown_duration_ms = 0, updated_at = ? |
| WHERE channel_id = ? AND key_index = ? AND cooldown_until > 0 |
| `, timeToUnix(time.Now()), configID, keyIndex) |
|
|
| return err |
| } |
|
|
| |
| func (s *SQLStore) ClearAllKeyCooldowns(ctx context.Context, configID int64) error { |
| _, err := s.db.ExecContext(ctx, ` |
| UPDATE api_keys |
| SET cooldown_until = 0, cooldown_duration_ms = 0, updated_at = ? |
| WHERE channel_id = ? |
| `, timeToUnix(time.Now()), configID) |
|
|
| return err |
| } |
|
|