|
|
package api |
|
|
|
|
|
import ( |
|
|
"augment2api/config" |
|
|
"bufio" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"io" |
|
|
"log" |
|
|
"net/http" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
) |
|
|
|
|
|
|
|
|
type OpenAIRequest struct { |
|
|
Model string `json:"model,omitempty"` |
|
|
Messages []ChatMessage `json:"messages,omitempty"` |
|
|
Stream bool `json:"stream,omitempty"` |
|
|
Temperature float64 `json:"temperature,omitempty"` |
|
|
MaxTokens int `json:"max_tokens,omitempty"` |
|
|
} |
|
|
|
|
|
|
|
|
type OpenAIResponse struct { |
|
|
ID string `json:"id"` |
|
|
Object string `json:"object"` |
|
|
Created int64 `json:"created"` |
|
|
Model string `json:"model"` |
|
|
Choices []Choice `json:"choices"` |
|
|
Usage Usage `json:"usage"` |
|
|
} |
|
|
|
|
|
|
|
|
type OpenAIStreamResponse struct { |
|
|
ID string `json:"id"` |
|
|
Object string `json:"object"` |
|
|
Created int64 `json:"created"` |
|
|
Model string `json:"model"` |
|
|
Choices []StreamChoice `json:"choices"` |
|
|
} |
|
|
|
|
|
type StreamChoice struct { |
|
|
Index int `json:"index"` |
|
|
Delta ChatMessage `json:"delta"` |
|
|
FinishReason *string `json:"finish_reason"` |
|
|
} |
|
|
|
|
|
type Choice struct { |
|
|
Index int `json:"index"` |
|
|
Message ChatMessage `json:"message"` |
|
|
FinishReason *string `json:"finish_reason"` |
|
|
} |
|
|
|
|
|
type ChatMessage struct { |
|
|
Role string `json:"role"` |
|
|
Content interface{} `json:"content"` |
|
|
} |
|
|
|
|
|
|
|
|
func (m ChatMessage) GetContent() string { |
|
|
switch v := m.Content.(type) { |
|
|
case string: |
|
|
return v |
|
|
case []interface{}: |
|
|
var result string |
|
|
for _, item := range v { |
|
|
if contentMap, ok := item.(map[string]interface{}); ok { |
|
|
if text, exists := contentMap["text"]; exists { |
|
|
if textStr, ok := text.(string); ok { |
|
|
result += textStr |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return result |
|
|
default: |
|
|
return "" |
|
|
} |
|
|
} |
|
|
|
|
|
type Usage struct { |
|
|
PromptTokens int `json:"prompt_tokens"` |
|
|
CompletionTokens int `json:"completion_tokens"` |
|
|
TotalTokens int `json:"total_tokens"` |
|
|
} |
|
|
|
|
|
|
|
|
type AugmentRequest struct { |
|
|
ChatHistory []AugmentChatHistory `json:"chat_history"` |
|
|
Message string `json:"message"` |
|
|
Mode string `json:"mode"` |
|
|
} |
|
|
|
|
|
type AugmentChatHistory struct { |
|
|
ResponseText string `json:"response_text"` |
|
|
RequestMessage string `json:"request_message"` |
|
|
} |
|
|
|
|
|
|
|
|
type AugmentResponse struct { |
|
|
Text string `json:"text"` |
|
|
Done bool `json:"done"` |
|
|
} |
|
|
|
|
|
|
|
|
type CodeResponse struct { |
|
|
Code string `json:"code"` |
|
|
State string `json:"state"` |
|
|
TenantURL string `json:"tenant_url"` |
|
|
} |
|
|
|
|
|
|
|
|
type ModelObject struct { |
|
|
ID string `json:"id"` |
|
|
Object string `json:"object"` |
|
|
Created int `json:"created"` |
|
|
OwnedBy string `json:"owned_by"` |
|
|
} |
|
|
|
|
|
|
|
|
type ModelsResponse struct { |
|
|
Object string `json:"object"` |
|
|
Data []ModelObject `json:"data"` |
|
|
} |
|
|
|
|
|
|
|
|
var ( |
|
|
accessToken string |
|
|
tenantURL string |
|
|
) |
|
|
|
|
|
|
|
|
func SetAuthInfo(token, tenant string) { |
|
|
accessToken = token |
|
|
tenantURL = tenant |
|
|
} |
|
|
|
|
|
|
|
|
func GetAuthInfo() (string, string) { |
|
|
if config.AppConfig.CodingMode == "true" { |
|
|
|
|
|
return config.AppConfig.CodingToken, config.AppConfig.TenantURL |
|
|
} |
|
|
|
|
|
|
|
|
token, tenantURL := GetRandomToken() |
|
|
if token != "" && tenantURL != "" { |
|
|
return token, tenantURL |
|
|
} |
|
|
|
|
|
|
|
|
return accessToken, tenantURL |
|
|
} |
|
|
|
|
|
|
|
|
func convertToAugmentRequest(req OpenAIRequest) AugmentRequest { |
|
|
augmentReq := AugmentRequest{ |
|
|
Mode: "CHAT", |
|
|
} |
|
|
|
|
|
if len(req.Messages) > 0 { |
|
|
lastMsg := req.Messages[len(req.Messages)-1] |
|
|
augmentReq.Message = lastMsg.GetContent() |
|
|
} |
|
|
|
|
|
var history []AugmentChatHistory |
|
|
for i := 0; i < len(req.Messages)-1; i += 2 { |
|
|
if i+1 < len(req.Messages) { |
|
|
history = append(history, AugmentChatHistory{ |
|
|
RequestMessage: req.Messages[i].GetContent(), |
|
|
ResponseText: req.Messages[i+1].GetContent(), |
|
|
}) |
|
|
} |
|
|
} |
|
|
|
|
|
augmentReq.ChatHistory = history |
|
|
return augmentReq |
|
|
} |
|
|
|
|
|
|
|
|
func AuthHandler(c *gin.Context, authorizeURL string) { |
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"authorize_url": authorizeURL, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func CallbackHandler(c *gin.Context, getAccessTokenFunc func(string, string, string) (string, error)) { |
|
|
|
|
|
var codeResp CodeResponse |
|
|
if err := c.ShouldBindJSON(&codeResp); err != nil { |
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求数据"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
token, err := getAccessTokenFunc(codeResp.TenantURL, "", codeResp.Code) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
SetAuthInfo(token, codeResp.TenantURL) |
|
|
|
|
|
|
|
|
if err := SaveTokenToRedis(token, codeResp.TenantURL); err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存token到Redis失败: " + err.Error()}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
c.JSON(http.StatusOK, gin.H{ |
|
|
"status": "success", |
|
|
"token": token, |
|
|
}) |
|
|
} |
|
|
|
|
|
|
|
|
func ModelsHandler(c *gin.Context) { |
|
|
|
|
|
response := ModelsResponse{ |
|
|
Object: "list", |
|
|
Data: []ModelObject{ |
|
|
{ |
|
|
ID: "claude-3-7-sonnet-20250219", |
|
|
Object: "model", |
|
|
Created: 1708387201, |
|
|
OwnedBy: "anthropic", |
|
|
}, |
|
|
{ |
|
|
ID: "claude-3.7", |
|
|
Object: "model", |
|
|
Created: 1708387200, |
|
|
OwnedBy: "anthropic", |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
c.JSON(http.StatusOK, response) |
|
|
} |
|
|
|
|
|
|
|
|
func ChatCompletionsHandler(c *gin.Context) { |
|
|
token, tenant := GetAuthInfo() |
|
|
if token == "" || tenant == "" { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无可用Token,请先在管理页面获取"}) |
|
|
return |
|
|
} |
|
|
|
|
|
var openAIReq OpenAIRequest |
|
|
if err := c.ShouldBindJSON(&openAIReq); err != nil { |
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求数据"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
augmentReq := convertToAugmentRequest(openAIReq) |
|
|
|
|
|
|
|
|
if openAIReq.Stream { |
|
|
handleStreamRequest(c, augmentReq, openAIReq.Model) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
handleNonStreamRequest(c, augmentReq, openAIReq.Model) |
|
|
} |
|
|
|
|
|
|
|
|
func handleStreamRequest(c *gin.Context, augmentReq AugmentRequest, model string) { |
|
|
c.Header("Content-Type", "text/event-stream") |
|
|
c.Header("Cache-Control", "no-cache") |
|
|
c.Header("Connection", "keep-alive") |
|
|
|
|
|
|
|
|
token, tenant := GetAuthInfo() |
|
|
if token == "" || tenant == "" { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无可用Token,请先在管理页面获取"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
jsonData, err := json.Marshal(augmentReq) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "序列化请求失败"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
req, err := http.NewRequest("POST", tenant+"chat-stream", strings.NewReader(string(jsonData))) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败"}) |
|
|
return |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", "Bearer "+token) |
|
|
|
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "请求失败: " + err.Error()}) |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
|
|
|
flusher, ok := c.Writer.(http.Flusher) |
|
|
if !ok { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "流式传输不支持"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
reader := bufio.NewReader(resp.Body) |
|
|
responseID := fmt.Sprintf("chatcmpl-%d", time.Now().Unix()) |
|
|
|
|
|
var fullText string |
|
|
for { |
|
|
line, err := reader.ReadString('\n') |
|
|
if err != nil { |
|
|
log.Printf("【err】: %v", err) |
|
|
if err == io.EOF { |
|
|
break |
|
|
} |
|
|
log.Printf("读取响应失败: %v", err) |
|
|
break |
|
|
} |
|
|
|
|
|
line = strings.TrimSpace(line) |
|
|
if line == "" { |
|
|
continue |
|
|
} |
|
|
|
|
|
var augmentResp AugmentResponse |
|
|
if err := json.Unmarshal([]byte(line), &augmentResp); err != nil { |
|
|
log.Printf("解析响应失败: %v", err) |
|
|
continue |
|
|
} |
|
|
|
|
|
fullText += augmentResp.Text |
|
|
|
|
|
|
|
|
streamResp := OpenAIStreamResponse{ |
|
|
ID: responseID, |
|
|
Object: "chat.completion.chunk", |
|
|
Created: time.Now().Unix(), |
|
|
Model: model, |
|
|
Choices: []StreamChoice{ |
|
|
{ |
|
|
Index: 0, |
|
|
Delta: ChatMessage{ |
|
|
Role: "assistant", |
|
|
Content: augmentResp.Text, |
|
|
}, |
|
|
FinishReason: nil, |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
if augmentResp.Done { |
|
|
finishReason := "stop" |
|
|
streamResp.Choices[0].FinishReason = &finishReason |
|
|
} |
|
|
|
|
|
|
|
|
jsonResp, err := json.Marshal(streamResp) |
|
|
if err != nil { |
|
|
log.Printf("序列化响应失败: %v", err) |
|
|
continue |
|
|
} |
|
|
|
|
|
fmt.Fprintf(c.Writer, "data: %s\n\n", jsonResp) |
|
|
flusher.Flush() |
|
|
|
|
|
|
|
|
if augmentResp.Done { |
|
|
fmt.Fprintf(c.Writer, "data: [DONE]\n\n") |
|
|
flusher.Flush() |
|
|
break |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func estimateTokenCount(text string) int { |
|
|
|
|
|
|
|
|
|
|
|
words := strings.Fields(text) |
|
|
wordCount := len(words) |
|
|
|
|
|
|
|
|
chineseCount := 0 |
|
|
for _, r := range text { |
|
|
if r >= 0x4E00 && r <= 0x9FFF { |
|
|
chineseCount++ |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
return wordCount + int(float64(chineseCount)*0.75) |
|
|
} |
|
|
|
|
|
|
|
|
func handleNonStreamRequest(c *gin.Context, augmentReq AugmentRequest, model string) { |
|
|
|
|
|
token, tenant := GetAuthInfo() |
|
|
if token == "" || tenant == "" { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无可用Token,请先在管理页面获取"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
jsonData, err := json.Marshal(augmentReq) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "序列化请求失败"}) |
|
|
return |
|
|
} |
|
|
|
|
|
|
|
|
req, err := http.NewRequest("POST", tenant+"chat-stream", strings.NewReader(string(jsonData))) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "创建请求失败"}) |
|
|
return |
|
|
} |
|
|
|
|
|
req.Header.Set("Content-Type", "application/json") |
|
|
req.Header.Set("Authorization", "Bearer "+token) |
|
|
|
|
|
|
|
|
client := &http.Client{} |
|
|
resp, err := client.Do(req) |
|
|
if err != nil { |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "请求失败: " + err.Error()}) |
|
|
return |
|
|
} |
|
|
defer resp.Body.Close() |
|
|
|
|
|
|
|
|
reader := bufio.NewReader(resp.Body) |
|
|
var fullText string |
|
|
|
|
|
for { |
|
|
line, err := reader.ReadString('\n') |
|
|
if err != nil { |
|
|
if err == io.EOF { |
|
|
break |
|
|
} |
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "读取响应失败: " + err.Error()}) |
|
|
return |
|
|
} |
|
|
|
|
|
line = strings.TrimSpace(line) |
|
|
if line == "" { |
|
|
continue |
|
|
} |
|
|
|
|
|
var augmentResp AugmentResponse |
|
|
if err := json.Unmarshal([]byte(line), &augmentResp); err != nil { |
|
|
continue |
|
|
} |
|
|
|
|
|
fullText += augmentResp.Text |
|
|
|
|
|
if augmentResp.Done { |
|
|
break |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
finishReason := "stop" |
|
|
|
|
|
|
|
|
promptTokens := estimateTokenCount(augmentReq.Message) |
|
|
for _, history := range augmentReq.ChatHistory { |
|
|
promptTokens += estimateTokenCount(history.RequestMessage) |
|
|
promptTokens += estimateTokenCount(history.ResponseText) |
|
|
} |
|
|
completionTokens := estimateTokenCount(fullText) |
|
|
|
|
|
openAIResp := OpenAIResponse{ |
|
|
ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()), |
|
|
Object: "chat.completion", |
|
|
Created: time.Now().Unix(), |
|
|
Model: model, |
|
|
Choices: []Choice{ |
|
|
{ |
|
|
Index: 0, |
|
|
Message: ChatMessage{ |
|
|
Role: "assistant", |
|
|
Content: fullText, |
|
|
}, |
|
|
FinishReason: &finishReason, |
|
|
}, |
|
|
}, |
|
|
Usage: Usage{ |
|
|
PromptTokens: promptTokens, |
|
|
CompletionTokens: completionTokens, |
|
|
TotalTokens: promptTokens + completionTokens, |
|
|
}, |
|
|
} |
|
|
|
|
|
c.JSON(http.StatusOK, openAIResp) |
|
|
} |
|
|
|