File size: 5,727 Bytes
6fefda3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | package balancer
import (
"jetbrains-ai-proxy/internal/config"
"sync"
"testing"
"time"
)
func TestNewJWTBalancer(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
// 测试轮询策略
balancer := NewJWTBalancer(tokens, config.RoundRobin)
if balancer == nil {
t.Fatal("Expected balancer to be created")
}
if balancer.GetTotalTokenCount() != 3 {
t.Errorf("Expected 3 tokens, got %d", balancer.GetTotalTokenCount())
}
if balancer.GetHealthyTokenCount() != 3 {
t.Errorf("Expected 3 healthy tokens, got %d", balancer.GetHealthyTokenCount())
}
}
func TestRoundRobinStrategy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 测试轮询顺序
expectedOrder := []string{"token1", "token2", "token3", "token1", "token2", "token3"}
for i, expected := range expectedOrder {
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error at iteration %d: %v", i, err)
}
if token != expected {
t.Errorf("At iteration %d, expected %s, got %s", i, expected, token)
}
}
}
func TestRandomStrategy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.Random)
// 测试随机策略 - 多次获取token,确保都是有效的
tokenCounts := make(map[string]int)
iterations := 100
for i := 0; i < iterations; i++ {
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error at iteration %d: %v", i, err)
}
// 检查token是否在预期列表中
found := false
for _, expectedToken := range tokens {
if token == expectedToken {
found = true
break
}
}
if !found {
t.Errorf("Got unexpected token: %s", token)
}
tokenCounts[token]++
}
// 确保所有token都被使用过(随机策略下应该都有机会被选中)
for _, token := range tokens {
if tokenCounts[token] == 0 {
t.Errorf("Token %s was never selected", token)
}
}
}
func TestMarkTokenUnhealthy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 标记一个token为不健康
balancer.MarkTokenUnhealthy("token2")
if balancer.GetHealthyTokenCount() != 2 {
t.Errorf("Expected 2 healthy tokens, got %d", balancer.GetHealthyTokenCount())
}
// 获取token,应该只返回健康的token
for i := 0; i < 10; i++ {
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if token == "token2" {
t.Errorf("Got unhealthy token: %s", token)
}
}
}
func TestMarkTokenHealthy(t *testing.T) {
tokens := []string{"token1", "token2", "token3"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 先标记为不健康,再标记为健康
balancer.MarkTokenUnhealthy("token2")
if balancer.GetHealthyTokenCount() != 2 {
t.Errorf("Expected 2 healthy tokens after marking unhealthy, got %d", balancer.GetHealthyTokenCount())
}
balancer.MarkTokenHealthy("token2")
if balancer.GetHealthyTokenCount() != 3 {
t.Errorf("Expected 3 healthy tokens after marking healthy, got %d", balancer.GetHealthyTokenCount())
}
}
func TestNoHealthyTokens(t *testing.T) {
tokens := []string{"token1", "token2"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
// 标记所有token为不健康
balancer.MarkTokenUnhealthy("token1")
balancer.MarkTokenUnhealthy("token2")
// 尝试获取token应该返回错误
_, err := balancer.GetToken()
if err == nil {
t.Error("Expected error when no healthy tokens available")
}
}
func TestConcurrentAccess(t *testing.T) {
tokens := []string{"token1", "token2", "token3", "token4", "token5"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
var wg sync.WaitGroup
numGoroutines := 10
tokensPerGoroutine := 100
// 并发获取tokens
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < tokensPerGoroutine; j++ {
_, err := balancer.GetToken()
if err != nil {
t.Errorf("Unexpected error in concurrent access: %v", err)
}
}
}()
}
// 并发标记tokens健康状态
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(index int) {
defer wg.Done()
token := tokens[index%len(tokens)]
for j := 0; j < 10; j++ {
if j%2 == 0 {
balancer.MarkTokenUnhealthy(token)
} else {
balancer.MarkTokenHealthy(token)
}
time.Sleep(time.Millisecond)
}
}(i)
}
wg.Wait()
// 确保最终状态正常
if balancer.GetTotalTokenCount() != len(tokens) {
t.Errorf("Expected %d total tokens, got %d", len(tokens), balancer.GetTotalTokenCount())
}
}
func TestRefreshTokens(t *testing.T) {
tokens := []string{"token1", "token2"}
balancer := NewJWTBalancer(tokens, config.RoundRobin)
if balancer.GetTotalTokenCount() != 2 {
t.Errorf("Expected 2 tokens initially, got %d", balancer.GetTotalTokenCount())
}
// 刷新tokens
newTokens := []string{"token3", "token4", "token5"}
balancer.RefreshTokens(newTokens)
if balancer.GetTotalTokenCount() != 3 {
t.Errorf("Expected 3 tokens after refresh, got %d", balancer.GetTotalTokenCount())
}
if balancer.GetHealthyTokenCount() != 3 {
t.Errorf("Expected 3 healthy tokens after refresh, got %d", balancer.GetHealthyTokenCount())
}
// 验证新tokens可以被获取
for i := 0; i < 6; i++ { // 两轮完整轮询
token, err := balancer.GetToken()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
found := false
for _, newToken := range newTokens {
if token == newToken {
found = true
break
}
}
if !found {
t.Errorf("Got unexpected token after refresh: %s", token)
}
}
}
|