| package sqlite_test |
|
|
| import ( |
| "context" |
| "sync" |
| "sync/atomic" |
| "testing" |
| "time" |
|
|
| "ccLoad/internal/model" |
| "ccLoad/internal/storage" |
| "ccLoad/internal/util" |
| ) |
|
|
| |
| func TestAuthErrorInitialCooldown(t *testing.T) { |
| tests := []struct { |
| name string |
| statusCode int |
| expectedMinDur time.Duration |
| expectedMaxDur time.Duration |
| }{ |
| { |
| name: "401未认证错误-初始冷却5分钟", |
| statusCode: 401, |
| expectedMinDur: 5 * time.Minute, |
| expectedMaxDur: 5 * time.Minute, |
| }, |
| { |
| name: "403禁止访问错误-初始冷却5分钟", |
| statusCode: 403, |
| expectedMinDur: 5 * time.Minute, |
| expectedMaxDur: 5 * time.Minute, |
| }, |
| { |
| name: "429限流错误-初始冷却1分钟", |
| statusCode: 429, |
| expectedMinDur: time.Minute, |
| expectedMaxDur: time.Minute, |
| }, |
| { |
| name: "500服务器错误-初始冷却2分钟", |
| statusCode: 500, |
| expectedMinDur: 2 * time.Minute, |
| expectedMaxDur: 2 * time.Minute, |
| }, |
| } |
|
|
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| |
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
| now := time.Now() |
|
|
| |
| cfg := &model.Config{ |
| Name: "test-channel", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| duration, err := store.BumpChannelCooldown(ctx, created.ID, now, tt.statusCode) |
| if err != nil { |
| t.Fatalf("BumpCooldownOnError失败: %v", err) |
| } |
|
|
| |
| if duration < tt.expectedMinDur || duration > tt.expectedMaxDur { |
| t.Errorf("状态码%d的初始冷却时间错误: 期望%v,实际%v", |
| tt.statusCode, tt.expectedMinDur, duration) |
| } |
|
|
| |
| until, exists := getChannelCooldownUntil(ctx, store, created.ID) |
| if !exists { |
| t.Fatal("冷却记录不存在") |
| } |
|
|
| actualDuration := until.Sub(now) |
| tolerance := 1 * time.Second |
|
|
| if actualDuration < tt.expectedMinDur-tolerance || actualDuration > tt.expectedMaxDur+tolerance { |
| t.Errorf("数据库冷却时间错误: 期望%v,实际%v", |
| tt.expectedMinDur, actualDuration) |
| } |
|
|
| t.Logf("[INFO] 状态码%d: 初始冷却时间=%v(期望%v)", |
| tt.statusCode, duration, tt.expectedMinDur) |
| }) |
| } |
| } |
|
|
| |
| func TestAuthErrorExponentialBackoff(t *testing.T) { |
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
| now := time.Now() |
|
|
| |
| cfg := &model.Config{ |
| Name: "test-channel-backoff", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| expectedSequence := []time.Duration{ |
| 5 * time.Minute, |
| 10 * time.Minute, |
| 20 * time.Minute, |
| 30 * time.Minute, |
| 30 * time.Minute, |
| } |
|
|
| for i, expected := range expectedSequence { |
| |
| duration, err := store.BumpChannelCooldown(ctx, created.ID, now, 401) |
| if err != nil { |
| t.Fatalf("第%d次BumpCooldownOnError失败: %v", i+1, err) |
| } |
|
|
| |
| tolerance := 100 * time.Millisecond |
| if duration < expected-tolerance || duration > expected+tolerance { |
| t.Errorf("第%d次错误的冷却时间错误: 期望%v,实际%v", |
| i+1, expected, duration) |
| } |
|
|
| t.Logf("[INFO] 第%d次401错误: 冷却时间=%v(期望%v)", |
| i+1, duration, expected) |
|
|
| |
| now = now.Add(expected + 1*time.Second) |
| } |
| } |
|
|
| |
| func TestKeyLevelAuthErrorCooldown(t *testing.T) { |
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
| now := time.Now() |
|
|
| |
| cfg := &model.Config{ |
| Name: "multi-key-channel", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| keyNames := []string{"sk-key1", "sk-key2", "sk-key3"} |
| keys := make([]*model.APIKey, len(keyNames)) |
| for i, key := range keyNames { |
| keys[i] = &model.APIKey{ |
| ChannelID: created.ID, |
| KeyIndex: i, |
| APIKey: key, |
| KeyStrategy: model.KeyStrategySequential, |
| } |
| } |
| if err = store.CreateAPIKeysBatch(ctx, keys); err != nil { |
| t.Fatalf("批量创建API Keys失败: %v", err) |
| } |
|
|
| |
| duration, err := store.BumpKeyCooldown(ctx, created.ID, 0, now, 401) |
| if err != nil { |
| t.Fatalf("BumpKeyCooldownOnError失败: %v", err) |
| } |
|
|
| |
| expectedDuration := 5 * time.Minute |
| tolerance := 1 * time.Second |
| if duration < expectedDuration-tolerance || duration > expectedDuration+tolerance { |
| t.Errorf("Key级401错误初始冷却时间错误: 期望%v,实际%v", |
| expectedDuration, duration) |
| } |
|
|
| |
| until, exists := getKeyCooldownUntil(ctx, store, created.ID, 0) |
| if !exists { |
| t.Fatal("Key冷却记录不存在") |
| } |
|
|
| actualDuration := until.Sub(now) |
| if actualDuration < expectedDuration-tolerance || actualDuration > expectedDuration+tolerance { |
| t.Errorf("数据库Key冷却时间错误: 期望%v,实际%v", |
| expectedDuration, actualDuration) |
| } |
|
|
| t.Logf("[INFO] Key级401错误: 初始冷却时间=%v(期望%v)", |
| duration, expectedDuration) |
| } |
|
|
| |
| func TestMixedErrorCodesCooldown(t *testing.T) { |
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
| now := time.Now() |
|
|
| |
| cfg := &model.Config{ |
| Name: "mixed-errors-channel", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| duration1, err := store.BumpChannelCooldown(ctx, created.ID, now, 500) |
| if err != nil { |
| t.Fatalf("首次500错误失败: %v", err) |
| } |
|
|
| if duration1 != 2*time.Minute { |
| t.Errorf("500错误初始冷却时间错误: 期望2分钟,实际%v", duration1) |
| } |
|
|
| |
| now2 := now.Add(3 * time.Minute) |
| duration2, err := store.BumpChannelCooldown(ctx, created.ID, now2, 401) |
| if err != nil { |
| t.Fatalf("后续401错误失败: %v", err) |
| } |
|
|
| |
| |
| |
| |
| expectedDuration := 4 * time.Minute |
| tolerance := 100 * time.Millisecond |
|
|
| if duration2 < expectedDuration-tolerance || duration2 > expectedDuration+tolerance { |
| t.Errorf("混合错误场景冷却时间错误: 期望%v,实际%v", |
| expectedDuration, duration2) |
| } |
|
|
| t.Logf("[INFO] 500错误(2min) → 401错误(%v) - 使用指数退避而非重置", duration2) |
| } |
|
|
| |
| func TestConcurrentCooldownUpdates(t *testing.T) { |
| if testing.Short() { |
| t.Skip("跳过并发测试(使用 -short 标志)") |
| } |
|
|
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
|
|
| |
| cfg := &model.Config{ |
| Name: "concurrent-test", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| const concurrency = 10 |
| var wg sync.WaitGroup |
| for i := 0; i < concurrency; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| _, _ = store.BumpChannelCooldown(ctx, created.ID, time.Now(), 401) |
| }() |
| } |
| wg.Wait() |
|
|
| |
| until, exists := getChannelCooldownUntil(ctx, store, created.ID) |
| if !exists { |
| t.Fatal("冷却记录不存在") |
| } |
|
|
| duration := time.Until(until) |
| minDuration := util.AuthErrorInitialCooldown - 1*time.Second |
| maxDuration := util.MaxCooldownDuration + 1*time.Second |
|
|
| if duration < minDuration || duration > maxDuration { |
| t.Errorf("并发场景冷却时间异常: %v (期望范围: %v - %v)", |
| duration, minDuration, maxDuration) |
| } |
|
|
| t.Logf("[INFO] 并发测试通过: %d个并发更新,最终冷却时间=%v", concurrency, duration) |
| } |
|
|
| |
| func TestConcurrentKeyCooldownUpdates(t *testing.T) { |
| if testing.Short() { |
| t.Skip("跳过并发测试(使用 -short 标志)") |
| } |
|
|
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
|
|
| |
| cfg := &model.Config{ |
| Name: "concurrent-key-test", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| keyNames := []string{"sk-key1", "sk-key2", "sk-key3"} |
| keys := make([]*model.APIKey, len(keyNames)) |
| for i, key := range keyNames { |
| keys[i] = &model.APIKey{ |
| ChannelID: created.ID, |
| KeyIndex: i, |
| APIKey: key, |
| KeyStrategy: model.KeyStrategySequential, |
| } |
| } |
| if err = store.CreateAPIKeysBatch(ctx, keys); err != nil { |
| t.Fatalf("批量创建API Keys失败: %v", err) |
| } |
|
|
| |
| sem := make(chan struct{}, 2) |
| var wg sync.WaitGroup |
| var successCount int32 |
|
|
| |
| for keyIndex := 0; keyIndex < 3; keyIndex++ { |
| for i := 0; i < 3; i++ { |
| wg.Add(1) |
| go func(idx int) { |
| defer wg.Done() |
| sem <- struct{}{} |
| defer func() { <-sem }() |
|
|
| _, err := store.BumpKeyCooldown(ctx, created.ID, idx, time.Now(), 401) |
| if err == nil { |
| atomic.AddInt32(&successCount, 1) |
| } |
| }(keyIndex) |
| } |
| } |
| wg.Wait() |
|
|
| t.Logf("[INFO] 并发更新完成: 成功次数=%d/9", successCount) |
|
|
| |
| for keyIndex := 0; keyIndex < 3; keyIndex++ { |
| until, exists := getKeyCooldownUntil(ctx, store, created.ID, keyIndex) |
| if !exists { |
| t.Errorf("Key %d 冷却记录不存在", keyIndex) |
| continue |
| } |
|
|
| duration := time.Until(until) |
| minDuration := util.AuthErrorInitialCooldown - 1*time.Second |
| maxDuration := util.MaxCooldownDuration + 1*time.Second |
|
|
| if duration < minDuration || duration > maxDuration { |
| t.Errorf("Key %d 并发场景冷却时间异常: %v (期望范围: %v - %v)", |
| keyIndex, duration, minDuration, maxDuration) |
| } |
| } |
| } |
|
|
| |
| |
| func TestRaceConditionDetection(t *testing.T) { |
| if testing.Short() { |
| t.Skip("跳过竞态检测测试(使用 -short 标志)") |
| } |
|
|
| store, cleanup := setupSQLiteTestStore(t, "test-auth-error.db") |
| defer cleanup() |
|
|
| ctx := context.Background() |
|
|
| cfg := &model.Config{ |
| Name: "race-test", |
| URL: "https://api.example.com", |
| Enabled: true, |
| } |
| created, err := store.CreateConfig(ctx, cfg) |
| if err != nil { |
| t.Fatalf("创建测试渠道失败: %v", err) |
| } |
|
|
| |
| _ = store.CreateAPIKeysBatch(ctx, []*model.APIKey{ |
| {ChannelID: created.ID, KeyIndex: 0, APIKey: "sk-key1", KeyStrategy: model.KeyStrategySequential}, |
| {ChannelID: created.ID, KeyIndex: 1, APIKey: "sk-key2", KeyStrategy: model.KeyStrategySequential}, |
| }) |
|
|
| |
| var wg sync.WaitGroup |
| for i := 0; i < 5; i++ { |
| wg.Add(3) |
|
|
| |
| go func() { |
| defer wg.Done() |
| _, _ = store.BumpChannelCooldown(ctx, created.ID, time.Now(), 401) |
| }() |
|
|
| |
| go func() { |
| defer wg.Done() |
| _, _ = store.BumpKeyCooldown(ctx, created.ID, 0, time.Now(), 401) |
| }() |
|
|
| |
| go func() { |
| defer wg.Done() |
| _, _ = store.GetConfig(ctx, created.ID) |
| }() |
| } |
|
|
| wg.Wait() |
| t.Log("[INFO] 竞态检测测试通过(使用 go test -race 运行以检测竞态条件)") |
| } |
|
|
| |
| |
| |
| func getChannelCooldownUntil(ctx context.Context, store storage.Store, channelID int64) (time.Time, bool) { |
| cfg, err := store.GetConfig(ctx, channelID) |
| if err != nil || cfg == nil { |
| return time.Time{}, false |
| } |
| if cfg.CooldownUntil == 0 { |
| return time.Time{}, false |
| } |
| until := time.Unix(cfg.CooldownUntil, 0) |
| |
| return until, time.Now().Before(until) |
| } |
|
|
| |
| func getKeyCooldownUntil(ctx context.Context, store storage.Store, channelID int64, keyIndex int) (time.Time, bool) { |
| key, err := store.GetAPIKey(ctx, channelID, keyIndex) |
| if err != nil || key == nil { |
| return time.Time{}, false |
| } |
| if key.CooldownUntil == 0 { |
| return time.Time{}, false |
| } |
| until := time.Unix(key.CooldownUntil, 0) |
| |
| return until, time.Now().Before(until) |
| } |
|
|