File size: 5,727 Bytes
6fefda3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
package balancer

import (
	"jetbrains-ai-proxy/internal/config"
	"sync"
	"testing"
	"time"
)

func TestNewJWTBalancer(t *testing.T) {
	tokens := []string{"token1", "token2", "token3"}
	
	// 测试轮询策略
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	if balancer == nil {
		t.Fatal("Expected balancer to be created")
	}
	
	if balancer.GetTotalTokenCount() != 3 {
		t.Errorf("Expected 3 tokens, got %d", balancer.GetTotalTokenCount())
	}
	
	if balancer.GetHealthyTokenCount() != 3 {
		t.Errorf("Expected 3 healthy tokens, got %d", balancer.GetHealthyTokenCount())
	}
}

func TestRoundRobinStrategy(t *testing.T) {
	tokens := []string{"token1", "token2", "token3"}
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	
	// 测试轮询顺序
	expectedOrder := []string{"token1", "token2", "token3", "token1", "token2", "token3"}
	
	for i, expected := range expectedOrder {
		token, err := balancer.GetToken()
		if err != nil {
			t.Fatalf("Unexpected error at iteration %d: %v", i, err)
		}
		
		if token != expected {
			t.Errorf("At iteration %d, expected %s, got %s", i, expected, token)
		}
	}
}

func TestRandomStrategy(t *testing.T) {
	tokens := []string{"token1", "token2", "token3"}
	balancer := NewJWTBalancer(tokens, config.Random)
	
	// 测试随机策略 - 多次获取token,确保都是有效的
	tokenCounts := make(map[string]int)
	iterations := 100
	
	for i := 0; i < iterations; i++ {
		token, err := balancer.GetToken()
		if err != nil {
			t.Fatalf("Unexpected error at iteration %d: %v", i, err)
		}
		
		// 检查token是否在预期列表中
		found := false
		for _, expectedToken := range tokens {
			if token == expectedToken {
				found = true
				break
			}
		}
		
		if !found {
			t.Errorf("Got unexpected token: %s", token)
		}
		
		tokenCounts[token]++
	}
	
	// 确保所有token都被使用过(随机策略下应该都有机会被选中)
	for _, token := range tokens {
		if tokenCounts[token] == 0 {
			t.Errorf("Token %s was never selected", token)
		}
	}
}

func TestMarkTokenUnhealthy(t *testing.T) {
	tokens := []string{"token1", "token2", "token3"}
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	
	// 标记一个token为不健康
	balancer.MarkTokenUnhealthy("token2")
	
	if balancer.GetHealthyTokenCount() != 2 {
		t.Errorf("Expected 2 healthy tokens, got %d", balancer.GetHealthyTokenCount())
	}
	
	// 获取token,应该只返回健康的token
	for i := 0; i < 10; i++ {
		token, err := balancer.GetToken()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		
		if token == "token2" {
			t.Errorf("Got unhealthy token: %s", token)
		}
	}
}

func TestMarkTokenHealthy(t *testing.T) {
	tokens := []string{"token1", "token2", "token3"}
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	
	// 先标记为不健康,再标记为健康
	balancer.MarkTokenUnhealthy("token2")
	if balancer.GetHealthyTokenCount() != 2 {
		t.Errorf("Expected 2 healthy tokens after marking unhealthy, got %d", balancer.GetHealthyTokenCount())
	}
	
	balancer.MarkTokenHealthy("token2")
	if balancer.GetHealthyTokenCount() != 3 {
		t.Errorf("Expected 3 healthy tokens after marking healthy, got %d", balancer.GetHealthyTokenCount())
	}
}

func TestNoHealthyTokens(t *testing.T) {
	tokens := []string{"token1", "token2"}
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	
	// 标记所有token为不健康
	balancer.MarkTokenUnhealthy("token1")
	balancer.MarkTokenUnhealthy("token2")
	
	// 尝试获取token应该返回错误
	_, err := balancer.GetToken()
	if err == nil {
		t.Error("Expected error when no healthy tokens available")
	}
}

func TestConcurrentAccess(t *testing.T) {
	tokens := []string{"token1", "token2", "token3", "token4", "token5"}
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	
	var wg sync.WaitGroup
	numGoroutines := 10
	tokensPerGoroutine := 100
	
	// 并发获取tokens
	for i := 0; i < numGoroutines; i++ {
		wg.Add(1)
		go func() {
			defer wg.Done()
			for j := 0; j < tokensPerGoroutine; j++ {
				_, err := balancer.GetToken()
				if err != nil {
					t.Errorf("Unexpected error in concurrent access: %v", err)
				}
			}
		}()
	}
	
	// 并发标记tokens健康状态
	for i := 0; i < numGoroutines; i++ {
		wg.Add(1)
		go func(index int) {
			defer wg.Done()
			token := tokens[index%len(tokens)]
			for j := 0; j < 10; j++ {
				if j%2 == 0 {
					balancer.MarkTokenUnhealthy(token)
				} else {
					balancer.MarkTokenHealthy(token)
				}
				time.Sleep(time.Millisecond)
			}
		}(i)
	}
	
	wg.Wait()
	
	// 确保最终状态正常
	if balancer.GetTotalTokenCount() != len(tokens) {
		t.Errorf("Expected %d total tokens, got %d", len(tokens), balancer.GetTotalTokenCount())
	}
}

func TestRefreshTokens(t *testing.T) {
	tokens := []string{"token1", "token2"}
	balancer := NewJWTBalancer(tokens, config.RoundRobin)
	
	if balancer.GetTotalTokenCount() != 2 {
		t.Errorf("Expected 2 tokens initially, got %d", balancer.GetTotalTokenCount())
	}
	
	// 刷新tokens
	newTokens := []string{"token3", "token4", "token5"}
	balancer.RefreshTokens(newTokens)
	
	if balancer.GetTotalTokenCount() != 3 {
		t.Errorf("Expected 3 tokens after refresh, got %d", balancer.GetTotalTokenCount())
	}
	
	if balancer.GetHealthyTokenCount() != 3 {
		t.Errorf("Expected 3 healthy tokens after refresh, got %d", balancer.GetHealthyTokenCount())
	}
	
	// 验证新tokens可以被获取
	for i := 0; i < 6; i++ { // 两轮完整轮询
		token, err := balancer.GetToken()
		if err != nil {
			t.Fatalf("Unexpected error: %v", err)
		}
		
		found := false
		for _, newToken := range newTokens {
			if token == newToken {
				found = true
				break
			}
		}
		
		if !found {
			t.Errorf("Got unexpected token after refresh: %s", token)
		}
	}
}