Spaces:
Sleeping
Sleeping
| package main | |
| import ( | |
| "bytes" | |
| "encoding/json" | |
| "io" | |
| "log" | |
| "net/http" | |
| "strings" | |
| ) | |
| const ( | |
| defaultPort = "8080" | |
| geminiOpenAIEndpoint = "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions" | |
| ) | |
| // RelayServer 中继服务器 | |
| type RelayServer struct { | |
| client *http.Client | |
| } | |
| // NewRelayServer 创建新的中继服务器 | |
| func NewRelayServer() *RelayServer { | |
| return &RelayServer{ | |
| client: &http.Client{}, | |
| } | |
| } | |
| // filterRequest 过滤掉Gemini不支持的参数 | |
| func filterRequest(body []byte) ([]byte, error) { | |
| var requestData map[string]interface{} | |
| if err := json.Unmarshal(body, &requestData); err != nil { | |
| return body, nil // 如果解析失败,返回原始数据 | |
| } | |
| // Gemini不支持的OpenAI参数列表 | |
| unsupportedParams := []string{ | |
| "frequency_penalty", | |
| "presence_penalty", | |
| "logit_bias", | |
| "user", | |
| "n", | |
| "stop", | |
| "suffix", | |
| "logprobs", | |
| "echo", | |
| "best_of", | |
| "response_format", | |
| "seed", | |
| "tools", | |
| "tool_choice", | |
| "parallel_tool_calls", | |
| } | |
| // 删除不支持的参数 | |
| for _, param := range unsupportedParams { | |
| delete(requestData, param) | |
| } | |
| // 重新序列化 | |
| return json.Marshal(requestData) | |
| } | |
| // handleRequest 处理所有的API请求 | |
| func (s *RelayServer) handleRequest(w http.ResponseWriter, r *http.Request) { | |
| // 检查是否有Authorization头 | |
| authHeader := r.Header.Get("Authorization") | |
| if authHeader == "" { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusUnauthorized) | |
| w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`)) | |
| return | |
| } | |
| // 读取请求体 | |
| bodyBytes, err := io.ReadAll(r.Body) | |
| if err != nil { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusBadRequest) | |
| w.Write([]byte(`{"error": {"message": "Failed to read request body", "type": "invalid_request_error"}}`)) | |
| return | |
| } | |
| defer r.Body.Close() | |
| // 打印原始请求(调试用) | |
| log.Printf("Original request body: %s", string(bodyBytes)) | |
| // 过滤请求参数 | |
| filteredBody, err := filterRequest(bodyBytes) | |
| if err != nil { | |
| log.Printf("Failed to filter request: %v", err) | |
| filteredBody = bodyBytes // 使用原始数据 | |
| } | |
| log.Printf("Filtered request body: %s", string(filteredBody)) | |
| // 创建新的请求 | |
| proxyReq, err := http.NewRequest("POST", geminiOpenAIEndpoint, bytes.NewReader(filteredBody)) | |
| if err != nil { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusInternalServerError) | |
| w.Write([]byte(`{"error": {"message": "Failed to create proxy request", "type": "server_error"}}`)) | |
| return | |
| } | |
| // 复制所有请求头 | |
| for name, values := range r.Header { | |
| // 跳过Host和Content-Length,这些会自动设置 | |
| if name == "Host" || name == "Content-Length" { | |
| continue | |
| } | |
| for _, value := range values { | |
| proxyReq.Header.Add(name, value) | |
| } | |
| } | |
| // 确保Authorization头被正确设置 | |
| proxyReq.Header.Set("Authorization", authHeader) | |
| proxyReq.Header.Set("Content-Type", "application/json") | |
| log.Printf("Request headers being sent to Gemini: %v", proxyReq.Header) | |
| // 发送请求 | |
| resp, err := s.client.Do(proxyReq) | |
| if err != nil { | |
| log.Printf("Failed to send request to Gemini: %v", err) | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusBadGateway) | |
| w.Write([]byte(`{"error": {"message": "Failed to connect to Gemini API", "type": "server_error"}}`)) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| // 打印响应状态(调试用) | |
| log.Printf("Response status from Gemini: %d", resp.StatusCode) | |
| // 复制响应头 | |
| for name, values := range resp.Header { | |
| // 跳过一些头部 | |
| if name == "Content-Length" { | |
| continue | |
| } | |
| for _, value := range values { | |
| w.Header().Add(name, value) | |
| } | |
| } | |
| // 添加CORS头 | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| // 设置状态码 | |
| w.WriteHeader(resp.StatusCode) | |
| // 处理流式响应 | |
| if strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream") { | |
| // 确保流式响应的头部设置正确 | |
| w.Header().Set("Cache-Control", "no-cache") | |
| w.Header().Set("Connection", "keep-alive") | |
| // 使用缓冲区进行流式传输 | |
| buf := make([]byte, 1024) | |
| for { | |
| n, err := resp.Body.Read(buf) | |
| if n > 0 { | |
| if _, writeErr := w.Write(buf[:n]); writeErr != nil { | |
| log.Printf("Error writing response: %v", writeErr) | |
| return | |
| } | |
| if flusher, ok := w.(http.Flusher); ok { | |
| flusher.Flush() | |
| } | |
| } | |
| if err != nil { | |
| if err != io.EOF { | |
| log.Printf("Error reading response: %v", err) | |
| } | |
| break | |
| } | |
| } | |
| } else { | |
| // 非流式响应 | |
| // 如果是错误响应,打印出来以便调试 | |
| if resp.StatusCode >= 400 { | |
| bodyBytes, _ := io.ReadAll(resp.Body) | |
| log.Printf("Error response from Gemini: %s", string(bodyBytes)) | |
| w.Write(bodyBytes) | |
| } else { | |
| io.Copy(w, resp.Body) | |
| } | |
| } | |
| } | |
| // handleHealth 健康检查端点 | |
| func (s *RelayServer) handleHealth(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusOK) | |
| w.Write([]byte(`{"status": "ok", "service": "gemini-relay"}`)) | |
| } | |
| // corsMiddleware CORS中间件 | |
| func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { | |
| return func(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE") | |
| w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-Requested-With") | |
| w.Header().Set("Access-Control-Max-Age", "86400") | |
| if r.Method == "OPTIONS" { | |
| w.WriteHeader(http.StatusOK) | |
| return | |
| } | |
| next(w, r) | |
| } | |
| } | |
| // loggingMiddleware 日志中间件 | |
| func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { | |
| return func(w http.ResponseWriter, r *http.Request) { | |
| log.Printf("[%s] %s %s %s", r.RemoteAddr, r.Method, r.URL.Path, r.UserAgent()) | |
| next(w, r) | |
| } | |
| } | |
| // 模型映射 | |
| func (s *RelayServer) handleModels(w http.ResponseWriter, r *http.Request) { | |
| // 检查Authorization | |
| if r.Header.Get("Authorization") == "" { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusUnauthorized) | |
| w.Write([]byte(`{"error": {"message": "Missing Authorization header", "type": "invalid_request_error"}}`)) | |
| return | |
| } | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(http.StatusOK) | |
| models := `{ | |
| "object": "list", | |
| "data": [ | |
| { | |
| "id": "gemini-1.5-pro", | |
| "object": "model", | |
| "created": 1686935002, | |
| "owned_by": "google" | |
| }, | |
| { | |
| "id": "gemini-1.5-flash", | |
| "object": "model", | |
| "created": 1686935002, | |
| "owned_by": "google" | |
| }, | |
| { | |
| "id": "gemini-1.5-flash-8b", | |
| "object": "model", | |
| "created": 1686935002, | |
| "owned_by": "google" | |
| }, | |
| { | |
| "id": "gemini-2.0-flash-exp", | |
| "object": "model", | |
| "created": 1686935002, | |
| "owned_by": "google" | |
| } | |
| ] | |
| }` | |
| w.Write([]byte(models)) | |
| } | |
| func main() { | |
| // 创建中继服务器 | |
| server := NewRelayServer() | |
| // 设置路由 | |
| mux := http.NewServeMux() | |
| // OpenAI兼容的端点 | |
| mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest))) | |
| mux.HandleFunc("/chat/completions", corsMiddleware(loggingMiddleware(server.handleRequest))) | |
| mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(server.handleModels))) | |
| mux.HandleFunc("/models", corsMiddleware(loggingMiddleware(server.handleModels))) | |
| // 健康检查 | |
| mux.HandleFunc("/health", corsMiddleware(server.handleHealth)) | |
| mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Content-Type", "application/json") | |
| w.Write([]byte(`{ | |
| "service": "Gemini API Relay", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "chat": "/v1/chat/completions", | |
| "models": "/v1/models", | |
| "health": "/health" | |
| }, | |
| "supported_models": [ | |
| "gemini-1.5-pro", | |
| "gemini-1.5-flash", | |
| "gemini-1.5-flash-8b", | |
| "gemini-2.0-flash-exp" | |
| ], | |
| "note": "Use Authorization header with 'Bearer YOUR_GEMINI_API_KEY'" | |
| }`)) | |
| }) | |
| // 启动服务器 | |
| port := defaultPort | |
| log.Printf("========================================") | |
| log.Printf("Gemini API Relay Server") | |
| log.Printf("Port: %s", port) | |
| log.Printf("Endpoint: %s", geminiOpenAIEndpoint) | |
| log.Printf("========================================") | |
| log.Printf("Usage:") | |
| log.Printf(" Authorization: Bearer YOUR_GEMINI_API_KEY") | |
| log.Printf("========================================") | |
| if err := http.ListenAndServe(":"+port, mux); err != nil { | |
| log.Fatalf("Server failed to start: %v", err) | |
| } | |
| } |