File size: 14,339 Bytes
05c6ec5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
package main

import (
	"bufio"
	"encoding/json"
	"flag"
	"io"
	"log"
	"net/http"
	"net/url"
	"os"
	"path"
	"strings"
	"sync"
)

// Config 结构体用于存储命令行参数配置
type Config struct {
	KeyFile    string // API 密钥文件路径
	TargetURL  string // 目标 API 基础 URL
	Port       string // 代理服务器监听端口
	Address    string // 代理服务器监听地址
	Password   string // 客户端身份验证密码
	MaxWorkers int    // 最大工作协程数
	MaxQueue   int    // 最大请求队列长度
}

// parseFlags 解析命令行参数并返回 Config 实例
func parseFlags() *Config {
	cfg := &Config{}
	flag.StringVar(&cfg.KeyFile, "key-file", "", "Path to the API key file")
	flag.StringVar(&cfg.TargetURL, "target-url", "", "Target API base URL")
	flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on")
	flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on")
	flag.StringVar(&cfg.Password, "password", "", "Password for client authentication")
	
	// 添加WorkerPool相关配置
	maxWorkers := flag.Int("max-workers", 50, "Maximum number of worker goroutines")
	maxQueue := flag.Int("max-queue", 500, "Maximum size of request queue")
	
	flag.Parse()
	
	// 将WorkerPool配置添加到Config结构体
	cfg.MaxWorkers = *maxWorkers
	cfg.MaxQueue = *maxQueue
	
	return cfg
}

// KeyPool 管理 API 密钥池
type KeyPool struct {
	keys         []string   // 密钥列表
	mu           sync.Mutex // 互斥锁,确保线程安全
	currentIndex int        // 当前密钥索引,用于循环抽取
}

// NewKeyPool 从文件中加载密钥并创建 KeyPool 实例
func NewKeyPool(filePath string) (*KeyPool, error) {
	file, err := os.Open(filePath)
	if err != nil {
		log.Printf("[ERROR] Failed to open key file %s: %v", filePath, err)
		return nil, err
	}
	defer file.Close()

	var keys []string
	scanner := bufio.NewScanner(file)
	for scanner.Scan() {
		key := strings.TrimSpace(scanner.Text())
		if key != "" {
			keys = append(keys, key)
		}
	}
	if err := scanner.Err(); err != nil {
		log.Printf("[ERROR] Failed to read key file %s: %v", filePath, err)
		return nil, err
	}
	log.Printf("[INFO] Loaded %d keys from file %s", len(keys), filePath)
	return &KeyPool{keys: keys, currentIndex: 0}, nil
}

// GetRandomKey 按顺序循环返回一个密钥
func (kp *KeyPool) GetRandomKey() string {
	kp.mu.Lock()
	defer kp.mu.Unlock()
	if len(kp.keys) == 0 {
		return ""
	}
	key := kp.keys[kp.currentIndex]
	kp.currentIndex = (kp.currentIndex + 1) % len(kp.keys) // 循环到下一个索引
	return key
}

// 定义请求结构体
type ProxyRequest struct {
	Request  *http.Request
	Response http.ResponseWriter
	Done     chan bool // 用于通知请求处理完成
}

// Worker结构体,表示一个工作协程
type Worker struct {
	ID         int
	TaskQueue  chan *ProxyRequest // 任务队列
	Quit       chan bool          // 退出信号
	WorkerPool *WorkerPool        // 所属工作池
}

// 创建新的Worker
func NewWorker(id int, workerPool *WorkerPool) *Worker {
	return &Worker{
		ID:         id,
		TaskQueue:  make(chan *ProxyRequest),
		Quit:       make(chan bool),
		WorkerPool: workerPool,
	}
}

// Worker开始工作
func (w *Worker) Start() {
	go func() {
		for {
			// 将worker注册到工作池的空闲队列
			w.WorkerPool.WorkerQueue <- w.TaskQueue

			select {
			case task := <-w.TaskQueue:
				// 处理请求
				w.WorkerPool.HandleFunc(task.Response, task.Request)
				task.Done <- true
			case <-w.Quit:
				// 收到退出信号
				return
			}
		}
	}()
}

