goapi / main.go
xidu's picture
Update main.go
1bca3a1 verified
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"math/rand"
"net/http"
"os"
"strings"
"time"
"github.com/google/generative-ai-go/genai"
"github.com/rs/cors"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
// --- 配置部分 ---
// API密钥列表 (请替换为您自己的密钥)
var apiKeys = []string{
"AIzaSyBUIl9AisD8FHUn5HLQcriXZnF4n5MqnWU",
"AIzaSyAId4YPsZSTLJ5_fA5BESjYxWZBzwADTJI",
// 在此添加更多密钥
}
// 定义支持的模型信息
var supportedModels = []ModelInfo{
{
ID: "gemini-2.5-flash-preview-05-20",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Description: "Gemini 2.5 Flash Preview - 最新实验性模型",
},
{
ID: "gemini-2.5-flash",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Description: "gemini-2.5-flash稳定经典专业模型",
},
{
ID: "gemini-2.5-pro",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Description: "Gemini 2.5 Pro 专业模型",
},
}
// 将OpenAI模型名称映射到Gemini模型名称
// 根据您的要求,键和值现在是相同的,不做任何转换。
var modelMapping = map[string]string{
"gemini-2.5-flash-preview-05-20": "gemini-2.5-flash-preview-05-20",
"gemini-2.5-flash": "gemini-2.5-flash",
"gemini-2.5-pro": "gemini-2.5-pro",
}
// 配置安全设置 (全部禁用)
var safetySettings = []*genai.SafetySetting{
{
Category: genai.HarmCategoryHarassment,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryHateSpeech,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategorySexuallyExplicit,
Threshold: genai.HarmBlockNone,
},
{
Category: genai.HarmCategoryDangerousContent,
Threshold: genai.HarmBlockNone,
},
}
const maxRetries = 3
// --- 数据结构 (用于JSON序列化/反序列化) ---
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Stream bool `json:"stream"`
MaxTokens int32 `json:"max_tokens,omitempty"`
Temperature float32 `json:"temperature,omitempty"`
TopP float32 `json:"top_p,omitempty"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ChatCompletionStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
}
type StreamChoice struct {
Index int `json:"index"`
Delta ChatMessage `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ModelInfo struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
Description string `json:"description"`
}
type ModelListResponse struct {
Object string `json:"object"`
Data []ModelInfo `json:"data"`
}
// --- 核心逻辑 ---
func getRandomAPIKey() string {
if len(apiKeys) == 0 {
log.Fatal("API密钥列表为空,请在 `apiKeys` 变量中配置密钥。")
}
r := rand.New(rand.NewSource(time.Now().UnixNano()))
return apiKeys[r.Intn(len(apiKeys))]
}
// convertMessages 将OpenAI格式的消息转换为Gemini格式的历史记录和最后一个用户的提示
func convertMessages(messages []ChatMessage) (history []*genai.Content, lastPrompt []genai.Part, systemInstruction *genai.Content) {
if len(messages) == 0 {
return nil, nil, nil
}
for i, msg := range messages {
var role string
if msg.Role == "system" {
systemInstruction = &genai.Content{Parts: []genai.Part{genai.Text(msg.Content)}}
continue
}
if i == len(messages)-1 && msg.Role == "user" {
lastPrompt = append(lastPrompt, genai.Text(msg.Content))
continue
}
if msg.Role == "assistant" {
role = "model"
} else {
role = "user"
}
history = append(history, &genai.Content{
Role: role,
Parts: []genai.Part{genai.Text(msg.Content)},
})
}
return history, lastPrompt, systemInstruction
}
func chatCompletionsHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "仅支持POST方法", http.StatusMethodNotAllowed)
return
}
var req ChatCompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("解析请求体失败: %v", err), http.StatusBadRequest)
return
}
// 根据您的要求,直接使用请求中的模型名称
modelName := req.Model
log.Printf("接收到模型请求: '%s',将直接使用该名称。", modelName)
history, lastPrompt, systemInstruction := convertMessages(req.Messages)
var lastErr error
usedKeys := make(map[string]bool)
for i := 0; i < maxRetries; i++ {
ctx := context.Background()
apiKey := getRandomAPIKey()
if len(usedKeys) < len(apiKeys) {
for usedKeys[apiKey] {
apiKey = getRandomAPIKey()
}
}
usedKeys[apiKey] = true
log.Printf("尝试第 %d 次, 使用密钥: ...%s", i+1, apiKey[len(apiKey)-4:])
client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey))
if err != nil {
lastErr = fmt.Errorf("创建客户端失败: %v", err)
log.Println(lastErr)
continue
}
defer client.Close()
model := client.GenerativeModel(modelName)
model.SystemInstruction = systemInstruction
model.SafetySettings = safetySettings
model.SetTemperature(req.Temperature)
model.SetTopP(req.TopP)
if req.MaxTokens > 0 {
model.SetMaxOutputTokens(req.MaxTokens)
}
chat := model.StartChat()
chat.History = history
if req.Stream {
err = handleStream(w, ctx, chat, lastPrompt, req.Model)
} else {
err = handleNonStream(w, ctx, model, chat, lastPrompt, req.Model)
}
if err == nil {
return
}
lastErr = err
log.Printf("第 %d 次尝试失败: %v", i+1, err)
time.Sleep(1 * time.Second)
}
http.Error(w, fmt.Sprintf("所有重试均失败: %v", lastErr), http.StatusInternalServerError)
}
func handleStream(w http.ResponseWriter, ctx context.Context, chat *genai.ChatSession, prompt []genai.Part, modelID string) error {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
iter := chat.SendMessageStream(ctx, prompt...)
for {
resp, err := iter.Next()
if err == iterator.Done {
break
}
if err != nil {
return fmt.Errorf("流式生成内容失败: %v", err)
}
var contentBuilder strings.Builder
for _, part := range resp.Candidates[0].Content.Parts {
if txt, ok := part.(genai.Text); ok {
contentBuilder.WriteString(string(txt))
}
}
chunk := ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: modelID,
Choices: []StreamChoice{
{
Index: 0,
Delta: ChatMessage{
Role: "assistant",
Content: contentBuilder.String(),
},
},
},
}
var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(chunk); err != nil {
return fmt.Errorf("序列化流式块失败: %v", err)
}
fmt.Fprintf(w, "data: %s\n\n", buf.String())
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
}
finishReason := "stop"
doneChunk := ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%d-done", time.Now().Unix()),
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: modelID,
Choices: []StreamChoice{
{
Index: 0,
FinishReason: &finishReason,
},
},
}
var buf bytes.Buffer
json.NewEncoder(&buf).Encode(doneChunk)
fmt.Fprintf(w, "data: %s\n\n", buf.String())
fmt.Fprintf(w, "data: [DONE]\n\n")
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return nil
}
func handleNonStream(w http.ResponseWriter, ctx context.Context, model *genai.GenerativeModel, chat *genai.ChatSession, prompt []genai.Part, modelID string) error {
resp, err := chat.SendMessage(ctx, prompt...)
if err != nil {
return fmt.Errorf("生成内容失败: %v", err)
}
var contentBuilder strings.Builder
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
if txt, ok := part.(genai.Text); ok {
contentBuilder.WriteString(string(txt))
}
}
}
// 计算Token
var promptParts []genai.Part
for _, c := range chat.History {
promptParts = append(promptParts, c.Parts...)
}
promptParts = append(promptParts, prompt...)
promptTokenCount, err := model.CountTokens(ctx, promptParts...)
if err != nil {
return fmt.Errorf("计算prompt tokens失败: %v", err)
}
completionTokenCount, err := model.CountTokens(ctx, resp.Candidates[0].Content.Parts...)
if err != nil {
return fmt.Errorf("计算completion tokens失败: %v", err)
}
response := ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: modelID,
Choices: []Choice{
{
Index: 0,
Message: ChatMessage{
Role: "assistant",
Content: contentBuilder.String(),
},
FinishReason: "stop",
},
},
Usage: Usage{
PromptTokens: int(promptTokenCount.TotalTokens),
CompletionTokens: int(completionTokenCount.TotalTokens),
TotalTokens: int(promptTokenCount.TotalTokens) + int(completionTokenCount.TotalTokens),
},
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
func modelsHandler(w http.ResponseWriter, r *http.Request) {
resp := ModelListResponse{
Object: "list",
Data: supportedModels,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func rootHandler(w http.ResponseWriter, r *http.Request) {
info := map[string]interface{}{
"name": "Gemini Official API (Go Version)",
"version": "1.3.0",
"description": "Google Gemini官方API接口服务",
"endpoints": map[string]string{
"models": "/v1/models",
"chat": "/v1/chat/completions",
"health": "/health",
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(info)
}
func healthHandler(w http.ResponseWriter, r *http.Request) {
var modelIDs []string
for _, m := range supportedModels {
modelIDs = append(modelIDs, m.ID)
}
health := map[string]interface{}{
"status": "healthy",
"timestamp": time.Now().Unix(),
"api": "gemini-official-go",
"available_models": modelIDs,
"version": "1.3.0",
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(health)
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", rootHandler)
mux.HandleFunc("/health", healthHandler)
mux.HandleFunc("/v1/models", modelsHandler)
mux.HandleFunc("/v1/chat/completions", chatCompletionsHandler)
mux.HandleFunc("/v1/chat/completions/v1/models", modelsHandler)
c := cors.New(cors.Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "OPTIONS"},
AllowedHeaders: []string{"*"},
AllowCredentials: true,
})
handler := c.Handler(mux)
port := "7860"
log.Println("🚀 启动Gemini官方API服务器 (Go版本)")
log.Printf("📊 支持的模型: %v", func() []string {
var ids []string
for _, m := range supportedModels {
ids = append(ids, m.ID)
}
return ids
}())
log.Printf("🔑 已配置 %d 个API密钥", len(apiKeys))
log.Println("🔄 支持自动重试和密钥轮换")
log.Printf("🔗 服务器正在监听 http://0.0.0.0:%s", port)
envKey := os.Getenv("GEMINI_API_KEY")
if envKey != "" {
apiKeys = strings.Split(envKey, ",")
log.Printf("从环境变量 GEMINI_API_KEY 加载了 %d 个密钥", len(apiKeys))
}
if err := http.ListenAndServe(":"+port, handler); err != nil {
log.Fatalf("启动服务器失败: %v", err)
}
}