Spaces:
Runtime error
Runtime error
| package main | |
| import ( | |
| "bytes" | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "log" | |
| "net/http" | |
| "os" | |
| "strings" | |
| "sync" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "github.com/google/generative-ai-go/genai" | |
| "google.golang.org/api/option" | |
| ) | |
| // 配置结构 | |
| type Config struct { | |
| AnthropicKey string | |
| GoogleKey string | |
| ServiceURL string | |
| DeepseekURL string | |
| OpenAIURL string | |
| } | |
| var ( | |
| config Config | |
| configOnce sync.Once | |
| ) | |
| // 请求结构 | |
| type TokenCountRequest struct { | |
| Model string `json:"model" binding:"required"` | |
| Messages []Message `json:"messages" binding:"required"` | |
| System *string `json:"system,omitempty"` | |
| } | |
| type Message struct { | |
| Role string `json:"role" binding:"required"` | |
| Content string `json:"content" binding:"required"` | |
| } | |
| // 响应结构 | |
| type TokenCountResponse struct { | |
| InputTokens int `json:"input_tokens"` | |
| } | |
| // 错误响应结构 | |
| type ErrorResponse struct { | |
| Error string `json:"error"` | |
| } | |
| // 模型映射规则 | |
| type ModelRule struct { | |
| Keywords []string | |
| Target string | |
| } | |
| var modelRules = []ModelRule{ | |
| { | |
| Keywords: []string{"gpt"}, | |
| Target: "gpt-3.5-turbo", | |
| }, | |
| { | |
| Keywords: []string{"openai"}, | |
| Target: "gpt-3.5-turbo", | |
| }, | |
| { | |
| Keywords: []string{"deepseek"}, | |
| Target: "deepseek-v3", | |
| }, | |
| { | |
| Keywords: []string{"claude", "3", "sonnet"}, | |
| Target: "claude-3-sonnet-20240229", | |
| }, | |
| { | |
| Keywords: []string{"claude", "3", "7"}, | |
| Target: "claude-3-7-sonnet-latest", | |
| }, | |
| { | |
| Keywords: []string{"claude", "3", "5", "sonnet"}, | |
| Target: "claude-3-5-sonnet-latest", | |
| }, | |
| { | |
| Keywords: []string{"claude", "3", "5", "haiku"}, | |
| Target: "claude-3-5-haiku-latest", | |
| }, | |
| { | |
| Keywords: []string{"claude", "3", "opus"}, | |
| Target: "claude-3-opus-latest", | |
| }, | |
| { | |
| Keywords: []string{"claude", "3", "haiku"}, | |
| Target: "claude-3-haiku-20240307", | |
| }, | |
| { | |
| Keywords: []string{"gemini", "2.0"}, | |
| Target: "gemini-2.0-flash", | |
| }, | |
| { | |
| Keywords: []string{"gemini", "2.5"}, | |
| Target: "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代 | |
| }, | |
| { | |
| Keywords: []string{"gemini", "1.5"}, | |
| Target: "gemini-1.5-flash", | |
| }, | |
| } | |
| // 智能匹配模型名称 | |
| func matchModelName(input string) string { | |
| // 转换为小写进行匹配 | |
| input = strings.ToLower(input) | |
| // 特殊规则:OpenAI GPT-4o | |
| if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) || | |
| strings.Contains(input, "o1") || | |
| strings.Contains(input, "o3") { | |
| return "gpt-4o" | |
| } | |
| // 特殊规则:OpenAI GPT-4 | |
| if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) || | |
| (strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) { | |
| return "gpt-4" | |
| } | |
| // 遍历所有规则 | |
| for _, rule := range modelRules { | |
| matches := true | |
| for _, keyword := range rule.Keywords { | |
| if !strings.Contains(input, strings.ToLower(keyword)) { | |
| matches = false | |
| break | |
| } | |
| } | |
| if matches { | |
| return rule.Target | |
| } | |
| } | |
| // 如果没有匹配到,返回原始输入 | |
| return input | |
| } | |
| // 加载配置 | |
| func loadConfig() Config { | |
| configOnce.Do(func() { | |
| config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY") | |
| if config.AnthropicKey == "" { | |
| log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用") | |
| } | |
| config.GoogleKey = os.Getenv("GOOGLE_API_KEY") | |
| if config.GoogleKey == "" { | |
| log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用") | |
| } | |
| // 获取Deepseek服务URL | |
| config.DeepseekURL = os.Getenv("DEEPSEEK_URL") | |
| if config.DeepseekURL == "" { | |
| config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址 | |
| log.Println("使用默认Deepseek服务地址:", config.DeepseekURL) | |
| } | |
| // 获取OpenAI服务URL | |
| config.OpenAIURL = os.Getenv("OPENAI_URL") | |
| if config.OpenAIURL == "" { | |
| config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址 | |
| log.Println("使用默认OpenAI服务地址:", config.OpenAIURL) | |
| } | |
| // 获取服务URL,用于防休眠 | |
| config.ServiceURL = os.Getenv("SERVICE_URL") | |
| if config.ServiceURL == "" { | |
| log.Println("SERVICE_URL 未设置,防休眠功能将被禁用") | |
| } | |
| }) | |
| return config | |
| } | |
| // 使用Claude API计算token | |
| func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) { | |
| // 准备请求Anthropic API | |
| client := &http.Client{} | |
| data, err := json.Marshal(req) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) | |
| } | |
| // 创建请求 | |
| request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data)) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) | |
| } | |
| // 设置请求头 | |
| request.Header.Set("x-api-key", config.AnthropicKey) | |
| request.Header.Set("anthropic-version", "2023-06-01") | |
| request.Header.Set("content-type", "application/json") | |
| // 发送请求 | |
| response, err := client.Do(request) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err) | |
| } | |
| defer response.Body.Close() | |
| // 读取响应 | |
| var result TokenCountResponse | |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) | |
| } | |
| return result, nil | |
| } | |
| // 使用Gemini API计算token | |
| func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) { | |
| // 检查API密钥 | |
| if config.GoogleKey == "" { | |
| return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置") | |
| } | |
| // 创建Gemini客户端 | |
| ctx := context.Background() | |
| client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey)) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err) | |
| } | |
| defer client.Close() | |
| // 使用已经匹配好的模型名称 | |
| modelName := req.Model | |
| // 创建Gemini模型 | |
| model := client.GenerativeModel(modelName) | |
| // 构建提示内容 | |
| var content string | |
| if req.System != nil && *req.System != "" { | |
| content += *req.System + "\n\n" | |
| } | |
| for _, msg := range req.Messages { | |
| if msg.Role == "user" { | |
| content += "用户: " + msg.Content + "\n" | |
| } else if msg.Role == "assistant" { | |
| content += "助手: " + msg.Content + "\n" | |
| } else { | |
| content += msg.Role + ": " + msg.Content + "\n" | |
| } | |
| } | |
| // 计算token | |
| tokResp, err := model.CountTokens(ctx, genai.Text(content)) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err) | |
| } | |
| return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil | |
| } | |
| // 使用Deepseek API计算token | |
| func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) { | |
| // 准备请求 | |
| client := &http.Client{} | |
| data, err := json.Marshal(req) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) | |
| } | |
| // 创建请求 | |
| request, err := http.NewRequest("POST", config.DeepseekURL+"/count_tokens", bytes.NewBuffer(data)) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) | |
| } | |
| // 设置请求头 | |
| request.Header.Set("Content-Type", "application/json") | |
| // 发送请求 | |
| response, err := client.Do(request) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err) | |
| } | |
| defer response.Body.Close() | |
| // 读取响应 | |
| var result TokenCountResponse | |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) | |
| } | |
| return result, nil | |
| } | |
| // 使用OpenAI API计算token | |
| func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) { | |
| // 准备请求 | |
| client := &http.Client{} | |
| data, err := json.Marshal(req) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) | |
| } | |
| // 创建请求 | |
| request, err := http.NewRequest("POST", config.OpenAIURL+"/count_tokens", bytes.NewBuffer(data)) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) | |
| } | |
| // 设置请求头 | |
| request.Header.Set("Content-Type", "application/json") | |
| // 发送请求 | |
| response, err := client.Do(request) | |
| if err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err) | |
| } | |
| defer response.Body.Close() | |
| // 读取响应 | |
| var result struct { | |
| InputTokens int `json:"input_tokens"` | |
| Model string `json:"model"` | |
| Encoding string `json:"encoding"` | |
| } | |
| if err := json.NewDecoder(response.Body).Decode(&result); err != nil { | |
| return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) | |
| } | |
| return TokenCountResponse{InputTokens: result.InputTokens}, nil | |
| } | |
| // 计算token | |
| func countTokens(c *gin.Context) { | |
| var req TokenCountRequest | |
| if err := c.ShouldBindJSON(&req); err != nil { | |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()}) | |
| return | |
| } | |
| // 保存原始模型名称 | |
| originalModel := req.Model | |
| // 检查是否为不支持的模型 | |
| isUnsupportedModel := true | |
| // 检查是否为支持的模型类型 | |
| modelLower := strings.ToLower(req.Model) | |
| if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") || | |
| strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") || | |
| strings.HasPrefix(modelLower, "claude") || | |
| strings.Contains(modelLower, "gemini") || | |
| strings.Contains(modelLower, "deepseek") { | |
| isUnsupportedModel = false | |
| } | |
| // 智能匹配模型名称 | |
| req.Model = matchModelName(req.Model) | |
| var result TokenCountResponse | |
| var err error | |
| // 优先检查是否为Deepseek模型 | |
| if strings.Contains(strings.ToLower(req.Model), "deepseek") { | |
| // 使用Deepseek API | |
| result, err = countTokensWithDeepseek(req) | |
| } else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") { | |
| // 使用OpenAI API | |
| result, err = countTokensWithOpenAI(req) | |
| } else if strings.HasPrefix(strings.ToLower(req.Model), "claude") { | |
| // 使用Claude API | |
| if config.AnthropicKey == "" { | |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"}) | |
| return | |
| } | |
| result, err = countTokensWithClaude(req) | |
| } else if strings.Contains(strings.ToLower(req.Model), "gemini") { | |
| // 使用Gemini API | |
| if config.GoogleKey == "" { | |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"}) | |
| return | |
| } | |
| result, err = countTokensWithGemini(req) | |
| } else if isUnsupportedModel { | |
| // 不支持的模型,使用GPT-4o估算 | |
| // 创建新的请求,使用GPT-4o | |
| gptReq := req | |
| gptReq.Model = "gpt-4o" | |
| // 使用OpenAI API | |
| result, err = countTokensWithOpenAI(gptReq) | |
| if err == nil { | |
| // 返回估算值,但添加警告信息 | |
| c.JSON(http.StatusOK, gin.H{ | |
| "input_tokens": result.InputTokens, | |
| "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), | |
| "estimated_with": "gpt-4o", | |
| }) | |
| return | |
| } | |
| } else { | |
| // 完全不支持的情况,返回错误但仍提供估算值 | |
| // 使用GPT-4o进行估算 | |
| gptReq := req | |
| gptReq.Model = "gpt-4o" | |
| estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) | |
| if estimateErr == nil { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "input_tokens": estimatedResult.InputTokens, | |
| "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), | |
| "estimated_with": "gpt-4o", | |
| }) | |
| } else { | |
| c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)}) | |
| } | |
| return | |
| } | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()}) | |
| return | |
| } | |
| // 返回结果 | |
| c.JSON(http.StatusOK, result) | |
| } | |
| // 健康检查 | |
| func healthCheck(c *gin.Context) { | |
| c.JSON(http.StatusOK, gin.H{ | |
| "status": "healthy", | |
| "time": time.Now().Format(time.RFC3339), | |
| }) | |
| } | |
| // 防休眠任务 | |
| func startKeepAlive() { | |
| if config.ServiceURL == "" { | |
| return | |
| } | |
| healthURL := fmt.Sprintf("%s/health", config.ServiceURL) | |
| ticker := time.NewTicker(10 * time.Hour) | |
| // 立即执行一次检查 | |
| go func() { | |
| log.Printf("Starting keep-alive checks to %s", healthURL) | |
| for { | |
| resp, err := http.Get(healthURL) | |
| if err != nil { | |
| log.Printf("Keep-alive check failed: %v", err) | |
| } else { | |
| resp.Body.Close() | |
| log.Printf("Keep-alive check successful") | |
| } | |
| // 等待下一次触发 | |
| <-ticker.C | |
| } | |
| }() | |
| } | |
| func main() { | |
| // 加载配置 | |
| loadConfig() | |
| // 设置gin模式 | |
| gin.SetMode(gin.ReleaseMode) | |
| // 创建路由 | |
| r := gin.Default() | |
| // 添加中间件 | |
| r.Use(gin.Recovery()) | |
| r.Use(func(c *gin.Context) { | |
| c.Writer.Header().Set("Access-Control-Allow-Origin", "*") | |
| c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") | |
| c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type") | |
| if c.Request.Method == "OPTIONS" { | |
| c.AbortWithStatus(204) | |
| return | |
| } | |
| c.Next() | |
| }) | |
| // 路由 | |
| r.GET("/health", healthCheck) | |
| r.POST("/count_tokens", countTokens) | |
| // 获取端口 | |
| port := os.Getenv("PORT") | |
| if port == "" { | |
| port = "7860" // Hugging Face默认端口 | |
| } | |
| // 启动防休眠任务 | |
| startKeepAlive() | |
| // 启动服务器 | |
| log.Printf("Server starting on port %s", port) | |
| if err := r.Run(":" + port); err != nil { | |
| log.Fatal(err) | |
| } | |
| } | |