// Worker停止工作
func (w *Worker) Stop() {
	go func() {
		w.Quit <- true
	}()
}

// WorkerPool结构体,管理工作协程池
type WorkerPool struct {
	WorkerQueue chan chan *ProxyRequest // 空闲Worker队列
	TaskQueue   chan *ProxyRequest      // 任务队列
	MaxWorkers  int                     // 最大Worker数量
	MaxQueue    int                     // 最大队列长度
	HandleFunc  func(http.ResponseWriter, *http.Request) // 请求处理函数
}

// 创建新的WorkerPool
func NewWorkerPool(maxWorkers int, maxQueue int, handleFunc func(http.ResponseWriter, *http.Request)) *WorkerPool {
	pool := &WorkerPool{
		WorkerQueue: make(chan chan *ProxyRequest, maxWorkers),
		TaskQueue:   make(chan *ProxyRequest, maxQueue),
		MaxWorkers:  maxWorkers,
		MaxQueue:    maxQueue,
		HandleFunc:  handleFunc,
	}
	return pool
}

// 启动WorkerPool
func (wp *WorkerPool) Start() {
	// 创建并启动workers
	for i := 0; i < wp.MaxWorkers; i++ {
		worker := NewWorker(i, wp)
		worker.Start()
		log.Printf("[INFO] Started worker %d", i)
	}

	// 启动任务分发协程
	go wp.dispatch()
}

// 停止WorkerPool
func (wp *WorkerPool) Stop() {
	// TODO: 实现停止逻辑
}

// 将任务分发给空闲worker
func (wp *WorkerPool) dispatch() {
	for {
		select {
		case task := <-wp.TaskQueue:
			// 等待空闲worker
			workerTaskQueue := <-wp.WorkerQueue
			// 将任务发送给worker
			workerTaskQueue <- task
		}
	}
}

// 将请求提交到WorkerPool
func (wp *WorkerPool) Submit(response http.ResponseWriter, request *http.Request) bool {
	task := &ProxyRequest{
		Request:  request,
		Response: response,
		Done:     make(chan bool, 1),
	}

	select {
	case wp.TaskQueue <- task:
		// 请求成功加入队列
		<-task.Done // 等待任务完成
		return true
	default:
		// 队列已满,实现背压
		log.Println("[WARN] Task queue is full, rejecting request")
		http.Error(response, "Server is busy, please try again later", http.StatusServiceUnavailable)
		return false
	}
}

// ProxyHandler 处理 HTTP 代理请求
type ProxyHandler struct {
	cfg        *Config      // 配置信息
	keyPool    *KeyPool     // 密钥池
	client     *http.Client // HTTP 客户端
	workerPool *WorkerPool  // 工作协程池
}

// NewProxyHandler 创建 ProxyHandler 实例
func NewProxyHandler(cfg *Config, keyPool *KeyPool) *ProxyHandler {
	handler := &ProxyHandler{
		cfg:     cfg,
		keyPool: keyPool,
		client:  &http.Client{},
	}
	return handler
}

// InitWorkerPool 初始化工作协程池
func (ph *ProxyHandler) InitWorkerPool(maxWorkers int, maxQueue int) {
	ph.workerPool = NewWorkerPool(maxWorkers, maxQueue, ph.HandleRequest)
	ph.workerPool.Start()
	log.Printf("[INFO] Started worker pool with %d workers and queue size %d", maxWorkers, maxQueue)
}

// ServeHTTP 实现 HTTP 处理逻辑
func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	// 记录接收到的请求
	log.Printf("[INFO] Received request: %s %s", r.Method, r.URL.String())

	// 将请求提交到工作池处理
	ph.workerPool.Submit(w, r)
}

