package main import ( "bytes" "context" "encoding/json" "fmt" "io" "log" "net/http" "os" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" ) // 配置结构 type Config struct { AnthropicKey string GoogleKey string ServiceURL string DeepseekURL string OpenAIURL string } var ( config Config configOnce sync.Once ) // 请求结构 type TokenCountRequest struct { Model string `json:"model" binding:"required"` Messages []Message `json:"messages" binding:"required"` System *string `json:"system,omitempty"` } type Message struct { Role string `json:"role" binding:"required"` Content string `json:"content" binding:"required"` } // 响应结构 type TokenCountResponse struct { InputTokens int `json:"input_tokens"` } // 错误响应结构 type ErrorResponse struct { Error string `json:"error"` } // 模型映射规则 type ModelRule struct { Keywords []string Target string } var modelRules = []ModelRule{ { Keywords: []string{"deepseek"}, Target: "deepseek-v3", }, // 先放更具体的规则 { Keywords: []string{"claude", "3", "5", "sonnet"}, Target: "claude-3-5-sonnet-latest", }, { Keywords: []string{"claude", "3", "5", "haiku"}, Target: "claude-3-5-haiku-latest", }, { Keywords: []string{"claude", "3", "7"}, Target: "claude-3-7-sonnet-latest", }, { Keywords: []string{"claude", "3", "opus"}, Target: "claude-3-opus-latest", }, { Keywords: []string{"claude", "3", "haiku"}, Target: "claude-3-haiku-20240307", }, // 再放一般规则 { Keywords: []string{"claude", "3", "sonnet"}, Target: "claude-3-sonnet-20240229", }, { Keywords: []string{"gemini", "2.0"}, Target: "gemini-2.0-flash", }, { Keywords: []string{"gemini", "2.5"}, Target: "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代 }, { Keywords: []string{"gemini", "1.5"}, Target: "gemini-1.5-flash", }, } // 智能匹配模型名称 func matchModelName(input string) string { // 转换为小写进行匹配 input = strings.ToLower(input) log.Printf("正在匹配模型名称: %s", input) // 特殊处理 Claude 3.5 系列 if strings.Contains(input, "claude") && strings.Contains(input, "3.5") || strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "5") { if strings.Contains(input, "sonnet") { log.Printf("匹配到Claude 3.5 Sonnet") return "claude-3-5-sonnet-latest" } else if strings.Contains(input, "haiku") { log.Printf("匹配到Claude 3.5 Haiku") return "claude-3-5-haiku-latest" } else { // 默认为Sonnet log.Printf("匹配到Claude 3.5 (默认使用Sonnet)") return "claude-3-5-sonnet-latest" } } // 特殊处理 Claude 3.7 系列 if strings.Contains(input, "claude") && strings.Contains(input, "3.7") || strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "7") { log.Printf("匹配到Claude 3.7") return "claude-3-7-sonnet-latest" } // 特殊规则:OpenAI GPT-4o if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) || strings.Contains(input, "o1") || strings.Contains(input, "o3") { log.Printf("匹配到GPT-4o") return "gpt-4o" } // 特殊规则:OpenAI GPT-4 if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) || (strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) { log.Printf("匹配到GPT-4") return "gpt-4" } // 遍历所有规则 for _, rule := range modelRules { matches := true for _, keyword := range rule.Keywords { if !strings.Contains(input, strings.ToLower(keyword)) { matches = false break } } if matches { log.Printf("通过规则匹配到: %s", rule.Target) return rule.Target } } // 如果没有匹配到,返回原始输入 log.Printf("没有匹配到任何规则,使用原始输入: %s", input) return input } // 加载配置 func loadConfig() Config { configOnce.Do(func() { // 配置日志格式 log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) log.Println("开始加载配置...") config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY") if config.AnthropicKey == "" { log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用") } else { log.Println("Anthropic API Key已配置") } config.GoogleKey = os.Getenv("GOOGLE_API_KEY") if config.GoogleKey == "" { log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用") } else { log.Println("Google API Key已配置") } // 获取Deepseek服务URL config.DeepseekURL = os.Getenv("DEEPSEEK_URL") if config.DeepseekURL == "" { config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址 log.Println("使用默认Deepseek服务地址:", config.DeepseekURL) } else { log.Println("使用配置的Deepseek服务地址:", config.DeepseekURL) } // 获取OpenAI服务URL config.OpenAIURL = os.Getenv("OPENAI_URL") if config.OpenAIURL == "" { config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址 log.Println("使用默认OpenAI服务地址:", config.OpenAIURL) } else { log.Println("使用配置的OpenAI服务地址:", config.OpenAIURL) } // 获取服务URL,用于防休眠 config.ServiceURL = os.Getenv("SERVICE_URL") if config.ServiceURL == "" { log.Println("SERVICE_URL 未设置,防休眠功能将被禁用") } else { log.Println("防休眠URL已配置:", config.ServiceURL) } log.Println("配置加载完成") }) return config } // 使用Claude API计算token func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) { // 准备请求Anthropic API log.Printf("开始Claude API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages)) // 验证并过滤空内容的消息 var filteredMessages []Message for i, msg := range req.Messages { if msg.Content == "" { log.Printf("警告: 消息 #%d 内容为空,将被过滤掉", i) continue // 跳过空内容消息 } if msg.Role != "user" && msg.Role != "assistant" { log.Printf("警告: 消息 #%d 角色'%s'不是标准角色(user/assistant),可能导致请求失败", i, msg.Role) } filteredMessages = append(filteredMessages, msg) } if len(filteredMessages) == 0 { log.Printf("错误: 过滤后没有有效消息") return TokenCountResponse{}, fmt.Errorf("没有有效消息:所有消息内容都为空") } // 创建新请求,使用过滤后的消息 filteredReq := TokenCountRequest{ Model: req.Model, Messages: filteredMessages, System: req.System, } // 记录过滤后的消息数量 if len(filteredMessages) != len(req.Messages) { log.Printf("消息过滤: 原始消息数=%d, 过滤后消息数=%d", len(req.Messages), len(filteredMessages)) } client := &http.Client{} data, err := json.Marshal(filteredReq) if err != nil { log.Printf("错误: 序列化Claude请求失败: %v", err) return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) } // 记录请求内容用于调试 if len(data) < 1000 { log.Printf("Claude请求内容: %s", string(data)) } else { log.Printf("Claude请求内容较大,长度=%d字节", len(data)) } // 创建请求 request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data)) if err != nil { log.Printf("错误: 创建Claude请求失败: %v", err) return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) } // 设置请求头 request.Header.Set("x-api-key", config.AnthropicKey) request.Header.Set("anthropic-version", "2023-06-01") request.Header.Set("content-type", "application/json") // 发送请求 response, err := client.Do(request) if err != nil { log.Printf("错误: 发送请求到Anthropic API失败: %v", err) return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err) } defer response.Body.Close() // 检查响应状态码 if response.StatusCode != http.StatusOK { // 读取错误响应 var errorBody []byte errorBody, _ = io.ReadAll(response.Body) log.Printf("错误: Claude API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) // 检查常见错误 errorStr := string(errorBody) if response.StatusCode == http.StatusUnauthorized || strings.Contains(errorStr, "invalid_api_key") { log.Printf("错误: Claude API密钥无效或过期") return TokenCountResponse{}, fmt.Errorf("Claude API验证失败,请检查API Key是否有效: %s", string(errorBody)) } else if response.StatusCode == http.StatusBadRequest { if strings.Contains(errorStr, "empty content") { log.Printf("错误: 请求包含空内容的消息") return TokenCountResponse{}, fmt.Errorf("请求格式错误: 消息不能有空内容: %s", string(errorBody)) } else if strings.Contains(errorStr, "invalid_request_error") { log.Printf("错误: 无效的请求格式") return TokenCountResponse{}, fmt.Errorf("无效的请求格式: %s", string(errorBody)) } } return TokenCountResponse{}, fmt.Errorf("Claude API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) } // 读取响应 var result TokenCountResponse if err := json.NewDecoder(response.Body).Decode(&result); err != nil { log.Printf("错误: 解码Claude响应失败: %v", err) return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) } log.Printf("Claude API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) return result, nil } // 使用Gemini API计算token func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) { // 检查API密钥 log.Printf("开始Gemini API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages)) if config.GoogleKey == "" { log.Printf("错误: Gemini API密钥未设置") return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置") } // 创建Gemini客户端 ctx := context.Background() client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey)) if err != nil { log.Printf("错误: 创建Gemini客户端失败: %v", err) return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err) } defer client.Close() // 使用已经匹配好的模型名称 modelName := req.Model log.Printf("使用Gemini模型: %s", modelName) // 创建Gemini模型 model := client.GenerativeModel(modelName) // 构建提示内容 var content string if req.System != nil && *req.System != "" { content += *req.System + "\n\n" log.Printf("Gemini请求包含系统提示: %s", *req.System) } for _, msg := range req.Messages { if msg.Role == "user" { content += "用户: " + msg.Content + "\n" } else if msg.Role == "assistant" { content += "助手: " + msg.Content + "\n" } else { content += msg.Role + ": " + msg.Content + "\n" } } // 计算token log.Printf("开始计算Gemini tokens...") tokResp, err := model.CountTokens(ctx, genai.Text(content)) if err != nil { log.Printf("错误: 计算Gemini token失败: %v", err) if strings.Contains(err.Error(), "invalid_api_key") || strings.Contains(err.Error(), "permission_denied") { log.Printf("错误: Gemini API密钥可能无效或过期") } return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err) } log.Printf("Gemini API请求成功: 模型=%s, 输入tokens=%d", req.Model, tokResp.TotalTokens) return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil } // 使用Deepseek API计算token func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) { log.Printf("开始Deepseek API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.DeepseekURL) // 准备请求 client := &http.Client{} data, err := json.Marshal(req) if err != nil { log.Printf("错误: 序列化Deepseek请求失败: %v", err) return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) } // 创建请求 requestURL := config.DeepseekURL + "/count_tokens" log.Printf("发送请求到Deepseek服务: %s", requestURL) request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) if err != nil { log.Printf("错误: 创建Deepseek请求失败: %v", err) return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) } // 设置请求头 request.Header.Set("Content-Type", "application/json") // 发送请求 response, err := client.Do(request) if err != nil { log.Printf("错误: 发送请求到Deepseek服务失败: %v", err) return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err) } defer response.Body.Close() // 检查响应状态码 if response.StatusCode != http.StatusOK { // 读取错误响应 var errorBody []byte errorBody, _ = io.ReadAll(response.Body) log.Printf("错误: Deepseek API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) return TokenCountResponse{}, fmt.Errorf("Deepseek API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) } // 读取响应 var result TokenCountResponse if err := json.NewDecoder(response.Body).Decode(&result); err != nil { log.Printf("错误: 解码Deepseek响应失败: %v", err) return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) } log.Printf("Deepseek API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) return result, nil } // 使用OpenAI API计算token func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) { log.Printf("开始OpenAI API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.OpenAIURL) // 准备请求 client := &http.Client{} data, err := json.Marshal(req) if err != nil { log.Printf("错误: 序列化OpenAI请求失败: %v", err) return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) } // 创建请求 requestURL := config.OpenAIURL + "/count_tokens" log.Printf("发送请求到OpenAI服务: %s", requestURL) request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) if err != nil { log.Printf("错误: 创建OpenAI请求失败: %v", err) return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) } // 设置请求头 request.Header.Set("Content-Type", "application/json") // 发送请求 response, err := client.Do(request) if err != nil { log.Printf("错误: 发送请求到OpenAI服务失败: %v", err) return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err) } defer response.Body.Close() // 检查响应状态码 if response.StatusCode != http.StatusOK { // 读取错误响应 var errorBody []byte errorBody, _ = io.ReadAll(response.Body) log.Printf("错误: OpenAI API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) return TokenCountResponse{}, fmt.Errorf("OpenAI API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) } // 读取响应 var result struct { InputTokens int `json:"input_tokens"` Model string `json:"model"` Encoding string `json:"encoding"` } if err := json.NewDecoder(response.Body).Decode(&result); err != nil { log.Printf("错误: 解码OpenAI响应失败: %v", err) return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) } log.Printf("OpenAI API请求成功: 模型=%s(实际使用=%s), 编码=%s, 输入tokens=%d", req.Model, result.Model, result.Encoding, result.InputTokens) return TokenCountResponse{InputTokens: result.InputTokens}, nil } // 计算token func countTokens(c *gin.Context) { var req TokenCountRequest if err := c.ShouldBindJSON(&req); err != nil { log.Printf("错误: 无效的请求格式: %v", err) c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()}) return } // 记录请求详情 systemPrompt := "无" if req.System != nil && *req.System != "" { systemPrompt = *req.System } log.Printf("收到token计算请求: 原始模型=%s, 消息数量=%d, 系统提示=%s", req.Model, len(req.Messages), systemPrompt) // 保存原始模型名称 originalModel := req.Model // 检查是否为不支持的模型 isUnsupportedModel := true // 检查是否为支持的模型类型 modelLower := strings.ToLower(req.Model) if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") || strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") || strings.HasPrefix(modelLower, "claude") || strings.Contains(modelLower, "gemini") || strings.Contains(modelLower, "deepseek") { isUnsupportedModel = false } // 智能匹配模型名称 req.Model = matchModelName(req.Model) log.Printf("模型名称匹配结果: 原始=%s -> 匹配=%s", originalModel, req.Model) var result TokenCountResponse var err error // 优先检查是否为Deepseek模型 if strings.Contains(strings.ToLower(req.Model), "deepseek") { log.Printf("使用Deepseek API计算token") // 使用Deepseek API result, err = countTokensWithDeepseek(req) } else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") { log.Printf("使用OpenAI API计算token") // 使用OpenAI API result, err = countTokensWithOpenAI(req) } else if strings.HasPrefix(strings.ToLower(req.Model), "claude") { log.Printf("使用Claude API计算token") // 使用Claude API if config.AnthropicKey == "" { log.Printf("错误: ANTHROPIC_API_KEY未设置") c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"}) return } result, err = countTokensWithClaude(req) } else if strings.Contains(strings.ToLower(req.Model), "gemini") { log.Printf("使用Gemini API计算token") // 使用Gemini API if config.GoogleKey == "" { log.Printf("错误: GOOGLE_API_KEY未设置") c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"}) return } result, err = countTokensWithGemini(req) } else if isUnsupportedModel { log.Printf("不支持的模型: %s, 将使用GPT-4o估算", originalModel) // 不支持的模型,使用GPT-4o估算 // 创建新的请求,使用GPT-4o gptReq := req gptReq.Model = "gpt-4o" // 使用OpenAI API estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) if estimateErr == nil { log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) // 返回估算值,但添加警告信息,使用400状态码 c.JSON(http.StatusBadRequest, gin.H{ "input_tokens": estimatedResult.InputTokens, "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), "estimated_with": "gpt-4o", "error": fmt.Sprintf("Unsupported model: %s", originalModel), }) return } else { log.Printf("使用GPT-4o估算失败: %v", estimateErr) c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("Failed to estimate tokens for unsupported model: %s", originalModel)}) return } } else { log.Printf("完全不支持的模型: %s, 将尝试使用GPT-4o估算", originalModel) // 完全不支持的情况,返回错误但仍提供估算值 // 使用GPT-4o进行估算 gptReq := req gptReq.Model = "gpt-4o" estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) if estimateErr == nil { log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) c.JSON(http.StatusBadRequest, gin.H{ "input_tokens": estimatedResult.InputTokens, "warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), "estimated_with": "gpt-4o", "error": fmt.Sprintf("Unsupported model: %s", originalModel), }) } else { log.Printf("使用GPT-4o估算失败: %v", estimateErr) c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)}) } return } if err != nil { log.Printf("计算token失败: %v", err) // 对所有API调用失败的情况尝试使用GPT-4o估算 log.Printf("API调用失败,尝试使用GPT-4o估算: 原始模型=%s, 错误=%v", req.Model, err) // 创建新的请求,使用GPT-4o gptReq := req gptReq.Model = "gpt-4o" // 使用OpenAI API进行估算 estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) if estimateErr == nil { log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) // 返回估算值,但添加警告信息和原始错误,使用400状态码 c.JSON(http.StatusBadRequest, gin.H{ "input_tokens": estimatedResult.InputTokens, "warning": fmt.Sprintf("Token calculation for model '%s' failed. This is an estimation based on gpt-4o and may not be accurate.", originalModel), "estimated_with": "gpt-4o", "error": err.Error(), }) return } else { log.Printf("使用GPT-4o估算也失败: %v", estimateErr) // 如果GPT-4o估算也失败,返回原始错误 c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()}) return } } // 返回结果 log.Printf("成功计算token: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) c.JSON(http.StatusOK, result) } // 健康检查 func healthCheck(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "status": "healthy", "time": time.Now().Format(time.RFC3339), }) } // 防休眠任务 func startKeepAlive() { if config.ServiceURL == "" { return } healthURL := fmt.Sprintf("%s/health", config.ServiceURL) ticker := time.NewTicker(10 * time.Hour) // 立即执行一次检查 go func() { log.Printf("Starting keep-alive checks to %s", healthURL) for { resp, err := http.Get(healthURL) if err != nil { log.Printf("Keep-alive check failed: %v", err) } else { resp.Body.Close() log.Printf("Keep-alive check successful") } // 等待下一次触发 <-ticker.C } }() } func main() { // 加载配置 loadConfig() log.Println("=== Token计算服务启动 ===") // 设置gin模式 gin.SetMode(gin.ReleaseMode) log.Println("设置Gin为发布模式") // 创建路由 r := gin.Default() log.Println("创建Gin路由") // 添加中间件 r.Use(gin.Recovery()) r.Use(func(c *gin.Context) { // 请求开始时间 startTime := time.Now() // 请求信息记录 log.Printf("收到请求: %s %s 来自 %s", c.Request.Method, c.Request.URL.Path, c.ClientIP()) c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type") if c.Request.Method == "OPTIONS" { c.AbortWithStatus(204) return } // 处理请求 c.Next() // 请求完成时间 endTime := time.Now() latency := endTime.Sub(startTime) // 请求结果记录 log.Printf("请求完成: %s %s 状态=%d 耗时=%v", c.Request.Method, c.Request.URL.Path, c.Writer.Status(), latency) }) // 路由 r.GET("/health", healthCheck) r.POST("/count_tokens", countTokens) log.Println("配置路由: GET /health, POST /count_tokens") // 获取端口 port := os.Getenv("PORT") if port == "" { port = "7860" // Hugging Face默认端口 log.Println("使用默认端口: 7860") } else { log.Println("使用配置端口:", port) } // 启动防休眠任务 startKeepAlive() // 启动服务器 log.Printf("=== 服务器启动在端口 %s ===", port) if err := r.Run(":" + port); err != nil { log.Fatalf("服务器启动失败: %v", err) } }