aug2a / api /handler.go
huanbao's picture
Update api/handler.go
17cbf51 verified
package api
import (
"augment2api/config"
"bufio"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// OpenAIRequest OpenAI兼容的请求结构
type OpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []ChatMessage `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
// OpenAIResponse OpenAI兼容的响应结构
type OpenAIResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
// OpenAIStreamResponse OpenAI兼容的流式响应结构
type OpenAIStreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
}
type StreamChoice struct {
Index int `json:"index"`
Delta ChatMessage `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type Choice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason *string `json:"finish_reason"`
}
type ChatMessage struct {
Role string `json:"role"`
Content interface{} `json:"content"`
}
// GetContent 添加一个辅助方法来获取消息内容
func (m ChatMessage) GetContent() string {
switch v := m.Content.(type) {
case string:
return v
case []interface{}:
var result string
for _, item := range v {
if contentMap, ok := item.(map[string]interface{}); ok {
if text, exists := contentMap["text"]; exists {
if textStr, ok := text.(string); ok {
result += textStr
}
}
}
}
return result
default:
return ""
}
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// AugmentRequest Augment API请求结构
type AugmentRequest struct {
ChatHistory []AugmentChatHistory `json:"chat_history"`
Message string `json:"message"`
Mode string `json:"mode"`
}
type AugmentChatHistory struct {
ResponseText string `json:"response_text"`
RequestMessage string `json:"request_message"`
}
// AugmentResponse Augment API响应结构
type AugmentResponse struct {
Text string `json:"text"`
Done bool `json:"done"`
}
// CodeResponse 用于解析从授权服务返回的代码
type CodeResponse struct {
Code string `json:"code"`
State string `json:"state"`
TenantURL string `json:"tenant_url"`
}
// ModelObject OpenAI模型对象结构
type ModelObject struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
OwnedBy string `json:"owned_by"`
}
// ModelsResponse OpenAI模型列表响应结构
type ModelsResponse struct {
Object string `json:"object"`
Data []ModelObject `json:"data"`
}
// 全局变量
var (
accessToken string
tenantURL string
)
// SetAuthInfo 设置认证信息
func SetAuthInfo(token, tenant string) {
accessToken = token
tenantURL = tenant
}
// GetAuthInfo 获取认证信息
func GetAuthInfo() (string, string) {
if config.AppConfig.CodingMode == "true" {
// 调试模式
return config.AppConfig.CodingToken, config.AppConfig.TenantURL
}
// 随机获取一个token
token, tenantURL := GetRandomToken()
if token != "" && tenantURL != "" {
return token, tenantURL
}
// 如果没有可用的token,则使用内存中的token
return accessToken, tenantURL
}
// convertToAugmentRequest 将OpenAI请求转换为Augment请求
func convertToAugmentRequest(req OpenAIRequest) AugmentRequest {
augmentReq := AugmentRequest{
Mode: "CHAT",
}
if len(req.Messages) > 0 {
lastMsg := req.Messages[len(req.Messages)-1]
augmentReq.Message = lastMsg.GetContent()
}
var history []AugmentChatHistory
for i := 0; i < len(req.Messages)-1; i += 2 {
if i+1 < len(req.Messages) {
history = append(history, AugmentChatHistory{
RequestMessage: req.Messages[i].GetContent(),
ResponseText: req.Messages[i+1].GetContent(),
})
}
}
augmentReq.ChatHistory = history
return augmentReq
}
// AuthHandler 处理授权请求
func AuthHandler(c *gin.Context, authorizeURL string) {
c.JSON(http.StatusOK, gin.H{
"authorize_url": authorizeURL,
})
}
// CallbackHandler 处理回调请求
func CallbackHandler(c *gin.Context, getAccessTokenFunc func(string, string, string) (string, error)) {
// 1. 解析请求数据
var codeResp CodeResponse
if err := c.ShouldBindJSON(&codeResp); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求数据"})
return
}
// 2. 使用授权码获取访问令牌
token, err := getAccessTokenFunc(codeResp.TenantURL, "", codeResp.Code)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// 3. 保存令牌和租户URL
SetAuthInfo(token, codeResp.TenantURL)
// 4. 保存到Redis
if err := SaveTokenToRedis(token, codeResp.TenantURL); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存token到Redis失败: " + err.Error()})
return
}
// 5. 返回成功响应
c.JSON(http.StatusOK, gin.H{
"status": "success",
"token": token,
})
}
// ModelsHandler 处理模型请求
func ModelsHandler(c *gin.Context) {
// 创建符合OpenAI格式的模型列表响应
response := ModelsResponse{
Object: "list",
Data: []ModelObject{
{
ID: "claude-3-7-sonnet-20250219",
Object: "model",
Created: 1708387201,
OwnedBy: "anthropic",
},
{
ID: "claude-3.7",
Object: "model",
Created: 1708387200,
OwnedBy: "anthropic",
},
},
}
c.JSON(http.StatusOK, response)
}
// ChatCompletionsHandler 处理聊天完成请求
func ChatCompletionsHandler(c *gin.Context) {
token, tenant := GetAuthInfo()
if token == "" || tenant == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无可用Token,请先在管理页面获取"})
return
}
var openAIReq OpenAIRequest
if err := c.ShouldBindJSON(&openAIReq); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求数据"})
return
}
// 转换为Augment请求格式
augmentReq := convertToAugmentRequest(openAIReq)
// 处理流式请求
if openAIReq.Stream {
handleStreamRequest(c, augmentReq, openAIReq.Model)
return
}
// 处理非流式请求
handleNonStreamRequest(c, augmentReq, openAIReq.Model)
}
// 处理流式请求
func handleStreamRequest(c *gin.Context, augmentReq AugmentRequest, model string) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
// 获取token和tenant_url
token, tenant := GetAuthInfo()
if token == "" || tenant == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无可用Token,请先在管理页面获取"})
return
}
// 准备请求数据
jsonData, err := json.Marshal(augmentReq)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "序列化请求失败"})
return
}
// 创建请求 - 使用获取到的tenant_url
req, err := http.NewRequest("POST", tenant+"chat-stream", strings.NewReader(string(jsonData)))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败"})
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token) // 使用获取到的token
// 发送请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "请求失败: " + err.Error()})
return
}
defer resp.Body.Close()
// 设置刷新器以确保数据立即发送
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": "流式传输不支持"})
return
}
// 读取并转发响应
reader := bufio.NewReader(resp.Body)
responseID := fmt.Sprintf("chatcmpl-%d", time.Now().Unix())
//log.Printf("【reader】: %v", reader)
var fullText string
for {
line, err := reader.ReadString('\n')
if err != nil {
log.Printf("【err】: %v", err)
if err == io.EOF {
break
}
log.Printf("读取响应失败: %v", err)
break
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
var augmentResp AugmentResponse
if err := json.Unmarshal([]byte(line), &augmentResp); err != nil {
log.Printf("解析响应失败: %v", err)
continue
}
fullText += augmentResp.Text
// 创建OpenAI兼容的流式响应
streamResp := OpenAIStreamResponse{
ID: responseID,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: model,
Choices: []StreamChoice{
{
Index: 0,
Delta: ChatMessage{
Role: "assistant",
Content: augmentResp.Text,
},
FinishReason: nil,
},
},
}
// 如果是最后一条消息,设置完成原因
if augmentResp.Done {
finishReason := "stop"
streamResp.Choices[0].FinishReason = &finishReason
}
// 序列化并发送响应
jsonResp, err := json.Marshal(streamResp)
if err != nil {
log.Printf("序列化响应失败: %v", err)
continue
}
fmt.Fprintf(c.Writer, "data: %s\n\n", jsonResp)
flusher.Flush()
// 如果完成,发送最后的[DONE]标记
if augmentResp.Done {
fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
break
}
}
}
// estimateTokenCount 粗略估计文本中的token数量
// 这是一个简单的估算方法,实际token数量取决于具体的分词算法
func estimateTokenCount(text string) int {
// 英文单词和标点符号大约是1个token
// 中文字符大约是1.5个token(每个字符约为0.75个token)
// 按空格分割英文单词
words := strings.Fields(text)
wordCount := len(words)
// 计算中文字符数量
chineseCount := 0
for _, r := range text {
if r >= 0x4E00 && r <= 0x9FFF {
chineseCount++
}
}
// 粗略估计:英文单词按1个token计算,中文字符按0.75个token计算
return wordCount + int(float64(chineseCount)*0.75)
}
// 处理非流式请求
func handleNonStreamRequest(c *gin.Context, augmentReq AugmentRequest, model string) {
// 获取token和tenant_url
token, tenant := GetAuthInfo()
if token == "" || tenant == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "无可用Token,请先在管理页面获取"})
return
}
// 准备请求数据
jsonData, err := json.Marshal(augmentReq)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "序列化请求失败"})
return
}
// 创建请求 - 使用获取到的tenant_url
req, err := http.NewRequest("POST", tenant+"chat-stream", strings.NewReader(string(jsonData)))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败"})
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+token) // 使用获取到的token
// 发送请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "请求失败: " + err.Error()})
return
}
defer resp.Body.Close()
// 读取完整响应
reader := bufio.NewReader(resp.Body)
var fullText string
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取响应失败: " + err.Error()})
return
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
var augmentResp AugmentResponse
if err := json.Unmarshal([]byte(line), &augmentResp); err != nil {
continue
}
fullText += augmentResp.Text
if augmentResp.Done {
break
}
}
// 创建OpenAI兼容的响应
finishReason := "stop"
// 估算token数量
promptTokens := estimateTokenCount(augmentReq.Message)
for _, history := range augmentReq.ChatHistory {
promptTokens += estimateTokenCount(history.RequestMessage)
promptTokens += estimateTokenCount(history.ResponseText)
}
completionTokens := estimateTokenCount(fullText)
openAIResp := OpenAIResponse{
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: model,
Choices: []Choice{
{
Index: 0,
Message: ChatMessage{
Role: "assistant",
Content: fullText,
},
FinishReason: &finishReason,
},
},
Usage: Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
},
}
c.JSON(http.StatusOK, openAIResp)
}