tokenizer / main.go
malt666's picture
Upload main.go
2033651 verified
package main
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/generative-ai-go/genai"
"google.golang.org/api/option"
)
// 配置结构
type Config struct {
AnthropicKey string
GoogleKey string
ServiceURL string
DeepseekURL string
OpenAIURL string
}
var (
config Config
configOnce sync.Once
)
// 请求结构
type TokenCountRequest struct {
Model string `json:"model" binding:"required"`
Messages []Message `json:"messages" binding:"required"`
System *string `json:"system,omitempty"`
}
type Message struct {
Role string `json:"role" binding:"required"`
Content string `json:"content" binding:"required"`
}
// 响应结构
type TokenCountResponse struct {
InputTokens int `json:"input_tokens"`
}
// 错误响应结构
type ErrorResponse struct {
Error string `json:"error"`
}
// 模型映射规则
type ModelRule struct {
Keywords []string
Target string
}
var modelRules = []ModelRule{
{
Keywords: []string{"deepseek"},
Target: "deepseek-v3",
},
// 先放更具体的规则
{
Keywords: []string{"claude", "3", "5", "sonnet"},
Target: "claude-3-5-sonnet-latest",
},
{
Keywords: []string{"claude", "3", "5", "haiku"},
Target: "claude-3-5-haiku-latest",
},
{
Keywords: []string{"claude", "3", "7"},
Target: "claude-3-7-sonnet-latest",
},
{
Keywords: []string{"claude", "3", "opus"},
Target: "claude-3-opus-latest",
},
{
Keywords: []string{"claude", "3", "haiku"},
Target: "claude-3-haiku-20240307",
},
// 再放一般规则
{
Keywords: []string{"claude", "3", "sonnet"},
Target: "claude-3-sonnet-20240229",
},
{
Keywords: []string{"gemini", "2.0"},
Target: "gemini-2.0-flash",
},
{
Keywords: []string{"gemini", "2.5"},
Target: "gemini-2.0-flash", // 目前使用2.0-flash作为2.5的替代
},
{
Keywords: []string{"gemini", "1.5"},
Target: "gemini-1.5-flash",
},
}
// 智能匹配模型名称
func matchModelName(input string) string {
// 转换为小写进行匹配
input = strings.ToLower(input)
log.Printf("正在匹配模型名称: %s", input)
// 特殊处理 Claude 3.5 系列
if strings.Contains(input, "claude") && strings.Contains(input, "3.5") ||
strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "5") {
if strings.Contains(input, "sonnet") {
log.Printf("匹配到Claude 3.5 Sonnet")
return "claude-3-5-sonnet-latest"
} else if strings.Contains(input, "haiku") {
log.Printf("匹配到Claude 3.5 Haiku")
return "claude-3-5-haiku-latest"
} else {
// 默认为Sonnet
log.Printf("匹配到Claude 3.5 (默认使用Sonnet)")
return "claude-3-5-sonnet-latest"
}
}
// 特殊处理 Claude 3.7 系列
if strings.Contains(input, "claude") && strings.Contains(input, "3.7") ||
strings.Contains(input, "claude") && strings.Contains(input, "3") && strings.Contains(input, "7") {
log.Printf("匹配到Claude 3.7")
return "claude-3-7-sonnet-latest"
}
// 特殊规则:OpenAI GPT-4o
if (strings.Contains(input, "gpt") && strings.Contains(input, "4o")) ||
strings.Contains(input, "o1") ||
strings.Contains(input, "o3") {
log.Printf("匹配到GPT-4o")
return "gpt-4o"
}
// 特殊规则:OpenAI GPT-4
if (strings.Contains(input, "gpt") && strings.Contains(input, "3") && strings.Contains(input, "5")) ||
(strings.Contains(input, "gpt") && strings.Contains(input, "4") && !strings.Contains(input, "4o")) {
log.Printf("匹配到GPT-4")
return "gpt-4"
}
// 遍历所有规则
for _, rule := range modelRules {
matches := true
for _, keyword := range rule.Keywords {
if !strings.Contains(input, strings.ToLower(keyword)) {
matches = false
break
}
}
if matches {
log.Printf("通过规则匹配到: %s", rule.Target)
return rule.Target
}
}
// 如果没有匹配到,返回原始输入
log.Printf("没有匹配到任何规则,使用原始输入: %s", input)
return input
}
// 加载配置
func loadConfig() Config {
configOnce.Do(func() {
// 配置日志格式
log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
log.Println("开始加载配置...")
config.AnthropicKey = os.Getenv("ANTHROPIC_API_KEY")
if config.AnthropicKey == "" {
log.Println("警告: ANTHROPIC_API_KEY 环境变量未设置,Claude模型将无法使用")
} else {
log.Println("Anthropic API Key已配置")
}
config.GoogleKey = os.Getenv("GOOGLE_API_KEY")
if config.GoogleKey == "" {
log.Println("警告: GOOGLE_API_KEY 环境变量未设置,Gemini模型将无法使用")
} else {
log.Println("Google API Key已配置")
}
// 获取Deepseek服务URL
config.DeepseekURL = os.Getenv("DEEPSEEK_URL")
if config.DeepseekURL == "" {
config.DeepseekURL = "http://127.0.0.1:7861" // 默认本地地址
log.Println("使用默认Deepseek服务地址:", config.DeepseekURL)
} else {
log.Println("使用配置的Deepseek服务地址:", config.DeepseekURL)
}
// 获取OpenAI服务URL
config.OpenAIURL = os.Getenv("OPENAI_URL")
if config.OpenAIURL == "" {
config.OpenAIURL = "http://127.0.0.1:7862" // 默认本地地址
log.Println("使用默认OpenAI服务地址:", config.OpenAIURL)
} else {
log.Println("使用配置的OpenAI服务地址:", config.OpenAIURL)
}
// 获取服务URL,用于防休眠
config.ServiceURL = os.Getenv("SERVICE_URL")
if config.ServiceURL == "" {
log.Println("SERVICE_URL 未设置,防休眠功能将被禁用")
} else {
log.Println("防休眠URL已配置:", config.ServiceURL)
}
log.Println("配置加载完成")
})
return config
}
// 使用Claude API计算token
func countTokensWithClaude(req TokenCountRequest) (TokenCountResponse, error) {
// 准备请求Anthropic API
log.Printf("开始Claude API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages))
// 验证并过滤空内容的消息
var filteredMessages []Message
for i, msg := range req.Messages {
if msg.Content == "" {
log.Printf("警告: 消息 #%d 内容为空,将被过滤掉", i)
continue // 跳过空内容消息
}
if msg.Role != "user" && msg.Role != "assistant" {
log.Printf("警告: 消息 #%d 角色'%s'不是标准角色(user/assistant),可能导致请求失败", i, msg.Role)
}
filteredMessages = append(filteredMessages, msg)
}
if len(filteredMessages) == 0 {
log.Printf("错误: 过滤后没有有效消息")
return TokenCountResponse{}, fmt.Errorf("没有有效消息:所有消息内容都为空")
}
// 创建新请求,使用过滤后的消息
filteredReq := TokenCountRequest{
Model: req.Model,
Messages: filteredMessages,
System: req.System,
}
// 记录过滤后的消息数量
if len(filteredMessages) != len(req.Messages) {
log.Printf("消息过滤: 原始消息数=%d, 过滤后消息数=%d", len(req.Messages), len(filteredMessages))
}
client := &http.Client{}
data, err := json.Marshal(filteredReq)
if err != nil {
log.Printf("错误: 序列化Claude请求失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
}
// 记录请求内容用于调试
if len(data) < 1000 {
log.Printf("Claude请求内容: %s", string(data))
} else {
log.Printf("Claude请求内容较大,长度=%d字节", len(data))
}
// 创建请求
request, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages/count_tokens", bytes.NewBuffer(data))
if err != nil {
log.Printf("错误: 创建Claude请求失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
}
// 设置请求头
request.Header.Set("x-api-key", config.AnthropicKey)
request.Header.Set("anthropic-version", "2023-06-01")
request.Header.Set("content-type", "application/json")
// 发送请求
response, err := client.Do(request)
if err != nil {
log.Printf("错误: 发送请求到Anthropic API失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("发送请求到Anthropic API失败: %v", err)
}
defer response.Body.Close()
// 检查响应状态码
if response.StatusCode != http.StatusOK {
// 读取错误响应
var errorBody []byte
errorBody, _ = io.ReadAll(response.Body)
log.Printf("错误: Claude API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody))
// 检查常见错误
errorStr := string(errorBody)
if response.StatusCode == http.StatusUnauthorized || strings.Contains(errorStr, "invalid_api_key") {
log.Printf("错误: Claude API密钥无效或过期")
return TokenCountResponse{}, fmt.Errorf("Claude API验证失败,请检查API Key是否有效: %s", string(errorBody))
} else if response.StatusCode == http.StatusBadRequest {
if strings.Contains(errorStr, "empty content") {
log.Printf("错误: 请求包含空内容的消息")
return TokenCountResponse{}, fmt.Errorf("请求格式错误: 消息不能有空内容: %s", string(errorBody))
} else if strings.Contains(errorStr, "invalid_request_error") {
log.Printf("错误: 无效的请求格式")
return TokenCountResponse{}, fmt.Errorf("无效的请求格式: %s", string(errorBody))
}
}
return TokenCountResponse{}, fmt.Errorf("Claude API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody))
}
// 读取响应
var result TokenCountResponse
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
log.Printf("错误: 解码Claude响应失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
}
log.Printf("Claude API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens)
return result, nil
}
// 使用Gemini API计算token
func countTokensWithGemini(req TokenCountRequest) (TokenCountResponse, error) {
// 检查API密钥
log.Printf("开始Gemini API请求: 模型=%s, 消息数量=%d", req.Model, len(req.Messages))
if config.GoogleKey == "" {
log.Printf("错误: Gemini API密钥未设置")
return TokenCountResponse{}, fmt.Errorf("GOOGLE_API_KEY 未设置")
}
// 创建Gemini客户端
ctx := context.Background()
client, err := genai.NewClient(ctx, option.WithAPIKey(config.GoogleKey))
if err != nil {
log.Printf("错误: 创建Gemini客户端失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("创建Gemini客户端失败: %v", err)
}
defer client.Close()
// 使用已经匹配好的模型名称
modelName := req.Model
log.Printf("使用Gemini模型: %s", modelName)
// 创建Gemini模型
model := client.GenerativeModel(modelName)
// 构建提示内容
var content string
if req.System != nil && *req.System != "" {
content += *req.System + "\n\n"
log.Printf("Gemini请求包含系统提示: %s", *req.System)
}
for _, msg := range req.Messages {
if msg.Role == "user" {
content += "用户: " + msg.Content + "\n"
} else if msg.Role == "assistant" {
content += "助手: " + msg.Content + "\n"
} else {
content += msg.Role + ": " + msg.Content + "\n"
}
}
// 计算token
log.Printf("开始计算Gemini tokens...")
tokResp, err := model.CountTokens(ctx, genai.Text(content))
if err != nil {
log.Printf("错误: 计算Gemini token失败: %v", err)
if strings.Contains(err.Error(), "invalid_api_key") || strings.Contains(err.Error(), "permission_denied") {
log.Printf("错误: Gemini API密钥可能无效或过期")
}
return TokenCountResponse{}, fmt.Errorf("计算Gemini token失败: %v", err)
}
log.Printf("Gemini API请求成功: 模型=%s, 输入tokens=%d", req.Model, tokResp.TotalTokens)
return TokenCountResponse{InputTokens: int(tokResp.TotalTokens)}, nil
}
// 使用Deepseek API计算token
func countTokensWithDeepseek(req TokenCountRequest) (TokenCountResponse, error) {
log.Printf("开始Deepseek API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.DeepseekURL)
// 准备请求
client := &http.Client{}
data, err := json.Marshal(req)
if err != nil {
log.Printf("错误: 序列化Deepseek请求失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
}
// 创建请求
requestURL := config.DeepseekURL + "/count_tokens"
log.Printf("发送请求到Deepseek服务: %s", requestURL)
request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data))
if err != nil {
log.Printf("错误: 创建Deepseek请求失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
}
// 设置请求头
request.Header.Set("Content-Type", "application/json")
// 发送请求
response, err := client.Do(request)
if err != nil {
log.Printf("错误: 发送请求到Deepseek服务失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("发送请求到Deepseek服务失败: %v", err)
}
defer response.Body.Close()
// 检查响应状态码
if response.StatusCode != http.StatusOK {
// 读取错误响应
var errorBody []byte
errorBody, _ = io.ReadAll(response.Body)
log.Printf("错误: Deepseek API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody))
return TokenCountResponse{}, fmt.Errorf("Deepseek API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody))
}
// 读取响应
var result TokenCountResponse
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
log.Printf("错误: 解码Deepseek响应失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
}
log.Printf("Deepseek API请求成功: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens)
return result, nil
}
// 使用OpenAI API计算token
func countTokensWithOpenAI(req TokenCountRequest) (TokenCountResponse, error) {
log.Printf("开始OpenAI API请求: 模型=%s, 消息数量=%d, 服务地址=%s", req.Model, len(req.Messages), config.OpenAIURL)
// 准备请求
client := &http.Client{}
data, err := json.Marshal(req)
if err != nil {
log.Printf("错误: 序列化OpenAI请求失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("序列化请求失败: %v", err)
}
// 创建请求
requestURL := config.OpenAIURL + "/count_tokens"
log.Printf("发送请求到OpenAI服务: %s", requestURL)
request, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(data))
if err != nil {
log.Printf("错误: 创建OpenAI请求失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("创建请求失败: %v", err)
}
// 设置请求头
request.Header.Set("Content-Type", "application/json")
// 发送请求
response, err := client.Do(request)
if err != nil {
log.Printf("错误: 发送请求到OpenAI服务失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("发送请求到OpenAI服务失败: %v", err)
}
defer response.Body.Close()
// 检查响应状态码
if response.StatusCode != http.StatusOK {
// 读取错误响应
var errorBody []byte
errorBody, _ = io.ReadAll(response.Body)
log.Printf("错误: OpenAI API返回非200状态码: %d, 响应体: %s", response.StatusCode, string(errorBody))
return TokenCountResponse{}, fmt.Errorf("OpenAI API返回错误状态码: %d, 响应: %s", response.StatusCode, string(errorBody))
}
// 读取响应
var result struct {
InputTokens int `json:"input_tokens"`
Model string `json:"model"`
Encoding string `json:"encoding"`
}
if err := json.NewDecoder(response.Body).Decode(&result); err != nil {
log.Printf("错误: 解码OpenAI响应失败: %v", err)
return TokenCountResponse{}, fmt.Errorf("解码响应失败: %v", err)
}
log.Printf("OpenAI API请求成功: 模型=%s(实际使用=%s), 编码=%s, 输入tokens=%d",
req.Model, result.Model, result.Encoding, result.InputTokens)
return TokenCountResponse{InputTokens: result.InputTokens}, nil
}
// 计算token
func countTokens(c *gin.Context) {
var req TokenCountRequest
if err := c.ShouldBindJSON(&req); err != nil {
log.Printf("错误: 无效的请求格式: %v", err)
c.JSON(http.StatusBadRequest, ErrorResponse{Error: err.Error()})
return
}
// 记录请求详情
systemPrompt := "无"
if req.System != nil && *req.System != "" {
systemPrompt = *req.System
}
log.Printf("收到token计算请求: 原始模型=%s, 消息数量=%d, 系统提示=%s",
req.Model, len(req.Messages), systemPrompt)
// 保存原始模型名称
originalModel := req.Model
// 检查是否为不支持的模型
isUnsupportedModel := true
// 检查是否为支持的模型类型
modelLower := strings.ToLower(req.Model)
if strings.Contains(modelLower, "gpt") || strings.Contains(modelLower, "openai") ||
strings.Contains(modelLower, "o1") || strings.Contains(modelLower, "o3") ||
strings.HasPrefix(modelLower, "claude") ||
strings.Contains(modelLower, "gemini") ||
strings.Contains(modelLower, "deepseek") {
isUnsupportedModel = false
}
// 智能匹配模型名称
req.Model = matchModelName(req.Model)
log.Printf("模型名称匹配结果: 原始=%s -> 匹配=%s", originalModel, req.Model)
var result TokenCountResponse
var err error
// 优先检查是否为Deepseek模型
if strings.Contains(strings.ToLower(req.Model), "deepseek") {
log.Printf("使用Deepseek API计算token")
// 使用Deepseek API
result, err = countTokensWithDeepseek(req)
} else if strings.Contains(strings.ToLower(req.Model), "gpt") || strings.Contains(strings.ToLower(req.Model), "openai") {
log.Printf("使用OpenAI API计算token")
// 使用OpenAI API
result, err = countTokensWithOpenAI(req)
} else if strings.HasPrefix(strings.ToLower(req.Model), "claude") {
log.Printf("使用Claude API计算token")
// 使用Claude API
if config.AnthropicKey == "" {
log.Printf("错误: ANTHROPIC_API_KEY未设置")
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "ANTHROPIC_API_KEY 未设置,无法使用Claude模型"})
return
}
result, err = countTokensWithClaude(req)
} else if strings.Contains(strings.ToLower(req.Model), "gemini") {
log.Printf("使用Gemini API计算token")
// 使用Gemini API
if config.GoogleKey == "" {
log.Printf("错误: GOOGLE_API_KEY未设置")
c.JSON(http.StatusBadRequest, ErrorResponse{Error: "GOOGLE_API_KEY 未设置,无法使用Gemini模型"})
return
}
result, err = countTokensWithGemini(req)
} else if isUnsupportedModel {
log.Printf("不支持的模型: %s, 将使用GPT-4o估算", originalModel)
// 不支持的模型,使用GPT-4o估算
// 创建新的请求,使用GPT-4o
gptReq := req
gptReq.Model = "gpt-4o"
// 使用OpenAI API
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq)
if estimateErr == nil {
log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens)
// 返回估算值,但添加警告信息,使用400状态码
c.JSON(http.StatusBadRequest, gin.H{
"input_tokens": estimatedResult.InputTokens,
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
"estimated_with": "gpt-4o",
"error": fmt.Sprintf("Unsupported model: %s", originalModel),
})
return
} else {
log.Printf("使用GPT-4o估算失败: %v", estimateErr)
c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("Failed to estimate tokens for unsupported model: %s", originalModel)})
return
}
} else {
log.Printf("完全不支持的模型: %s, 将尝试使用GPT-4o估算", originalModel)
// 完全不支持的情况,返回错误但仍提供估算值
// 使用GPT-4o进行估算
gptReq := req
gptReq.Model = "gpt-4o"
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq)
if estimateErr == nil {
log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens)
c.JSON(http.StatusBadRequest, gin.H{
"input_tokens": estimatedResult.InputTokens,
"warning": fmt.Sprintf("The tokenizer for model '%s' is not supported yet. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
"estimated_with": "gpt-4o",
"error": fmt.Sprintf("Unsupported model: %s", originalModel),
})
} else {
log.Printf("使用GPT-4o估算失败: %v", estimateErr)
c.JSON(http.StatusBadRequest, ErrorResponse{Error: fmt.Sprintf("The tokenizer for model '%s' is not supported yet.", originalModel)})
}
return
}
if err != nil {
log.Printf("计算token失败: %v", err)
// 对所有API调用失败的情况尝试使用GPT-4o估算
log.Printf("API调用失败,尝试使用GPT-4o估算: 原始模型=%s, 错误=%v", req.Model, err)
// 创建新的请求,使用GPT-4o
gptReq := req
gptReq.Model = "gpt-4o"
// 使用OpenAI API进行估算
estimatedResult, estimateErr := countTokensWithOpenAI(gptReq)
if estimateErr == nil {
log.Printf("使用GPT-4o估算成功: 模型=%s, 估算tokens=%d", originalModel, estimatedResult.InputTokens)
// 返回估算值,但添加警告信息和原始错误,使用400状态码
c.JSON(http.StatusBadRequest, gin.H{
"input_tokens": estimatedResult.InputTokens,
"warning": fmt.Sprintf("Token calculation for model '%s' failed. This is an estimation based on gpt-4o and may not be accurate.", originalModel),
"estimated_with": "gpt-4o",
"error": err.Error(),
})
return
} else {
log.Printf("使用GPT-4o估算也失败: %v", estimateErr)
// 如果GPT-4o估算也失败,返回原始错误
c.JSON(http.StatusInternalServerError, ErrorResponse{Error: err.Error()})
return
}
}
// 返回结果
log.Printf("成功计算token: 模型=%s, 输入tokens=%d", req.Model, result.InputTokens)
c.JSON(http.StatusOK, result)
}
// 健康检查
func healthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"time": time.Now().Format(time.RFC3339),
})
}
// 防休眠任务
func startKeepAlive() {
if config.ServiceURL == "" {
return
}
healthURL := fmt.Sprintf("%s/health", config.ServiceURL)
ticker := time.NewTicker(10 * time.Hour)
// 立即执行一次检查
go func() {
log.Printf("Starting keep-alive checks to %s", healthURL)
for {
resp, err := http.Get(healthURL)
if err != nil {
log.Printf("Keep-alive check failed: %v", err)
} else {
resp.Body.Close()
log.Printf("Keep-alive check successful")
}
// 等待下一次触发
<-ticker.C
}
}()
}
func main() {
// 加载配置
loadConfig()
log.Println("=== Token计算服务启动 ===")
// 设置gin模式
gin.SetMode(gin.ReleaseMode)
log.Println("设置Gin为发布模式")
// 创建路由
r := gin.Default()
log.Println("创建Gin路由")
// 添加中间件
r.Use(gin.Recovery())
r.Use(func(c *gin.Context) {
// 请求开始时间
startTime := time.Now()
// 请求信息记录
log.Printf("收到请求: %s %s 来自 %s", c.Request.Method, c.Request.URL.Path, c.ClientIP())
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
// 处理请求
c.Next()
// 请求完成时间
endTime := time.Now()
latency := endTime.Sub(startTime)
// 请求结果记录
log.Printf("请求完成: %s %s 状态=%d 耗时=%v",
c.Request.Method, c.Request.URL.Path, c.Writer.Status(), latency)
})
// 路由
r.GET("/health", healthCheck)
r.POST("/count_tokens", countTokens)
log.Println("配置路由: GET /health, POST /count_tokens")
// 获取端口
port := os.Getenv("PORT")
if port == "" {
port = "7860" // Hugging Face默认端口
log.Println("使用默认端口: 7860")
} else {
log.Println("使用配置端口:", port)
}
// 启动防休眠任务
startKeepAlive()
// 启动服务器
log.Printf("=== 服务器启动在端口 %s ===", port)
if err := r.Run(":" + port); err != nil {
log.Fatalf("服务器启动失败: %v", err)
}
}