// HandleRequest 处理请求的方法,由Worker调用
func (ph *ProxyHandler) HandleRequest(w http.ResponseWriter, r *http.Request) {
	// 验证客户端身份
	if !ph.authenticate(r) {
		log.Println("[WARN] Unauthorized access attempt")
		http.Error(w, "Unauthorized", http.StatusUnauthorized)
		return
	}
	log.Println("[INFO] Authentication successful")

	// 尝试解析请求体中的模型信息
	model, err := ph.extractModelFromRequest(r)
	if err != nil {
		log.Printf("[WARN] Failed to extract model from request: %v", err)
	} else if model != "" {
		log.Printf("[INFO] Model specified in request: %s", model)
	}

	// 构建目标 URL
	targetURL, err := ph.buildTargetURL(r)
	if err != nil {
		log.Printf("[ERROR] Failed to build target URL: %v", err)
		http.Error(w, "Bad Request", http.StatusBadRequest)
		return
	}
	log.Printf("[INFO] Target URL: %s", targetURL)

	// 重试逻辑
	maxRetries := len(ph.keyPool.keys)
	attemptedKeys := make(map[string]bool)
	log.Printf("[INFO] Starting key selection process, total keys available: %d", maxRetries)

	for i := 0; i < maxRetries; i++ {
		key := ph.getUnusedKey(attemptedKeys)
		if key == "" {
			log.Printf("[ERROR] No unused keys remaining after %d attempts", i)
			break
		}
		attemptedKeys[key] = true
		maskedKey := maskKey(key)
		log.Printf("[INFO] Attempt %d/%d: Selecting key %s", i+1, maxRetries, maskedKey)

		// 创建请求
		req, err := ph.createRequest(r, targetURL, key)
		if err != nil {
			log.Printf("[ERROR] Failed to create request with key %s: %v", maskedKey, err)
			log.Printf("[INFO] Switching to another key due to request creation failure")
			continue
		}

		// 发送请求
		log.Printf("[INFO] Sending request to target API with key %s", maskedKey)
		resp, err := ph.client.Do(req)
		if err != nil {
			log.Printf("[ERROR] Failed to send request with key %s: %v", maskedKey, err)
			log.Printf("[INFO] Switching to another key due to network error")
			continue
		}
		defer resp.Body.Close()

		// 处理响应
		log.Printf("[INFO] Received response with status code %d", resp.StatusCode)
		if resp.StatusCode >= 200 && resp.StatusCode < 300 {
			log.Println("[INFO] Request successful, forwarding response")
			ph.forwardResponse(w, resp)
			return
		} else if resp.StatusCode == 403 || resp.StatusCode == 429 {
			log.Printf("[WARN] Received %d status code with key %s", resp.StatusCode, maskedKey)
			log.Printf("[INFO] Switching to another key due to status code %d", resp.StatusCode)
			continue
		} else {
			log.Printf("[INFO] Forwarding response with status code %d", resp.StatusCode)
			ph.forwardResponse(w, resp)
			return
		}
	}

	// 所有密钥尝试后仍失败
	log.Printf("[ERROR] All %d keys failed after retries", maxRetries)
	http.Error(w, "Failed to get response from API after all retries", http.StatusBadGateway)
}

// getUnusedKey 获取一个未使用过的密钥
func (ph *ProxyHandler) getUnusedKey(attempted map[string]bool) string {
	key := ph.keyPool.GetRandomKey()
	// 如果获取到的密钥已使用过,则尝试其他密钥
	for attempted[key] && len(attempted) < len(ph.keyPool.keys) {
		key = ph.keyPool.GetRandomKey()
	}
	// 如果所有密钥都已尝试过,返回空字符串
	if attempted[key] {
		return ""
	}
	return key
}

// authenticate 验证客户端身份
func (ph *ProxyHandler) authenticate(r *http.Request) bool {
	authHeader := r.Header.Get("Authorization")
	if authHeader == "" {
		return false
	}
	parts := strings.Split(authHeader, " ")
	if len(parts) != 2 || parts[0] != "Bearer" {
		return false
	}
	return parts[1] == ph.cfg.Password
}

// buildTargetURL 构建目标 API 的完整 URL
func (ph *ProxyHandler) buildTargetURL(r *http.Request) (string, error) {
	u, err := url.Parse(ph.cfg.TargetURL)
	if err != nil {
		return "", err
	}
	u.Path = path.Join(u.Path, r.URL.Path)
	u.RawQuery = r.URL.RawQuery
	return u.String(), nil
}

