| package repository |
|
|
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "log" |
| "math/rand/v2" |
| "strconv" |
| "time" |
|
|
| "github.com/Wei-Shaw/sub2api/internal/service" |
| "github.com/redis/go-redis/v9" |
| ) |
|
|
| const ( |
| billingBalanceKeyPrefix = "billing:balance:" |
| billingSubKeyPrefix = "billing:sub:" |
| billingRateLimitKeyPrefix = "apikey:rate:" |
| billingCacheTTL = 5 * time.Minute |
| billingCacheJitter = 30 * time.Second |
| rateLimitCacheTTL = 7 * 24 * time.Hour |
|
|
| |
| rateLimitWindow5h = 5 * time.Hour |
| rateLimitWindow1d = 24 * time.Hour |
| rateLimitWindow7d = 7 * 24 * time.Hour |
| ) |
|
|
| |
| func jitteredTTL() time.Duration { |
| |
| if billingCacheJitter <= 0 { |
| return billingCacheTTL |
| } |
| jitter := time.Duration(rand.IntN(int(billingCacheJitter))) |
| return billingCacheTTL - jitter |
| } |
|
|
| |
| func billingBalanceKey(userID int64) string { |
| return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) |
| } |
|
|
| |
| func billingSubKey(userID, groupID int64) string { |
| return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID) |
| } |
|
|
| const ( |
| subFieldStatus = "status" |
| subFieldExpiresAt = "expires_at" |
| subFieldDailyUsage = "daily_usage" |
| subFieldWeeklyUsage = "weekly_usage" |
| subFieldMonthlyUsage = "monthly_usage" |
| subFieldVersion = "version" |
| ) |
|
|
| |
| func billingRateLimitKey(keyID int64) string { |
| return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID) |
| } |
|
|
| const ( |
| rateLimitFieldUsage5h = "usage_5h" |
| rateLimitFieldUsage1d = "usage_1d" |
| rateLimitFieldUsage7d = "usage_7d" |
| rateLimitFieldWindow5h = "window_5h" |
| rateLimitFieldWindow1d = "window_1d" |
| rateLimitFieldWindow7d = "window_7d" |
| ) |
|
|
| var ( |
| deductBalanceScript = redis.NewScript(` |
| local current = redis.call('GET', KEYS[1]) |
| if current == false then |
| return 0 |
| end |
| local newVal = tonumber(current) - tonumber(ARGV[1]) |
| redis.call('SET', KEYS[1], newVal) |
| redis.call('EXPIRE', KEYS[1], ARGV[2]) |
| return 1 |
| `) |
|
|
| updateSubUsageScript = redis.NewScript(` |
| local exists = redis.call('EXISTS', KEYS[1]) |
| if exists == 0 then |
| return 0 |
| end |
| local cost = tonumber(ARGV[1]) |
| redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost) |
| redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost) |
| redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost) |
| redis.call('EXPIRE', KEYS[1], ARGV[2]) |
| return 1 |
| `) |
|
|
| |
| |
| |
| |
| |
| |
| updateRateLimitUsageScript = redis.NewScript(` |
| local exists = redis.call('EXISTS', KEYS[1]) |
| if exists == 0 then |
| return 0 |
| end |
| local cost = tonumber(ARGV[1]) |
| local now = tonumber(ARGV[3]) |
| local win5h = tonumber(ARGV[4]) |
| local win1d = tonumber(ARGV[5]) |
| local win7d = tonumber(ARGV[6]) |
| |
| -- Helper: check if window is expired and update usage + window accordingly |
| -- Returns nothing, modifies the hash in-place. |
| local function update_window(usage_field, window_field, window_duration) |
| local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0) |
| if w == 0 or (now - w) >= window_duration then |
| -- Window expired or never started: reset usage to cost, start new window |
| redis.call('HSET', KEYS[1], usage_field, tostring(cost)) |
| redis.call('HSET', KEYS[1], window_field, tostring(now)) |
| else |
| -- Window still valid: accumulate |
| redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost) |
| end |
| end |
| |
| update_window('usage_5h', 'window_5h', win5h) |
| update_window('usage_1d', 'window_1d', win1d) |
| update_window('usage_7d', 'window_7d', win7d) |
| redis.call('EXPIRE', KEYS[1], ARGV[2]) |
| return 1 |
| `) |
| ) |
|
|
| type billingCache struct { |
| rdb *redis.Client |
| } |
|
|
| func NewBillingCache(rdb *redis.Client) service.BillingCache { |
| return &billingCache{rdb: rdb} |
| } |
|
|
| func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) { |
| key := billingBalanceKey(userID) |
| val, err := c.rdb.Get(ctx, key).Result() |
| if err != nil { |
| return 0, err |
| } |
| return strconv.ParseFloat(val, 64) |
| } |
|
|
| func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { |
| key := billingBalanceKey(userID) |
| return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err() |
| } |
|
|
| func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { |
| key := billingBalanceKey(userID) |
| _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() |
| if err != nil && !errors.Is(err, redis.Nil) { |
| log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) |
| return err |
| } |
| return nil |
| } |
|
|
| func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { |
| key := billingBalanceKey(userID) |
| return c.rdb.Del(ctx, key).Err() |
| } |
|
|
| func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) { |
| key := billingSubKey(userID, groupID) |
| result, err := c.rdb.HGetAll(ctx, key).Result() |
| if err != nil { |
| return nil, err |
| } |
| if len(result) == 0 { |
| return nil, redis.Nil |
| } |
| return c.parseSubscriptionCache(result) |
| } |
|
|
| func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) { |
| result := &service.SubscriptionCacheData{} |
|
|
| result.Status = data[subFieldStatus] |
| if result.Status == "" { |
| return nil, errors.New("invalid cache: missing status") |
| } |
|
|
| if expiresStr, ok := data[subFieldExpiresAt]; ok { |
| expiresAt, err := strconv.ParseInt(expiresStr, 10, 64) |
| if err == nil { |
| result.ExpiresAt = time.Unix(expiresAt, 0) |
| } |
| } |
|
|
| if dailyStr, ok := data[subFieldDailyUsage]; ok { |
| result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64) |
| } |
|
|
| if weeklyStr, ok := data[subFieldWeeklyUsage]; ok { |
| result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64) |
| } |
|
|
| if monthlyStr, ok := data[subFieldMonthlyUsage]; ok { |
| result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64) |
| } |
|
|
| if versionStr, ok := data[subFieldVersion]; ok { |
| result.Version, _ = strconv.ParseInt(versionStr, 10, 64) |
| } |
|
|
| return result, nil |
| } |
|
|
| func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error { |
| if data == nil { |
| return nil |
| } |
|
|
| key := billingSubKey(userID, groupID) |
|
|
| fields := map[string]any{ |
| subFieldStatus: data.Status, |
| subFieldExpiresAt: data.ExpiresAt.Unix(), |
| subFieldDailyUsage: data.DailyUsage, |
| subFieldWeeklyUsage: data.WeeklyUsage, |
| subFieldMonthlyUsage: data.MonthlyUsage, |
| subFieldVersion: data.Version, |
| } |
|
|
| pipe := c.rdb.Pipeline() |
| pipe.HSet(ctx, key, fields) |
| pipe.Expire(ctx, key, jitteredTTL()) |
| _, err := pipe.Exec(ctx) |
| return err |
| } |
|
|
| func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { |
| key := billingSubKey(userID, groupID) |
| _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result() |
| if err != nil && !errors.Is(err, redis.Nil) { |
| log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) |
| return err |
| } |
| return nil |
| } |
|
|
| func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { |
| key := billingSubKey(userID, groupID) |
| return c.rdb.Del(ctx, key).Err() |
| } |
|
|
| func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) { |
| key := billingRateLimitKey(keyID) |
| result, err := c.rdb.HGetAll(ctx, key).Result() |
| if err != nil { |
| return nil, err |
| } |
| if len(result) == 0 { |
| return nil, redis.Nil |
| } |
| data := &service.APIKeyRateLimitCacheData{} |
| if v, ok := result[rateLimitFieldUsage5h]; ok { |
| data.Usage5h, _ = strconv.ParseFloat(v, 64) |
| } |
| if v, ok := result[rateLimitFieldUsage1d]; ok { |
| data.Usage1d, _ = strconv.ParseFloat(v, 64) |
| } |
| if v, ok := result[rateLimitFieldUsage7d]; ok { |
| data.Usage7d, _ = strconv.ParseFloat(v, 64) |
| } |
| if v, ok := result[rateLimitFieldWindow5h]; ok { |
| data.Window5h, _ = strconv.ParseInt(v, 10, 64) |
| } |
| if v, ok := result[rateLimitFieldWindow1d]; ok { |
| data.Window1d, _ = strconv.ParseInt(v, 10, 64) |
| } |
| if v, ok := result[rateLimitFieldWindow7d]; ok { |
| data.Window7d, _ = strconv.ParseInt(v, 10, 64) |
| } |
| return data, nil |
| } |
|
|
| func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error { |
| if data == nil { |
| return nil |
| } |
| key := billingRateLimitKey(keyID) |
| fields := map[string]any{ |
| rateLimitFieldUsage5h: data.Usage5h, |
| rateLimitFieldUsage1d: data.Usage1d, |
| rateLimitFieldUsage7d: data.Usage7d, |
| rateLimitFieldWindow5h: data.Window5h, |
| rateLimitFieldWindow1d: data.Window1d, |
| rateLimitFieldWindow7d: data.Window7d, |
| } |
| pipe := c.rdb.Pipeline() |
| pipe.HSet(ctx, key, fields) |
| pipe.Expire(ctx, key, rateLimitCacheTTL) |
| _, err := pipe.Exec(ctx) |
| return err |
| } |
|
|
| func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error { |
| key := billingRateLimitKey(keyID) |
| now := time.Now().Unix() |
| _, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key}, |
| cost, |
| int(rateLimitCacheTTL.Seconds()), |
| now, |
| int(rateLimitWindow5h.Seconds()), |
| int(rateLimitWindow1d.Seconds()), |
| int(rateLimitWindow7d.Seconds()), |
| ).Result() |
| if err != nil && !errors.Is(err, redis.Nil) { |
| log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err) |
| return err |
| } |
| return nil |
| } |
|
|
| func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error { |
| key := billingRateLimitKey(keyID) |
| return c.rdb.Del(ctx, key).Err() |
| } |
|
|