|
|
package main |
|
|
|
|
|
import ( |
|
|
"bytes" |
|
|
"context" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"log" |
|
|
"net/http" |
|
|
"os" |
|
|
"strings" |
|
|
"sync" |
|
|
"time" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
"github.com/google/generative-ai-go/genai" |
|
|
"google.golang.org/api/option" |
|
|
) |
|
|
|
|
|
|
|
|
type Config struct { |
|
|
AnthropicKey string |
|
|
GoogleKey string |
|
|
ServiceURL string |
|
|
DeepseekURL string |
|
|
OpenAIURL string |
|
|
} |
|
|
|
|
|
var ( |
|
|
config Config |
|
|
configOnce sync.Once |
|
|
) |
|
|
|
|
|
|
|
|
type TokenCountRequest struct { |
|
|
Model string `json:"model" binding:"required"` |
|
|
Messages []Message `json:"messages" binding:"required"` |
|
|
System *string `json:"system,omitempty"` |
|
|
} |
|
|
|
|
|
type Message struct { |
|
|
Role string `json:"role" binding:"required"` |
|
|
Content string `json:"content" binding:"required"` |
|
|
} |
|
|
|
|
|
|
|
|
type TokenCountResponse struct { |
|
|
InputTokens int `json:"input_tokens"` |
|
|
} |
|
|
|
|
|
|
|
|
type ErrorResponse struct { |
|
|
Error string `json:"error"` |
|
|
} |
|
|
|
|
|
|
|
|
type ModelRule struct { |
|
|
Keywords []string |
|
|
Target string |
|
|
} |
|
|
|
|
|
var modelRules = []ModelRule{ |
|
|
{ |
|
|
Keywords: []string{"deepseek"}, |
|
|
Target: "deepseek-v3", |
|
|
}, |
|
|
|
|
|
{ |
|
|
Keywords: []string{"claude", "3", "5", "sonnet"}, |
|
|
Target: "claude-3-5-sonnet-latest", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"claude", "3", "5", "haiku"}, |
|
|
Target: "claude-3-5-haiku-latest", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"claude", "3", "7"}, |
|
|
Target: "claude-3-7-sonnet-latest", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"claude", "3", "opus"}, |
|
|
Target: "claude-3-opus-latest", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"claude", "3", "haiku"}, |
|
|
Target: "claude-3-haiku-20240307", |
|
|
}, |
|
|
|
|
|
{ |
|
|
Keywords: []string{"claude", "3", "sonnet"}, |
|
|
Target: "claude-3-sonnet-20240229", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"gemini", "2.0"}, |
|
|
Target: "gemini-2.0-flash", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"gemini", "2.5"}, |
|
|
Target: "gemini-2.0-flash", |
|
|
}, |
|
|
{ |
|
|
Keywords: []string{"gemini", "1.5"}, |
|
|
Target: "gemini-1.5-flash", |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
func matchModelName(input string) string { |
|
|
|
|
|
input = strings.ToLower(input) |
|
|
log.Printf("正在匹配模型名称: %s", input) |
|
|
|
|
|
|
|
|
if strings.Contains(input, "claude") && strings.Contains(input, "3.5") || |
|
|
strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "5") { |
|
|
if strings.Contains(input, "sonnet") { |
|
|
log.Printf("匹配到Claude 3.5 Sonnet") |
|
|
return "claude-3-5-sonnet-latest" |
|
|
} else if strings.Contains(input, "haiku") { |
|
|
log.Printf("匹配到Claude 3.5 Haiku") |
|
|
return "claude-3-5-haiku-latest" |
|
|
} else { |
|
|
|
|
|
log.Printf("匹配到Claude 3.5 (默认使用Sonnet)") |
|
|
return "claude-3-5-sonnet-latest" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if strings.Contains(input, "claude") && strings.Contains(input, "3.7") || |
|
|
strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "7") { |
|
|
log.Printf("匹配到Claude 3.7") |
|
|
return "claude-3-7-sonnet-latest" |
|
|
} |
|
|
|
|
|
|
|
|
if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) || |
|
|
strings.Contains(input, "o1") || |
|
|
strings.Contains(input, "o3") { |
|
|
log.Printf("匹配到GPT-4o") |
|
|
return "gpt-4o" |
|
|
} |
|
|
|
|
|
|
|
|
if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) || |
|
|
(strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) { |
|
|
log.Printf("匹配到GPT-4") |
|
|
return "gpt-4" |
|
|
} |
|
|
|
|
|
|
|
|
for _, rule := range modelRules { |
|
|
matches := true |
|
|
for _, keyword := range rule.Keywords { |
|
|
if !strings.Contains(input, strings.ToLower(keyword)) { |
|
|
matches = false |
|
|
break |
|
|
} |
|
|
} |
|
|
if matches { |
|
|
log.Printf("通过规则匹配到: %s", rule.Target) |
|
|
return rule.Target |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
log.Printf("没有匹配到任何规则,使用原始输入: %s", input) |
|
|
return input |
|
|
} |
|
|
|
|
|
|
|
|
func loadConfig() Config { |
|
|
configOnce.Do(func() { |
|
|
|
|
|
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) |
|
|
log.Println("开始加载配置...") |
|
|
|
|
|
config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY") |
|
|
if config.AnthropicKey == "" { |
|
|
log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用") |
|
|
} else { |
|
|
log.Println("Anthropic API Key已配置") |
|
|
} |
|
|
|
|
|
config.GoogleKey = os.Getenv("GOOGLE_API_KEY") |
|
|
if config.GoogleKey == "" { |
|
|
log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用") |
|
|
} else { |
|
|
log.Println("Google API Key已配置") |
|
|
} |
|
|
|
|
|
|
|
|
config.DeepseekURL = os.Getenv("DEEPSEEK_URL") |
|
|
if config.DeepseekURL == "" { |
|
|
config.DeepseekURL = "http://127.0.0.1:7861" |
|
|
log.Println("使用默认Deepseek服务地址:", config.DeepseekURL) |
|
|
} else { |
|
|
log.Println("使用配置的Deepseek服务地址:", config.DeepseekURL) |
|
|
} |
|
|
|
|
|
|
|
|
config.OpenAIURL = os.Getenv("OPENAI_URL") |
|
|
if config.OpenAIURL == "" { |
|
|
config.OpenAIURL = "http://127.0.0.1:7862" |
|
|
log.Println("使用默认OpenAI服务地址:", config.OpenAIURL) |
|
|
} else { |
|
|
log.Println("使用配置的OpenAI服务地址:", config.OpenAIURL) |
|
|
} |
|
|
|
|
|
|
|
|
config.ServiceURL = os.Getenv("SERVICE_URL") |
|
|
if config.ServiceURL == "" { |
|
|
log.Println("SERVICE_URL 未设置,防休眠功能将被禁用") |
|
|
} else { |
|
|
log.Println("防休眠URL已配置:", config.ServiceURL) |
|
|
} |
|
|
|
|
|
log.Println("配置加载完成") |
|
|
}) |
|
|
return config |
|
|
} |
|
|
|
|
|
|
|
|
func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) { |
|
|
|
|
|
log.Printf("开始Claude API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages)) |
|
|
|
|
|
|
|
|
var filteredMessages []Message |
|
|
for i, msg := range req.Messages { |
|
|
if msg.Content == "" { |
|
|
log.Printf("警告: 消息 #%d 内容为空,将被过滤掉", i) |
|
|
continue |
|
|
} |
|
|
if msg.Role != "user" && msg.Role != "assistant" { |
|
|
log.Printf("警告: 消息 #%d 角色'%s'不是标准角色(user/assistant),可能导致请求失败", i, msg.Role) |
|
|
} |
|
|
filteredMessages = append(filteredMessages, msg) |
|
|
} |
|
|
|
|
|
if len(filteredMessages) == 0 { |
|
|
log.Printf("错误: 过滤后没有有效消息") |
|
|
return TokenCountResponse{}, fmt.Errorf("没有有效消息:所有消息内容都为空") |
|
|
} |
|
|
|
|
|
|
|
|
filteredReq := TokenCountRequest{ |
|
|
Model: req.Model, |
|
|
Messages: filteredMessages, |
|
|
System: req.System, |
|
|
} |
|
|
|
|
|
|
|
|
if len(filteredMessages) != len(req.Messages) { |
|
|
log.Printf("消息过滤: 原始消息数=%d, 过滤后消息数=%d", len(req.Messages), len(filteredMessages)) |
|
|
} |
|
|
|
|
|
client := &http.Client{} |
|
|
data, err := json.Marshal(filteredReq) |
|
|
if err != nil { |
|
|
log.Printf("错误: 序列化Claude请求失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
if len(data) < 1000 { |
|
|
log.Printf("Claude请求内容: %s", string(data)) |
|
|
} else { |
|
|
log.Printf("Claude请求内容较大,长度=%d字节", len(data)) |
|
|
} |
|
|
|
|
|
|
|
|
request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data)) |
|
|
if err != nil { |
|
|
log.Printf("错误: 创建Claude请求失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
request.Header.Set("x-api-key", config.AnthropicKey) |
|
|
request.Header.Set("anthropic-version", "2023-06-01") |
|
|
request.Header.Set("content-type", "application/json") |
|
|
|
|
|
|
|
|
response, err := client.Do(request) |
|
|
if err != nil { |
|
|
log.Printf("错误: 发送请求到Anthropic API失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err) |
|
|
} |
|
|
defer response.Body.Close() |
|
|
|
|
|
|
|
|
if response.StatusCode != http.StatusOK { |
|
|
|
|
|
var errorBody []byte |
|
|
errorBody, _ = io.ReadAll(response.Body) |
|
|
log.Printf("错误: Claude API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) |
|
|
|
|
|
|
|
|
errorStr := string(errorBody) |
|
|
if response.StatusCode == http.StatusUnauthorized || strings.Contains(errorStr, "invalid_api_key") { |
|
|
log.Printf("错误: Claude API密钥无效或过期") |
|
|
return TokenCountResponse{}, fmt.Errorf("Claude API验证失败,请检查API Key是否有效: %s", string(errorBody)) |
|
|
} else if response.StatusCode == http.StatusBadRequest { |
|
|
if strings.Contains(errorStr, "empty content") { |
|
|
log.Printf("错误: 请求包含空内容的消息") |
|
|
return TokenCountResponse{}, fmt.Errorf("请求格式错误: 消息不能有空内容: %s", string(errorBody)) |
|
|
} else if strings.Contains(errorStr, "invalid_request_error") { |
|
|
log.Printf("错误: 无效的请求格式") |
|
|
return TokenCountResponse{}, fmt.Errorf("无效的请求格式: %s", string(errorBody)) |
|
|
} |
|
|
} |
|
|
|
|
|
return TokenCountResponse{}, fmt.Errorf("Claude API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) |
|
|
} |
|
|
|
|
|
|
|
|
var result TokenCountResponse |
|
|
if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
|
|
log.Printf("错误: 解码Claude响应失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) |
|
|
} |
|
|
|
|
|
log.Printf("Claude API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) |
|
|
return result, nil |
|
|
} |
|
|
|
|
|
|
|
|
func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) { |
|
|
|
|
|
log.Printf("开始Gemini API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages)) |
|
|
|
|
|
if config.GoogleKey == "" { |
|
|
log.Printf("错误: Gemini API密钥未设置") |
|
|
return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置") |
|
|
} |
|
|
|
|
|
|
|
|
ctx := context.Background() |
|
|
client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey)) |
|
|
if err != nil { |
|
|
log.Printf("错误: 创建Gemini客户端失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err) |
|
|
} |
|
|
defer client.Close() |
|
|
|
|
|
|
|
|
modelName := req.Model |
|
|
log.Printf("使用Gemini模型: %s", modelName) |
|
|
|
|
|
|
|
|
model := client.GenerativeModel(modelName) |
|
|
|
|
|
|
|
|
var content string |
|
|
if req.System != nil && *req.System != "" { |
|
|
content += *req.System + "\n\n" |
|
|
log.Printf("Gemini请求包含系统提示: %s", *req.System) |
|
|
} |
|
|
|
|
|
for _, msg := range req.Messages { |
|
|
if msg.Role == "user" { |
|
|
content += "用户: " + msg.Content + "\n" |
|
|
} else if msg.Role == "assistant" { |
|
|
content += "助手: " + msg.Content + "\n" |
|
|
} else { |
|
|
content += msg.Role + ": " + msg.Content + "\n" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
log.Printf("开始计算Gemini tokens...") |
|
|
tokResp, err := model.CountTokens(ctx, genai.Text(content)) |
|
|
if err != nil { |
|
|
log.Printf("错误: 计算Gemini token失败: %v", err) |
|
|
if strings.Contains(err.Error(), "invalid_api_key") || strings.Contains(err.Error(), "permission_denied") { |
|
|
log.Printf("错误: Gemini API密钥可能无效或过期") |
|
|
} |
|
|
return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err) |
|
|
} |
|
|
|
|
|
log.Printf("Gemini API请求成功: 模型=%s, 输入tokens=%d", req.Model, tokResp.TotalTokens) |
|
|
return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil |
|
|
} |
|
|
|
|
|
|
|
|
func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) { |
|
|
log.Printf("开始Deepseek API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.DeepseekURL) |
|
|
|
|
|
|
|
|
client := &http.Client{} |
|
|
data, err := json.Marshal(req) |
|
|
if err != nil { |
|
|
log.Printf("错误: 序列化Deepseek请求失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
requestURL := config.DeepseekURL + "/count_tokens" |
|
|
log.Printf("发送请求到Deepseek服务: %s", requestURL) |
|
|
request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) |
|
|
if err != nil { |
|
|
log.Printf("错误: 创建Deepseek请求失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
request.Header.Set("Content-Type", "application/json") |
|
|
|
|
|
|
|
|
response, err := client.Do(request) |
|
|
if err != nil { |
|
|
log.Printf("错误: 发送请求到Deepseek服务失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err) |
|
|
} |
|
|
defer response.Body.Close() |
|
|
|
|
|
|
|
|
if response.StatusCode != http.StatusOK { |
|
|
|
|
|
var errorBody []byte |
|
|
errorBody, _ = io.ReadAll(response.Body) |
|
|
log.Printf("错误: Deepseek API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) |
|
|
return TokenCountResponse{}, fmt.Errorf("Deepseek API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) |
|
|
} |
|
|
|
|
|
|
|
|
var result TokenCountResponse |
|
|
if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
|
|
log.Printf("错误: 解码Deepseek响应失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) |
|
|
} |
|
|
|
|
|
log.Printf("Deepseek API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) |
|
|
return result, nil |
|
|
} |
|
|
|
|
|
|
|
|
func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) { |
|
|
log.Printf("开始OpenAI API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.OpenAIURL) |
|
|
|
|
|
|
|
|
client := &http.Client{} |
|
|
data, err := json.Marshal(req) |
|
|
if err != nil { |
|
|
log.Printf("错误: 序列化OpenAI请求失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
requestURL := config.OpenAIURL + "/count_tokens" |
|
|
log.Printf("发送请求到OpenAI服务: %s", requestURL) |
|
|
request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data)) |
|
|
if err != nil { |
|
|
log.Printf("错误: 创建OpenAI请求失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err) |
|
|
} |
|
|
|
|
|
|
|
|
request.Header.Set("Content-Type", "application/json") |
|
|
|
|
|
|
|
|
response, err := client.Do(request) |
|
|
if err != nil { |
|
|
log.Printf("错误: 发送请求到OpenAI服务失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err) |
|
|
} |
|
|
defer response.Body.Close() |
|
|
|
|
|
|
|
|
if response.StatusCode != http.StatusOK { |
|
|
|
|
|
var errorBody []byte |
|
|
errorBody, _ = io.ReadAll(response.Body) |
|
|
log.Printf("错误: OpenAI API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody)) |
|
|
return TokenCountResponse{}, fmt.Errorf("OpenAI API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody)) |
|
|
} |
|
|
|
|
|
|
|
|
var result struct { |
|
|
InputTokens int `json:"input_tokens"` |
|
|
Model string `json:"model"` |
|
|
Encoding string `json:"encoding"` |
|
|
} |
|
|
if err := json.NewDecoder(response.Body).Decode(&result); err != nil { |
|
|
log.Printf("错误: 解码OpenAI响应失败: %v", err) |
|
|
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err) |
|
|
} |
|
|
|
|
|
log.Printf("OpenAI API请求成功: 模型=%s(实际使用=%s), 编码=%s, 输入tokens=%d", |
|
|
req.Model, result.Model, result.Encoding, result.InputTokens) |
|
|
return TokenCountResponse{InputTokens: result.InputTokens}, nil |
|
|
} |
|
|
|
|
|
|
|
|
func countTokens(c *gin.Context) { |
|
|
var req TokenCountRequest |
|
|
if err := c.ShouldBindJSON(&req); err != nil { |
|
|
log.Printf("错误: 无效的请求格式: %v", err) |
|
|
c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
systemPrompt := "无" |
|
|
if req.System != nil && *req.System != "" { |
|
|
systemPrompt = *req.System |
|
|
} |
|
|
log.Printf("收到token计算请求: 原始模型=%s, 消息数量=%d, 系统提示=%s", |
|
|
req.Model, len(req.Messages), systemPrompt) |
|
|
|
|
|
|
|
|
originalModel := req.Model |
|
|
|
|
|
|
|
|
isUnsupportedModel := true |
|
|
|
|
|
|
|
|
modelLower := strings.ToLower(req.Model) |
|
|
if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") || |
|
|
strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") || |
|
|
strings.HasPrefix(modelLower, "claude") || |
|
|
strings.Contains(modelLower, "gemini") || |
|
|
strings.Contains(modelLower, "deepseek") { |
|
|
isUnsupportedModel = false |
|
|
} |
|
|
|
|
|
|
|
|
req.Model = matchModelName(req.Model) |
|
|
log.Printf("模型名称匹配结果: 原始=%s -> 匹配=%s", originalModel, req.Model) |
|
|
|
|
|
var result TokenCountResponse |
|
|
var err error |
|
|
|
|
|
|
|
|
if strings.Contains(strings.ToLower(req.Model), "deepseek") { |
|
|
log.Printf("使用Deepseek API计算token") |
|
|
|
|
|
result, err = countTokensWithDeepseek(req) |
|
|
} else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") { |
|
|
log.Printf("使用OpenAI API计算token") |
|
|
|
|
|
result, err = countTokensWithOpenAI(req) |
|
|
} else if strings.HasPrefix(strings.ToLower(req.Model), "claude") { |
|
|
log.Printf("使用Claude API计算token") |
|
|
|
|
|
if config.AnthropicKey == "" { |
|
|
log.Printf("错误: ANTHROPIC_API_KEY未设置") |
|
|
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"}) |
|
|
return |
|
|
} |
|
|
result, err = countTokensWithClaude(req) |
|
|
} else if strings.Contains(strings.ToLower(req.Model), "gemini") { |
|
|
log.Printf("使用Gemini API计算token") |
|
|
|
|
|
if config.GoogleKey == "" { |
|
|
log.Printf("错误: GOOGLE_API_KEY未设置") |
|
|
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"}) |
|
|
return |
|
|
} |
|
|
result, err = countTokensWithGemini(req) |
|
|
} else if isUnsupportedModel { |
|
|
log.Printf("不支持的模型: %s, 将使用GPT-4o估算", originalModel) |
|
|
|
|
|
|
|
|
gptReq := req |
|
|
gptReq.Model = "gpt-4o" |
|
|
|
|
|
|
|
|
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) |
|
|
|
|
|
if estimateErr == nil { |
|
|
log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) |
|
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{ |
|
|
"input_tokens": estimatedResult.InputTokens, |
|
|
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), |
|
|
"estimated_with": "gpt-4o", |
|
|
"error": fmt.Sprintf("Unsupported model: %s", originalModel), |
|
|
}) |
|
|
return |
|
|
} else { |
|
|
log.Printf("使用GPT-4o估算失败: %v", estimateErr) |
|
|
c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("Failed to estimate tokens for unsupported model: %s", originalModel)}) |
|
|
return |
|
|
} |
|
|
} else { |
|
|
log.Printf("完全不支持的模型: %s, 将尝试使用GPT-4o估算", originalModel) |
|
|
|
|
|
|
|
|
gptReq := req |
|
|
gptReq.Model = "gpt-4o" |
|
|
|
|
|
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) |
|
|
if estimateErr == nil { |
|
|
log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) |
|
|
c.JSON(http.StatusBadRequest, gin.H{ |
|
|
"input_tokens": estimatedResult.InputTokens, |
|
|
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel), |
|
|
"estimated_with": "gpt-4o", |
|
|
"error": fmt.Sprintf("Unsupported model: %s", originalModel), |
|
|
}) |
|
|
} else { |
|
|
log.Printf("使用GPT-4o估算失败: %v", estimateErr) |
|
|
c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)}) |
|
|
} |
|
|
return |
|
|
} |
|
|
|
|
|
if err != nil { |
|
|
log.Printf("计算token失败: %v", err) |
|
|
|
|
|
|
|
|
log.Printf("API调用失败,尝试使用GPT-4o估算: 原始模型=%s, 错误=%v", req.Model, err) |
|
|
|
|
|
|
|
|
gptReq := req |
|
|
gptReq.Model = "gpt-4o" |
|
|
|
|
|
|
|
|
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq) |
|
|
|
|
|
if estimateErr == nil { |
|
|
log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens) |
|
|
|
|
|
|
|
|
c.JSON(http.StatusBadRequest, gin.H{ |
|
|
"input_tokens": estimatedResult.InputTokens, |
|
|
"warning": fmt.Sprintf("Token calculation for model '%s' failed. This is an estimation based on gpt-4o and may not be accurate.", originalModel), |
|
|
"estimated_with": "gpt-4o", |
|
|
"error": err.Error(), |
|
|
}) |
|
|
return |
|
|
} else { |
|
|
log.Printf("使用GPT-4o估算也失败: %v", estimateErr) |
|
|
|
|
|
c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()}) |
|
|
return |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
log.Printf("成功计算token: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens) |
|
|
c.JSON(http.StatusOK, result) |
|
|
} |
|
|
|
|
|
|
|
|
func healthCheck(c *gin.Context) { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": "healthy", |
|
|
"time": time.Now().Format(time.RFC3339), |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func startKeepAlive() { |
|
|
if config.ServiceURL == "" { |
|
|
return |
|
|
} |
|
|
|
|
|
healthURL := fmt.Sprintf("%s/health", config.ServiceURL) |
|
|
ticker := time.NewTicker(10 * time.Hour) |
|
|
|
|
|
|
|
|
go func() { |
|
|
log.Printf("Starting keep-alive checks to %s", healthURL) |
|
|
for { |
|
|
resp, err := http.Get(healthURL) |
|
|
if err != nil { |
|
|
log.Printf("Keep-alive check failed: %v", err) |
|
|
} else { |
|
|
resp.Body.Close() |
|
|
log.Printf("Keep-alive check successful") |
|
|
} |
|
|
|
|
|
|
|
|
<-ticker.C |
|
|
} |
|
|
}() |
|
|
} |
|
|
|
|
|
func main() { |
|
|
|
|
|
loadConfig() |
|
|
log.Println("=== Token计算服务启动 ===") |
|
|
|
|
|
|
|
|
gin.SetMode(gin.ReleaseMode) |
|
|
log.Println("设置Gin为发布模式") |
|
|
|
|
|
|
|
|
r := gin.Default() |
|
|
log.Println("创建Gin路由") |
|
|
|
|
|
|
|
|
r.Use(gin.Recovery()) |
|
|
r.Use(func(c *gin.Context) { |
|
|
|
|
|
startTime := time.Now() |
|
|
|
|
|
|
|
|
log.Printf("收到请求: %s %s 来自 %s", c.Request.Method, c.Request.URL.Path, c.ClientIP()) |
|
|
|
|
|
c.Writer.Header().Set("Access-Control-Allow-Origin", "*") |
|
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS") |
|
|
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type") |
|
|
if c.Request.Method == "OPTIONS" { |
|
|
c.AbortWithStatus(204) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
c.Next() |
|
|
|
|
|
|
|
|
endTime := time.Now() |
|
|
latency := endTime.Sub(startTime) |
|
|
|
|
|
|
|
|
log.Printf("请求完成: %s %s 状态=%d 耗时=%v", |
|
|
c.Request.Method, c.Request.URL.Path, c.Writer.Status(), latency) |
|
|
}) |
|
|
|
|
|
|
|
|
r.GET("/health", healthCheck) |
|
|
r.POST("/count_tokens", countTokens) |
|
|
log.Println("配置路由: GET /health, POST /count_tokens") |
|
|
|
|
|
|
|
|
port := os.Getenv("PORT") |
|
|
if port == "" { |
|
|
port = "7860" |
|
|
log.Println("使用默认端口: 7860") |
|
|
} else { |
|
|
log.Println("使用配置端口:", port) |
|
|
} |
|
|
|
|
|
|
|
|
startKeepAlive() |
|
|
|
|
|
|
|
|
log.Printf("=== 服务器启动在端口 %s ===", port) |
|
|
if err := r.Run(":" + port); err != nil { |
|
|
log.Fatalf("服务器启动失败: %v", err) |
|
|
} |
|
|
} |
|
|
|