Spaces:
Sleeping
Sleeping
| package main | |
| import ( | |
| "bufio" | |
| "encoding/json" | |
| "flag" | |
| "io" | |
| "log" | |
| "net/http" | |
| "net/url" | |
| "os" | |
| "path" | |
| "strings" | |
| "sync" | |
| ) | |
| // Config 结构体用于存储命令行参数配置 | |
| type Config struct { | |
| KeyFile string // API 密钥文件路径 | |
| TargetURL string // 目标 API 基础 URL | |
| Port string // 代理服务器监听端口 | |
| Address string // 代理服务器监听地址 | |
| Password string // 客户端身份验证密码 | |
| MaxWorkers int // 最大工作协程数 | |
| MaxQueue int // 最大请求队列长度 | |
| } | |
| // parseFlags 解析命令行参数并返回 Config 实例 | |
| func parseFlags() *Config { | |
| cfg := &Config{} | |
| flag.StringVar(&cfg.KeyFile, "key-file", "", "Path to the API key file") | |
| flag.StringVar(&cfg.TargetURL, "target-url", "", "Target API base URL") | |
| flag.StringVar(&cfg.Port, "port", "8080", "Port to listen on") | |
| flag.StringVar(&cfg.Address, "address", "localhost", "Address to listen on") | |
| flag.StringVar(&cfg.Password, "password", "", "Password for client authentication") | |
| // 添加WorkerPool相关配置 | |
| maxWorkers := flag.Int("max-workers", 50, "Maximum number of worker goroutines") | |
| maxQueue := flag.Int("max-queue", 500, "Maximum size of request queue") | |
| flag.Parse() | |
| // 将WorkerPool配置添加到Config结构体 | |
| cfg.MaxWorkers = *maxWorkers | |
| cfg.MaxQueue = *maxQueue | |
| return cfg | |
| } | |
| // KeyPool 管理 API 密钥池 | |
| type KeyPool struct { | |
| keys []string // 密钥列表 | |
| mu sync.Mutex // 互斥锁,确保线程安全 | |
| currentIndex int // 当前密钥索引,用于循环抽取 | |
| } | |
| // NewKeyPool 从文件中加载密钥并创建 KeyPool 实例 | |
| func NewKeyPool(filePath string) (*KeyPool, error) { | |
| file, err := os.Open(filePath) | |
| if err != nil { | |
| log.Printf("[ERROR] Failed to open key file %s: %v", filePath, err) | |
| return nil, err | |
| } | |
| defer file.Close() | |
| var keys []string | |
| scanner := bufio.NewScanner(file) | |
| for scanner.Scan() { | |
| key := strings.TrimSpace(scanner.Text()) | |
| if key != "" { | |
| keys = append(keys, key) | |
| } | |
| } | |
| if err := scanner.Err(); err != nil { | |
| log.Printf("[ERROR] Failed to read key file %s: %v", filePath, err) | |
| return nil, err | |
| } | |
| log.Printf("[INFO] Loaded %d keys from file %s", len(keys), filePath) | |
| return &KeyPool{keys: keys, currentIndex: 0}, nil | |
| } | |
| // GetRandomKey 按顺序循环返回一个密钥 | |
| func (kp *KeyPool) GetRandomKey() string { | |
| kp.mu.Lock() | |
| defer kp.mu.Unlock() | |
| if len(kp.keys) == 0 { | |
| return "" | |
| } | |
| key := kp.keys[kp.currentIndex] | |
| kp.currentIndex = (kp.currentIndex + 1) % len(kp.keys) // 循环到下一个索引 | |
| return key | |
| } | |
| // 定义请求结构体 | |
| type ProxyRequest struct { | |
| Request *http.Request | |
| Response http.ResponseWriter | |
| Done chan bool // 用于通知请求处理完成 | |
| } | |
| // Worker结构体,表示一个工作协程 | |
| type Worker struct { | |
| ID int | |
| TaskQueue chan *ProxyRequest // 任务队列 | |
| Quit chan bool // 退出信号 | |
| WorkerPool *WorkerPool // 所属工作池 | |
| } | |
| // 创建新的Worker | |
| func NewWorker(id int, workerPool *WorkerPool) *Worker { | |
| return &Worker{ | |
| ID: id, | |
| TaskQueue: make(chan *ProxyRequest), | |
| Quit: make(chan bool), | |
| WorkerPool: workerPool, | |
| } | |
| } | |
| // Worker开始工作 | |
| func (w *Worker) Start() { | |
| go func() { | |
| for { | |
| // 将worker注册到工作池的空闲队列 | |
| w.WorkerPool.WorkerQueue <- w.TaskQueue | |
| select { | |
| case task := <-w.TaskQueue: | |
| // 处理请求 | |
| w.WorkerPool.HandleFunc(task.Response, task.Request) | |
| task.Done <- true | |
| case <-w.Quit: | |
| // 收到退出信号 | |
| return | |
| } | |
| } | |
| }() | |
| } | |
| // Worker停止工作 | |
| func (w *Worker) Stop() { | |
| go func() { | |
| w.Quit <- true | |
| }() | |
| } | |
| // WorkerPool结构体,管理工作协程池 | |
| type WorkerPool struct { | |
| WorkerQueue chan chan *ProxyRequest // 空闲Worker队列 | |
| TaskQueue chan *ProxyRequest // 任务队列 | |
| MaxWorkers int // 最大Worker数量 | |
| MaxQueue int // 最大队列长度 | |
| HandleFunc func(http.ResponseWriter, *http.Request) // 请求处理函数 | |
| } | |
| // 创建新的WorkerPool | |
| func NewWorkerPool(maxWorkers int, maxQueue int, handleFunc func(http.ResponseWriter, *http.Request)) *WorkerPool { | |
| pool := &WorkerPool{ | |
| WorkerQueue: make(chan chan *ProxyRequest, maxWorkers), | |
| TaskQueue: make(chan *ProxyRequest, maxQueue), | |
| MaxWorkers: maxWorkers, | |
| MaxQueue: maxQueue, | |
| HandleFunc: handleFunc, | |
| } | |
| return pool | |
| } | |
| // 启动WorkerPool | |
| func (wp *WorkerPool) Start() { | |
| // 创建并启动workers | |
| for i := 0; i < wp.MaxWorkers; i++ { | |
| worker := NewWorker(i, wp) | |
| worker.Start() | |
| log.Printf("[INFO] Started worker %d", i) | |
| } | |
| // 启动任务分发协程 | |
| go wp.dispatch() | |
| } | |
| // 停止WorkerPool | |
| func (wp *WorkerPool) Stop() { | |
| // TODO: 实现停止逻辑 | |
| } | |
| // 将任务分发给空闲worker | |
| func (wp *WorkerPool) dispatch() { | |
| for { | |
| select { | |
| case task := <-wp.TaskQueue: | |
| // 等待空闲worker | |
| workerTaskQueue := <-wp.WorkerQueue | |
| // 将任务发送给worker | |
| workerTaskQueue <- task | |
| } | |
| } | |
| } | |
| // 将请求提交到WorkerPool | |
| func (wp *WorkerPool) Submit(response http.ResponseWriter, request *http.Request) bool { | |
| task := &ProxyRequest{ | |
| Request: request, | |
| Response: response, | |
| Done: make(chan bool, 1), | |
| } | |
| select { | |
| case wp.TaskQueue <- task: | |
| // 请求成功加入队列 | |
| <-task.Done // 等待任务完成 | |
| return true | |
| default: | |
| // 队列已满,实现背压 | |
| log.Println("[WARN] Task queue is full, rejecting request") | |
| http.Error(response, "Server is busy, please try again later", http.StatusServiceUnavailable) | |
| return false | |
| } | |
| } | |
| // ProxyHandler 处理 HTTP 代理请求 | |
| type ProxyHandler struct { | |
| cfg *Config // 配置信息 | |
| keyPool *KeyPool // 密钥池 | |
| client *http.Client // HTTP 客户端 | |
| workerPool *WorkerPool // 工作协程池 | |
| } | |
| // NewProxyHandler 创建 ProxyHandler 实例 | |
| func NewProxyHandler(cfg *Config, keyPool *KeyPool) *ProxyHandler { | |
| handler := &ProxyHandler{ | |
| cfg: cfg, | |
| keyPool: keyPool, | |
| client: &http.Client{}, | |
| } | |
| return handler | |
| } | |
| // InitWorkerPool 初始化工作协程池 | |
| func (ph *ProxyHandler) InitWorkerPool(maxWorkers int, maxQueue int) { | |
| ph.workerPool = NewWorkerPool(maxWorkers, maxQueue, ph.HandleRequest) | |
| ph.workerPool.Start() | |
| log.Printf("[INFO] Started worker pool with %d workers and queue size %d", maxWorkers, maxQueue) | |
| } | |
| // ServeHTTP 实现 HTTP 处理逻辑 | |
| func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |
| // 记录接收到的请求 | |
| log.Printf("[INFO] Received request: %s %s", r.Method, r.URL.String()) | |
| // 将请求提交到工作池处理 | |
| ph.workerPool.Submit(w, r) | |
| } | |
| // HandleRequest 处理请求的方法,由Worker调用 | |
| func (ph *ProxyHandler) HandleRequest(w http.ResponseWriter, r *http.Request) { | |
| // 验证客户端身份 | |
| if !ph.authenticate(r) { | |
| log.Println("[WARN] Unauthorized access attempt") | |
| http.Error(w, "Unauthorized", http.StatusUnauthorized) | |
| return | |
| } | |
| log.Println("[INFO] Authentication successful") | |
| // 尝试解析请求体中的模型信息 | |
| model, err := ph.extractModelFromRequest(r) | |
| if err != nil { | |
| log.Printf("[WARN] Failed to extract model from request: %v", err) | |
| } else if model != "" { | |
| log.Printf("[INFO] Model specified in request: %s", model) | |
| } | |
| // 构建目标 URL | |
| targetURL, err := ph.buildTargetURL(r) | |
| if err != nil { | |
| log.Printf("[ERROR] Failed to build target URL: %v", err) | |
| http.Error(w, "Bad Request", http.StatusBadRequest) | |
| return | |
| } | |
| log.Printf("[INFO] Target URL: %s", targetURL) | |
| // 重试逻辑 | |
| maxRetries := len(ph.keyPool.keys) | |
| attemptedKeys := make(map[string]bool) | |
| log.Printf("[INFO] Starting key selection process, total keys available: %d", maxRetries) | |
| for i := 0; i < maxRetries; i++ { | |
| key := ph.getUnusedKey(attemptedKeys) | |
| if key == "" { | |
| log.Printf("[ERROR] No unused keys remaining after %d attempts", i) | |
| break | |
| } | |
| attemptedKeys[key] = true | |
| maskedKey := maskKey(key) | |
| log.Printf("[INFO] Attempt %d/%d: Selecting key %s", i+1, maxRetries, maskedKey) | |
| // 创建请求 | |
| req, err := ph.createRequest(r, targetURL, key) | |
| if err != nil { | |
| log.Printf("[ERROR] Failed to create request with key %s: %v", maskedKey, err) | |
| log.Printf("[INFO] Switching to another key due to request creation failure") | |
| continue | |
| } | |
| // 发送请求 | |
| log.Printf("[INFO] Sending request to target API with key %s", maskedKey) | |
| resp, err := ph.client.Do(req) | |
| if err != nil { | |
| log.Printf("[ERROR] Failed to send request with key %s: %v", maskedKey, err) | |
| log.Printf("[INFO] Switching to another key due to network error") | |
| continue | |
| } | |
| defer resp.Body.Close() | |
| // 处理响应 | |
| log.Printf("[INFO] Received response with status code %d", resp.StatusCode) | |
| if resp.StatusCode >= 200 && resp.StatusCode < 300 { | |
| log.Println("[INFO] Request successful, forwarding response") | |
| ph.forwardResponse(w, resp) | |
| return | |
| } else if resp.StatusCode == 403 || resp.StatusCode == 429 { | |
| log.Printf("[WARN] Received %d status code with key %s", resp.StatusCode, maskedKey) | |
| log.Printf("[INFO] Switching to another key due to status code %d", resp.StatusCode) | |
| continue | |
| } else { | |
| log.Printf("[INFO] Forwarding response with status code %d", resp.StatusCode) | |
| ph.forwardResponse(w, resp) | |
| return | |
| } | |
| } | |
| // 所有密钥尝试后仍失败 | |
| log.Printf("[ERROR] All %d keys failed after retries", maxRetries) | |
| http.Error(w, "Failed to get response from API after all retries", http.StatusBadGateway) | |
| } | |
| // getUnusedKey 获取一个未使用过的密钥 | |
| func (ph *ProxyHandler) getUnusedKey(attempted map[string]bool) string { | |
| key := ph.keyPool.GetRandomKey() | |
| // 如果获取到的密钥已使用过,则尝试其他密钥 | |
| for attempted[key] && len(attempted) < len(ph.keyPool.keys) { | |
| key = ph.keyPool.GetRandomKey() | |
| } | |
| // 如果所有密钥都已尝试过,返回空字符串 | |
| if attempted[key] { | |
| return "" | |
| } | |
| return key | |
| } | |
| // authenticate 验证客户端身份 | |
| func (ph *ProxyHandler) authenticate(r *http.Request) bool { | |
| authHeader := r.Header.Get("Authorization") | |
| if authHeader == "" { | |
| return false | |
| } | |
| parts := strings.Split(authHeader, " ") | |
| if len(parts) != 2 || parts[0] != "Bearer" { | |
| return false | |
| } | |
| return parts[1] == ph.cfg.Password | |
| } | |
| // buildTargetURL 构建目标 API 的完整 URL | |
| func (ph *ProxyHandler) buildTargetURL(r *http.Request) (string, error) { | |
| u, err := url.Parse(ph.cfg.TargetURL) | |
| if err != nil { | |
| return "", err | |
| } | |
| u.Path = path.Join(u.Path, r.URL.Path) | |
| u.RawQuery = r.URL.RawQuery | |
| return u.String(), nil | |
| } | |
| // createRequest 创建转发请求 | |
| func (ph *ProxyHandler) createRequest(r *http.Request, targetURL, key string) (*http.Request, error) { | |
| req, err := http.NewRequest(r.Method, targetURL, r.Body) | |
| if err != nil { | |
| return nil, err | |
| } | |
| // 复制并修改请求头 | |
| for k, v := range r.Header { | |
| if k != "Host" && k != "Connection" && k != "Proxy-Connection" && k != "Authorization" { | |
| req.Header[k] = v | |
| } | |
| } | |
| req.Header.Set("Authorization", "Bearer "+key) | |
| return req, nil | |
| } | |
| // forwardResponse 将响应转发给客户端,支持流式和非流式 | |
| func (ph *ProxyHandler) forwardResponse(w http.ResponseWriter, resp *http.Response) { | |
| // 设置响应头 | |
| for k, v := range resp.Header { | |
| w.Header()[k] = v | |
| } | |
| w.WriteHeader(resp.StatusCode) | |
| // 处理流式响应 | |
| if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") || resp.Header.Get("Transfer-Encoding") == "chunked" { | |
| log.Println("[INFO] Handling streaming response") | |
| flusher, ok := w.(http.Flusher) | |
| if !ok { | |
| log.Println("[ERROR] Streaming unsupported by server") | |
| http.Error(w, "Streaming unsupported", http.StatusInternalServerError) | |
| return | |
| } | |
| reader := bufio.NewReader(resp.Body) | |
| for { | |
| line, err := reader.ReadBytes('\n') | |
| if err != nil { | |
| if err == io.EOF { | |
| log.Println("[INFO] Stream ended") | |
| break | |
| } | |
| log.Printf("[ERROR] Error reading stream: %v", err) | |
| http.Error(w, "Error reading stream", http.StatusInternalServerError) | |
| return | |
| } | |
| w.Write(line) | |
| flusher.Flush() | |
| } | |
| } else { | |
| // 非流式响应,直接复制 | |
| _, err := io.Copy(w, resp.Body) | |
| if err != nil { | |
| log.Printf("[ERROR] Failed to forward response: %v", err) | |
| } | |
| } | |
| } | |
| // extractModelFromRequest 尝试从请求体中提取模型名称 | |
| func (ph *ProxyHandler) extractModelFromRequest(r *http.Request) (string, error) { | |
| if r.Body == nil { | |
| return "", nil | |
| } | |
| body, err := io.ReadAll(r.Body) | |
| if err != nil { | |
| return "", err | |
| } | |
| r.Body = io.NopCloser(strings.NewReader(string(body))) | |
| var data map[string]interface{} | |
| if err := json.Unmarshal(body, &data); err != nil { | |
| return "", err | |
| } | |
| if model, ok := data["model"].(string); ok { | |
| return model, nil | |
| } | |
| return "", nil | |
| } | |
| // maskKey 直接返回原始密钥,不再进行掩码处理 | |
| func maskKey(key string) string { | |
| return key | |
| } | |
| // main 函数,启动代理服务器 | |
| func main() { | |
| // 解析配置 | |
| cfg := parseFlags() | |
| if cfg.KeyFile == "" || cfg.TargetURL == "" || cfg.Password == "" { | |
| log.Println("[ERROR] Missing required flags: --key-file, --target-url, --password") | |
| os.Exit(1) | |
| } | |
| log.Printf("[INFO] Configuration loaded: KeyFile=%s, TargetURL=%s, Address=%s, Port=%s, MaxWorkers=%d, MaxQueue=%d", | |
| cfg.KeyFile, cfg.TargetURL, cfg.Address, cfg.Port, cfg.MaxWorkers, cfg.MaxQueue) | |
| // 初始化密钥池 | |
| keyPool, err := NewKeyPool(cfg.KeyFile) | |
| if err != nil { | |
| log.Printf("[ERROR] Failed to initialize key pool: %v", err) | |
| os.Exit(1) | |
| } | |
| // 创建代理处理器 | |
| proxyHandler := NewProxyHandler(cfg, keyPool) | |
| // 初始化并启动工作池 | |
| proxyHandler.InitWorkerPool(cfg.MaxWorkers, cfg.MaxQueue) | |
| // 启动服务器 | |
| addr := cfg.Address + ":" + cfg.Port | |
| log.Printf("[INFO] Starting proxy server on %s", addr) | |
| if err := http.ListenAndServe(addr, proxyHandler); err != nil { | |
| log.Printf("[ERROR] Failed to start server: %v", err) | |
| os.Exit(1) | |
| } | |
| } |