Spaces:
Paused
Paused
| package balancer | |
| import ( | |
| "jetbrains-ai-proxy/internal/config" | |
| "sync" | |
| "testing" | |
| "time" | |
| ) | |
| func TestNewJWTBalancer(t *testing.T) { | |
| tokens := []string{"token1", "token2", "token3"} | |
| // 测试轮询策略 | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| if balancer == nil { | |
| t.Fatal("Expected balancer to be created") | |
| } | |
| if balancer.GetTotalTokenCount() != 3 { | |
| t.Errorf("Expected 3 tokens, got %d", balancer.GetTotalTokenCount()) | |
| } | |
| if balancer.GetHealthyTokenCount() != 3 { | |
| t.Errorf("Expected 3 healthy tokens, got %d", balancer.GetHealthyTokenCount()) | |
| } | |
| } | |
| func TestRoundRobinStrategy(t *testing.T) { | |
| tokens := []string{"token1", "token2", "token3"} | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| // 测试轮询顺序 | |
| expectedOrder := []string{"token1", "token2", "token3", "token1", "token2", "token3"} | |
| for i, expected := range expectedOrder { | |
| token, err := balancer.GetToken() | |
| if err != nil { | |
| t.Fatalf("Unexpected error at iteration %d: %v", i, err) | |
| } | |
| if token != expected { | |
| t.Errorf("At iteration %d, expected %s, got %s", i, expected, token) | |
| } | |
| } | |
| } | |
| func TestRandomStrategy(t *testing.T) { | |
| tokens := []string{"token1", "token2", "token3"} | |
| balancer := NewJWTBalancer(tokens, config.Random) | |
| // 测试随机策略 - 多次获取token,确保都是有效的 | |
| tokenCounts := make(map[string]int) | |
| iterations := 100 | |
| for i := 0; i < iterations; i++ { | |
| token, err := balancer.GetToken() | |
| if err != nil { | |
| t.Fatalf("Unexpected error at iteration %d: %v", i, err) | |
| } | |
| // 检查token是否在预期列表中 | |
| found := false | |
| for _, expectedToken := range tokens { | |
| if token == expectedToken { | |
| found = true | |
| break | |
| } | |
| } | |
| if !found { | |
| t.Errorf("Got unexpected token: %s", token) | |
| } | |
| tokenCounts[token]++ | |
| } | |
| // 确保所有token都被使用过(随机策略下应该都有机会被选中) | |
| for _, token := range tokens { | |
| if tokenCounts[token] == 0 { | |
| t.Errorf("Token %s was never selected", token) | |
| } | |
| } | |
| } | |
| func TestMarkTokenUnhealthy(t *testing.T) { | |
| tokens := []string{"token1", "token2", "token3"} | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| // 标记一个token为不健康 | |
| balancer.MarkTokenUnhealthy("token2") | |
| if balancer.GetHealthyTokenCount() != 2 { | |
| t.Errorf("Expected 2 healthy tokens, got %d", balancer.GetHealthyTokenCount()) | |
| } | |
| // 获取token,应该只返回健康的token | |
| for i := 0; i < 10; i++ { | |
| token, err := balancer.GetToken() | |
| if err != nil { | |
| t.Fatalf("Unexpected error: %v", err) | |
| } | |
| if token == "token2" { | |
| t.Errorf("Got unhealthy token: %s", token) | |
| } | |
| } | |
| } | |
| func TestMarkTokenHealthy(t *testing.T) { | |
| tokens := []string{"token1", "token2", "token3"} | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| // 先标记为不健康,再标记为健康 | |
| balancer.MarkTokenUnhealthy("token2") | |
| if balancer.GetHealthyTokenCount() != 2 { | |
| t.Errorf("Expected 2 healthy tokens after marking unhealthy, got %d", balancer.GetHealthyTokenCount()) | |
| } | |
| balancer.MarkTokenHealthy("token2") | |
| if balancer.GetHealthyTokenCount() != 3 { | |
| t.Errorf("Expected 3 healthy tokens after marking healthy, got %d", balancer.GetHealthyTokenCount()) | |
| } | |
| } | |
| func TestNoHealthyTokens(t *testing.T) { | |
| tokens := []string{"token1", "token2"} | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| // 标记所有token为不健康 | |
| balancer.MarkTokenUnhealthy("token1") | |
| balancer.MarkTokenUnhealthy("token2") | |
| // 尝试获取token应该返回错误 | |
| _, err := balancer.GetToken() | |
| if err == nil { | |
| t.Error("Expected error when no healthy tokens available") | |
| } | |
| } | |
| func TestConcurrentAccess(t *testing.T) { | |
| tokens := []string{"token1", "token2", "token3", "token4", "token5"} | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| var wg sync.WaitGroup | |
| numGoroutines := 10 | |
| tokensPerGoroutine := 100 | |
| // 并发获取tokens | |
| for i := 0; i < numGoroutines; i++ { | |
| wg.Add(1) | |
| go func() { | |
| defer wg.Done() | |
| for j := 0; j < tokensPerGoroutine; j++ { | |
| _, err := balancer.GetToken() | |
| if err != nil { | |
| t.Errorf("Unexpected error in concurrent access: %v", err) | |
| } | |
| } | |
| }() | |
| } | |
| // 并发标记tokens健康状态 | |
| for i := 0; i < numGoroutines; i++ { | |
| wg.Add(1) | |
| go func(index int) { | |
| defer wg.Done() | |
| token := tokens[index%len(tokens)] | |
| for j := 0; j < 10; j++ { | |
| if j%2 == 0 { | |
| balancer.MarkTokenUnhealthy(token) | |
| } else { | |
| balancer.MarkTokenHealthy(token) | |
| } | |
| time.Sleep(time.Millisecond) | |
| } | |
| }(i) | |
| } | |
| wg.Wait() | |
| // 确保最终状态正常 | |
| if balancer.GetTotalTokenCount() != len(tokens) { | |
| t.Errorf("Expected %d total tokens, got %d", len(tokens), balancer.GetTotalTokenCount()) | |
| } | |
| } | |
| func TestRefreshTokens(t *testing.T) { | |
| tokens := []string{"token1", "token2"} | |
| balancer := NewJWTBalancer(tokens, config.RoundRobin) | |
| if balancer.GetTotalTokenCount() != 2 { | |
| t.Errorf("Expected 2 tokens initially, got %d", balancer.GetTotalTokenCount()) | |
| } | |
| // 刷新tokens | |
| newTokens := []string{"token3", "token4", "token5"} | |
| balancer.RefreshTokens(newTokens) | |
| if balancer.GetTotalTokenCount() != 3 { | |
| t.Errorf("Expected 3 tokens after refresh, got %d", balancer.GetTotalTokenCount()) | |
| } | |
| if balancer.GetHealthyTokenCount() != 3 { | |
| t.Errorf("Expected 3 healthy tokens after refresh, got %d", balancer.GetHealthyTokenCount()) | |
| } | |
| // 验证新tokens可以被获取 | |
| for i := 0; i < 6; i++ { // 两轮完整轮询 | |
| token, err := balancer.GetToken() | |
| if err != nil { | |
| t.Fatalf("Unexpected error: %v", err) | |
| } | |
| found := false | |
| for _, newToken := range newTokens { | |
| if token == newToken { | |
| found = true | |
| break | |
| } | |
| } | |
| if !found { | |
| t.Errorf("Got unexpected token after refresh: %s", token) | |
| } | |
| } | |
| } | |