// createRequest 创建转发请求
func (ph *ProxyHandler) createRequest(r *http.Request, targetURL, key string) (*http.Request, error) {
	req, err := http.NewRequest(r.Method, targetURL, r.Body)
	if err != nil {
		return nil, err
	}

	// 复制并修改请求头
	for k, v := range r.Header {
		if k != "Host" && k != "Connection" && k != "Proxy-Connection" && k != "Authorization" {
			req.Header[k] = v
		}
	}
	req.Header.Set("Authorization", "Bearer "+key)
	return req, nil
}

// forwardResponse 将响应转发给客户端,支持流式和非流式
func (ph *ProxyHandler) forwardResponse(w http.ResponseWriter, resp *http.Response) {
	// 设置响应头
	for k, v := range resp.Header {
		w.Header()[k] = v
	}
	w.WriteHeader(resp.StatusCode)

	// 处理流式响应
	if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") || resp.Header.Get("Transfer-Encoding") == "chunked" {
		log.Println("[INFO] Handling streaming response")
		flusher, ok := w.(http.Flusher)
		if !ok {
			log.Println("[ERROR] Streaming unsupported by server")
			http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
			return
		}
		reader := bufio.NewReader(resp.Body)
		for {
			line, err := reader.ReadBytes('\n')
			if err != nil {
				if err == io.EOF {
					log.Println("[INFO] Stream ended")
					break
				}
				log.Printf("[ERROR] Error reading stream: %v", err)
				http.Error(w, "Error reading stream", http.StatusInternalServerError)
				return
			}
			w.Write(line)
			flusher.Flush()
		}
	} else {
		// 非流式响应,直接复制
		_, err := io.Copy(w, resp.Body)
		if err != nil {
			log.Printf("[ERROR] Failed to forward response: %v", err)
		}
	}
}

// extractModelFromRequest 尝试从请求体中提取模型名称
func (ph *ProxyHandler) extractModelFromRequest(r *http.Request) (string, error) {
	if r.Body == nil {
		return "", nil
	}
	body, err := io.ReadAll(r.Body)
	if err != nil {
		return "", err
	}
	r.Body = io.NopCloser(strings.NewReader(string(body)))

	var data map[string]interface{}
	if err := json.Unmarshal(body, &data); err != nil {
		return "", err
	}
	if model, ok := data["model"].(string); ok {
		return model, nil
	}
	return "", nil
}

// maskKey 直接返回原始密钥,不再进行掩码处理
func maskKey(key string) string {
	return key
}

// main 函数,启动代理服务器
func main() {
	// 解析配置
	cfg := parseFlags()
	if cfg.KeyFile == "" || cfg.TargetURL == "" || cfg.Password == "" {
		log.Println("[ERROR] Missing required flags: --key-file, --target-url, --password")
		os.Exit(1)
	}
	log.Printf("[INFO] Configuration loaded: KeyFile=%s, TargetURL=%s, Address=%s, Port=%s, MaxWorkers=%d, MaxQueue=%d", 
		cfg.KeyFile, cfg.TargetURL, cfg.Address, cfg.Port, cfg.MaxWorkers, cfg.MaxQueue)

	// 初始化密钥池
	keyPool, err := NewKeyPool(cfg.KeyFile)
	if err != nil {
		log.Printf("[ERROR] Failed to initialize key pool: %v", err)
		os.Exit(1)
	}

	// 创建代理处理器
	proxyHandler := NewProxyHandler(cfg, keyPool)
	
	// 初始化并启动工作池
	proxyHandler.InitWorkerPool(cfg.MaxWorkers, cfg.MaxQueue)

	// 启动服务器
	addr := cfg.Address + ":" + cfg.Port
	log.Printf("[INFO] Starting proxy server on %s", addr)
	if err := http.ListenAndServe(addr, proxyHandler); err != nil {
		log.Printf("[ERROR] Failed to start server: %v", err)
		os.Exit(1)
	}
}