FEOP / sse.go
KaThaNg's picture
Upload 12 files
a0e64e6 verified
package main
import (
"bufio"
"encoding/json"
"fmt"
"log"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// streamOpenAIResponseToClaudeSSE xử lý chuyển đổi streaming SSE từ OpenAI sang Claude
func streamOpenAIResponseToClaudeSSE(
c *gin.Context,
upstreamResp *http.Response,
claudeRequestID string,
requestedModel string, // Model name determined in handler
originalClaudeRequest *ClaudeRequest,
) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Content-Type-Options", "nosniff")
c.Writer.Flush()
messageID := claudeRequestID
accumulatedContent := ""
var openAIFinishReason *string
streamErrorOccurred := false
var errorDetails *ClaudeError
startTime := time.Now()
// Sửa lỗi: Sử dụng hàm ước tính token đã được cập nhật
calculatedInputTokens := estimateTokensFromClaudeRequest(originalClaudeRequest)
log.Printf("INFO: [%s] (OpenAI->Claude SSE) Bắt đầu SSE. Input tokens ước tính: %d. Model: %s", messageID, calculatedInputTokens, requestedModel)
// inputTokens from OpenAI stream can update this if OpenAI provides it
currentOpenAIInputTokens := calculatedInputTokens
outputTokens := 0
finalUsageReceivedFromStream := false // Flag to check if usage is explicitly received from OpenAI stream
eventIndex := 0
doneChan := make(chan struct{})
errChan := make(chan error, 1)
go func() {
defer close(doneChan)
defer upstreamResp.Body.Close()
scanner := bufio.NewScanner(upstreamResp.Body)
for scanner.Scan() {
select {
case <-c.Request.Context().Done():
log.Printf("INFO: [%s] (OpenAI->Claude SSE) SSE: Client ngắt kết nối trong vòng lặp đọc.", messageID)
return
default:
}
line := scanner.Text()
if line == "" {
continue
}
if strings.HasPrefix(line, "data:") {
dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
if dataStr == "[DONE]" {
log.Printf("INFO: [%s] (OpenAI->Claude SSE) SSE: Nhận được dấu hiệu [DONE].", messageID)
return // OpenAI stream finished
}
var chunk OpenAIStreamChunk
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
log.Printf("WARN: [%s] (OpenAI->Claude SSE) SSE: Không thể giải mã chunk JSON OpenAI: %v. Data: %s", messageID, err, dataStr)
continue
}
if len(chunk.Choices) > 0 {
choice := chunk.Choices[0]
if choice.FinishReason != nil {
openAIFinishReason = choice.FinishReason
}
if choice.Delta.Content != nil {
contentChunk := *choice.Delta.Content
accumulatedContent += contentChunk
// Gửi content_block_delta cho client
deltaPayload := ClaudeSSEEvent{
Type: "content_block_delta",
Index: func() *int { i := 0; return &i }(), // Claude spec: index of the content block
Delta: &ClaudeSSEDelta{
Type: "text_delta",
Text: &contentChunk,
},
}
if !sendSSEEvent(c, "content_block_delta", deltaPayload, messageID, eventIndex, false, "(OpenAI->Claude SSE)") {
return // Client disconnected
}
eventIndex++
// "Hack" để gửi message_delta với output_tokens cập nhật
currentOutputTokens := estimateTokens(accumulatedContent)
if currentOutputTokens != outputTokens {
outputTokens = currentOutputTokens
intermediateUsage := ClaudeSSEUsage{OutputTokens: outputTokens}
intermediateDeltaPayload := ClaudeSSEEvent{
Type: "message_delta",
Delta: &ClaudeSSEDelta{}, // Delta rỗng cho loại event này
Usage: &intermediateUsage,
}
if !sendSSEEvent(c, "message_delta", intermediateDeltaPayload, messageID, eventIndex, false, "(OpenAI->Claude SSE)") {
return // Client disconnected
}
eventIndex++
}
}
}
if chunk.Usage != nil { // Some OpenAI versions might send usage in chunks
if chunk.Usage.CompletionTokens > 0 {
outputTokens = chunk.Usage.CompletionTokens // Use OpenAI's count if available
finalUsageReceivedFromStream = true
}
if chunk.Usage.PromptTokens > 0 {
currentOpenAIInputTokens = chunk.Usage.PromptTokens // Update input tokens if provided by OpenAI
}
}
}
}
if err := scanner.Err(); err != nil {
select {
case <-c.Request.Context().Done():
log.Printf("INFO: [%s] (OpenAI->Claude SSE) SSE: Đọc upstream bị gián đoạn bởi client ngắt kết nối: %v", messageID, c.Request.Context().Err())
default:
log.Printf("ERROR: [%s] (OpenAI->Claude SSE) SSE: Lỗi đọc nội dung phản hồi upstream: %v", messageID, err)
errChan <- fmt.Errorf("lỗi đọc upstream: %w", err)
}
}
}()
// Gửi message_start
startUsage := ClaudeUsage{InputTokens: calculatedInputTokens, OutputTokens: 0}
startMessage := ClaudeSSEMessage{ID: messageID, Type: "message", Role: "assistant", Content: []ClaudeContentBlock{}, Model: requestedModel, Usage: startUsage}
startEvent := ClaudeSSEEvent{Type: "message_start", Message: &startMessage}
if !sendSSEEvent(c, "message_start", startEvent, messageID, eventIndex, true, "(OpenAI->Claude SSE)") {
return
}
eventIndex++
// Gửi content_block_start
contentStartBlock := ClaudeSSEContentBlock{Type: "text", Text: ""} // Text ban đầu rỗng
contentStartEvent := ClaudeSSEEvent{Type: "content_block_start", Index: func() *int { i := 0; return &i }(), ContentBlock: &contentStartBlock}
if !sendSSEEvent(c, "content_block_start", contentStartEvent, messageID, eventIndex, true, "(OpenAI->Claude SSE)") {
return
}
eventIndex++
// Chờ goroutine đọc xong hoặc có lỗi
select {
case <-doneChan:
// Normal completion
case err := <-errChan:
log.Printf("ERROR: [%s] (OpenAI->Claude SSE) SSE: Nhận lỗi từ goroutine đọc: %v", messageID, err)
streamErrorOccurred = true
errorDetails = &ClaudeError{Type: "api_error", Message: fmt.Sprintf("Lỗi đọc phản hồi upstream: %v", err)}
case <-c.Request.Context().Done():
log.Printf("INFO: [%s] (OpenAI->Claude SSE) SSE: Client ngắt kết nối trong quá trình xử lý stream: %v", messageID, c.Request.Context().Err())
streamErrorOccurred = true
errorDetails = &ClaudeError{Type: "client_disconnect", Message: "Client ngắt kết nối trong quá trình stream"}
}
var claudeStopReason string
if streamErrorOccurred && errorDetails != nil && errorDetails.Type == "client_disconnect" {
claudeStopReason = "client_disconnect"
} else if streamErrorOccurred {
claudeStopReason = "error"
} else {
claudeStopReason = mapOpenAIFinishReasonToClaude(openAIFinishReason)
}
finalInputTokens := currentOpenAIInputTokens // Use potentially updated input tokens from OpenAI stream
finalOutputTokens := outputTokens
if !finalUsageReceivedFromStream { // If OpenAI didn't send final usage, estimate from accumulated content
finalOutputTokens = estimateTokens(accumulatedContent)
}
finalOutputTokens = max(0, finalOutputTokens)
finalInputTokens = max(0, finalInputTokens)
log.Printf("INFO: [%s] (OpenAI->Claude SSE) SSE Stream hoàn tất. Lý do dừng: %s. Input: %d, Output: %d. Thời gian: %v. OpenAI Finish Reason: %s",
messageID, claudeStopReason, finalInputTokens, finalOutputTokens, time.Since(startTime), safeDeref(openAIFinishReason))
// "Hack" message_delta cuối cùng với stop_reason và usage (chỉ output_tokens)
finalHackUsageData := ClaudeSSEUsage{OutputTokens: finalOutputTokens}
finalDeltaStopReason := claudeStopReason
priorityFinalDeltaPayload := ClaudeSSEEvent{
Type: "message_delta",
Delta: &ClaudeSSEDelta{
StopReason: &finalDeltaStopReason,
StopSequence: nil,
},
Usage: &finalHackUsageData,
}
_ = sendSSEEvent(c, "message_delta", priorityFinalDeltaPayload, messageID, eventIndex, false, "(OpenAI->Claude SSE)")
eventIndex++
// Gửi content_block_stop
contentStopPayload := ClaudeSSEEvent{Type: "content_block_stop", Index: func() *int { i := 0; return &i }()}
_ = sendSSEEvent(c, "content_block_stop", contentStopPayload, messageID, eventIndex, false, "(OpenAI->Claude SSE)")
eventIndex++
// Gửi message_stop với usage đầy đủ
finalStopUsageData := ClaudeSSEUsage{InputTokens: &finalInputTokens, OutputTokens: finalOutputTokens}
messageStopPayload := ClaudeSSEEvent{Type: "message_stop", Usage: &finalStopUsageData}
_ = sendSSEEvent(c, "message_stop", messageStopPayload, messageID, eventIndex, true, "(OpenAI->Claude SSE)")
eventIndex++
if streamErrorOccurred && errorDetails != nil && errorDetails.Type != "client_disconnect" {
errorPayload := ClaudeSSEEvent{Type: "error", Error: errorDetails}
_ = sendSSEEvent(c, "error", errorPayload, messageID, eventIndex, true, "(OpenAI->Claude SSE)")
}
}
// proxyClaudeNativeSSE xử lý SSE từ upstream Claude native và **thêm "hack" usage**
func proxyClaudeNativeSSE(
c *gin.Context,
upstreamResp *http.Response,
claudeRequestID string,
originalClaudeRequest *ClaudeRequest, // Cần để tính input tokens và lấy model
) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("X-Content-Type-Options", "nosniff")
if anthropicVersion := upstreamResp.Header.Get("anthropic-version"); anthropicVersion != "" {
c.Writer.Header().Set("anthropic-version", anthropicVersion)
}
c.Writer.Flush()
messageID := claudeRequestID
requestedModel := originalClaudeRequest.Model
accumulatedContent := ""
streamErrorOccurred := false
var errorDetails *ClaudeError
startTime := time.Now()
// Sửa lỗi: Sử dụng hàm ước tính token đã được cập nhật
calculatedInputTokens := estimateTokensFromClaudeRequest(originalClaudeRequest)
log.Printf("INFO: [%s] (Claude Native SSE with Hack) Bắt đầu SSE. Input tokens ước tính: %d. Model: %s", messageID, calculatedInputTokens, requestedModel)
outputTokens := 0
eventIndex := 0
var upstreamFinalStopReason string
var upstreamFinalUsage *ClaudeUsage // This is of type *ClaudeUsage
doneChan := make(chan struct{})
errChan := make(chan error, 1)
// Gửi message_start cho client (tự tạo)
startUsage := ClaudeUsage{InputTokens: calculatedInputTokens, OutputTokens: 0}
startMessage := ClaudeSSEMessage{ID: messageID, Type: "message", Role: "assistant", Content: []ClaudeContentBlock{}, Model: requestedModel, Usage: startUsage}
startEvent := ClaudeSSEEvent{Type: "message_start", Message: &startMessage}
if !sendSSEEvent(c, "message_start", startEvent, messageID, eventIndex, true, "(Claude Native SSE with Hack)") {
return // Client disconnected
}
eventIndex++
// Gửi content_block_start cho client (tự tạo)
contentStartBlock := ClaudeSSEContentBlock{Type: "text", Text: ""}
contentStartEvent := ClaudeSSEEvent{Type: "content_block_start", Index: func() *int { i := 0; return &i }(), ContentBlock: &contentStartBlock}
if !sendSSEEvent(c, "content_block_start", contentStartEvent, messageID, eventIndex, true, "(Claude Native SSE with Hack)") {
return // Client disconnected
}
eventIndex++
go func() {
defer close(doneChan)
defer upstreamResp.Body.Close()
scanner := bufio.NewScanner(upstreamResp.Body)
for scanner.Scan() {
select {
case <-c.Request.Context().Done():
log.Printf("INFO: [%s] (Claude Native SSE with Hack) Client ngắt kết nối trong vòng lặp đọc.", messageID)
return
default:
}
line := scanner.Text()
if line == "" { // Skip empty lines between events
continue
}
if strings.HasPrefix(line, "data:") {
dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
var upstreamEvent ClaudeSSEEvent
if err := json.Unmarshal([]byte(dataStr), &upstreamEvent); err != nil {
log.Printf("WARN: [%s] (Claude Native SSE with Hack) Không thể giải mã chunk JSON Claude native: %v. Data: %s", messageID, err, dataStr)
continue
}
switch upstreamEvent.Type {
case "content_block_delta":
if upstreamEvent.Delta != nil && upstreamEvent.Delta.Text != nil {
textChunk := *upstreamEvent.Delta.Text
accumulatedContent += textChunk
clientContentDelta := ClaudeSSEEvent{
Type: "content_block_delta",
Index: upstreamEvent.Index,
Delta: &ClaudeSSEDelta{
Type: "text_delta",
Text: &textChunk,
},
}
if clientContentDelta.Index == nil {
clientContentDelta.Index = func() *int { i := 0; return &i }()
}
if !sendSSEEvent(c, "content_block_delta", clientContentDelta, messageID, eventIndex, false, "(Claude Native SSE with Hack)") {
return
}
eventIndex++
currentOutputTokens := estimateTokens(accumulatedContent)
if currentOutputTokens != outputTokens {
outputTokens = currentOutputTokens
intermediateUsage := ClaudeSSEUsage{OutputTokens: outputTokens}
intermediateDeltaPayload := ClaudeSSEEvent{
Type: "message_delta",
Delta: &ClaudeSSEDelta{},
Usage: &intermediateUsage,
}
if !sendSSEEvent(c, "message_delta", intermediateDeltaPayload, messageID, eventIndex, false, "(Claude Native SSE with Hack)") {
return
}
eventIndex++
}
}
case "message_delta":
if upstreamEvent.Usage != nil { // upstreamEvent.Usage is *ClaudeSSEUsage
if upstreamEvent.Usage.OutputTokens > outputTokens {
outputTokens = upstreamEvent.Usage.OutputTokens
}
currentHackUsage := ClaudeSSEUsage{OutputTokens: outputTokens} // Use updated outputTokens
// Preserve stop_reason/sequence from upstream's message_delta if present
var deltaDetails ClaudeSSEDelta
if upstreamEvent.Delta != nil {
deltaDetails.StopReason = upstreamEvent.Delta.StopReason
deltaDetails.StopSequence = upstreamEvent.Delta.StopSequence
}
hackMessageDelta := ClaudeSSEEvent{
Type: "message_delta",
Delta: &deltaDetails,
Usage: &currentHackUsage,
}
if !sendSSEEvent(c, "message_delta", hackMessageDelta, messageID, eventIndex, false, "(Claude Native SSE with Hack - from upstream delta)") {
return
}
eventIndex++
}
// Capture stop reason if it's in this message_delta
if upstreamEvent.Delta != nil && upstreamEvent.Delta.StopReason != nil {
upstreamFinalStopReason = *upstreamEvent.Delta.StopReason
}
case "message_stop":
log.Printf("INFO: [%s] (Claude Native SSE with Hack) Nhận message_stop từ upstream.", messageID)
if upstreamEvent.Message != nil && upstreamEvent.Message.StopReason != nil {
upstreamFinalStopReason = *upstreamEvent.Message.StopReason
}
if upstreamEvent.Usage != nil { // upstreamEvent.Usage is *ClaudeSSEUsage
// Convert ClaudeSSEUsage to ClaudeUsage for upstreamFinalUsage
tempUsage := &ClaudeUsage{
OutputTokens: upstreamEvent.Usage.OutputTokens,
}
if upstreamEvent.Usage.InputTokens != nil {
tempUsage.InputTokens = *upstreamEvent.Usage.InputTokens
}
// Assign the converted *ClaudeUsage to upstreamFinalUsage
upstreamFinalUsage = tempUsage
if upstreamEvent.Usage.OutputTokens > outputTokens {
outputTokens = upstreamEvent.Usage.OutputTokens
}
}
return
case "error":
log.Printf("ERROR: [%s] (Claude Native SSE with Hack) Nhận lỗi từ upstream: %+v", messageID, upstreamEvent.Error)
streamErrorOccurred = true
errorDetails = upstreamEvent.Error
return
case "ping":
pingEvent := ClaudeSSEEvent{Type: "ping"}
if !sendSSEEvent(c, "ping", pingEvent, messageID, eventIndex, false, "(Claude Native SSE with Hack - ping)") {
return
}
eventIndex++
}
}
}
if err := scanner.Err(); err != nil {
select {
case <-c.Request.Context().Done():
log.Printf("INFO: [%s] (Claude Native SSE with Hack) Đọc upstream bị gián đoạn bởi client ngắt kết nối: %v", messageID, c.Request.Context().Err())
default:
log.Printf("ERROR: [%s] (Claude Native SSE with Hack) Lỗi đọc nội dung phản hồi upstream: %v", messageID, err)
errChan <- fmt.Errorf("lỗi đọc upstream: %w", err)
}
}
}()
select {
case <-doneChan:
// Normal completion
case err := <-errChan:
log.Printf("ERROR: [%s] (Claude Native SSE with Hack) Nhận lỗi từ goroutine đọc: %v", messageID, err)
streamErrorOccurred = true
if errorDetails == nil {
errorDetails = &ClaudeError{Type: "api_error", Message: fmt.Sprintf("Lỗi đọc phản hồi upstream: %v", err)}
}
case <-c.Request.Context().Done():
log.Printf("INFO: [%s] (Claude Native SSE with Hack) Client ngắt kết nối trong quá trình xử lý stream: %v", messageID, c.Request.Context().Err())
streamErrorOccurred = true
if errorDetails == nil {
errorDetails = &ClaudeError{Type: "client_disconnect", Message: "Client ngắt kết nối trong quá trình stream"}
}
}
var finalStopReasonClient string
if streamErrorOccurred && errorDetails != nil && errorDetails.Type == "client_disconnect" {
finalStopReasonClient = "client_disconnect"
} else if streamErrorOccurred {
finalStopReasonClient = "error"
} else if upstreamFinalStopReason != "" {
finalStopReasonClient = upstreamFinalStopReason
} else {
finalStopReasonClient = "end_turn"
}
finalOutputTokens := outputTokens
if upstreamFinalUsage != nil && upstreamFinalUsage.OutputTokens > finalOutputTokens { // upstreamFinalUsage is *ClaudeUsage
finalOutputTokens = upstreamFinalUsage.OutputTokens
} else if accumulatedContent != "" { // Only estimate if content was accumulated and no definitive upstream usage
finalOutputTokens = estimateTokens(accumulatedContent)
}
finalOutputTokens = max(0, finalOutputTokens)
log.Printf("INFO: [%s] (Claude Native SSE with Hack) SSE Stream hoàn tất. Lý do dừng: %s. Input: %d, Output: %d. Thời gian: %v. Upstream Stop Reason: %s",
messageID, finalStopReasonClient, calculatedInputTokens, finalOutputTokens, time.Since(startTime), upstreamFinalStopReason)
finalHackUsageData := ClaudeSSEUsage{OutputTokens: finalOutputTokens}
finalDeltaStopReason := finalStopReasonClient
priorityFinalDeltaPayload := ClaudeSSEEvent{
Type: "message_delta",
Delta: &ClaudeSSEDelta{
StopReason: &finalDeltaStopReason,
StopSequence: nil,
},
Usage: &finalHackUsageData,
}
_ = sendSSEEvent(c, "message_delta", priorityFinalDeltaPayload, messageID, eventIndex, false, "(Claude Native SSE with Hack)")
eventIndex++
contentStopPayload := ClaudeSSEEvent{Type: "content_block_stop", Index: func() *int { i := 0; return &i }()}
_ = sendSSEEvent(c, "content_block_stop", contentStopPayload, messageID, eventIndex, false, "(Claude Native SSE with Hack)")
eventIndex++
finalClientInputTokens := calculatedInputTokens // Use the initially calculated input tokens for consistency
// If upstream provided definitive input tokens in its final usage, we could consider using it,
// but calculatedInputTokens is based on the original request, which is reliable for the client's perspective.
if upstreamFinalUsage != nil && upstreamFinalUsage.InputTokens > 0 {
// Potentially use upstreamFinalUsage.InputTokens if it's considered more accurate
// For now, sticking to calculatedInputTokens for the client-facing message_stop.
}
finalStopUsageData := ClaudeSSEUsage{InputTokens: &finalClientInputTokens, OutputTokens: finalOutputTokens}
messageStopPayload := ClaudeSSEEvent{Type: "message_stop", Usage: &finalStopUsageData}
_ = sendSSEEvent(c, "message_stop", messageStopPayload, messageID, eventIndex, true, "(Claude Native SSE with Hack)")
eventIndex++
if streamErrorOccurred && errorDetails != nil && errorDetails.Type != "client_disconnect" {
errorPayload := ClaudeSSEEvent{Type: "error", Error: errorDetails}
_ = sendSSEEvent(c, "error", errorPayload, messageID, eventIndex, true, "(Claude Native SSE with Hack)")
}
}
// sendSSEEvent helper function (logPrefix added for clarity)
func sendSSEEvent(c *gin.Context, eventName string, data interface{}, requestID string, eventIndex int, shouldLog bool, logPrefix string) bool {
select {
case <-c.Request.Context().Done():
if shouldLog || eventName == "message_stop" || eventName == "error" {
log.Printf("INFO: [%s] %s Client ngắt kết nối trước khi gửi SSE event %d (%s).", requestID, logPrefix, eventIndex, eventName)
}
return false
default:
jsonData, err := json.Marshal(data)
if err != nil {
log.Printf("ERROR: [%s] %s Không thể marshal SSE event %d (%s): %v", requestID, logPrefix, eventIndex, eventName, err)
return true
}
_, err = fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", eventName, string(jsonData))
if err != nil {
if shouldLog || eventName == "message_stop" || eventName == "error" {
log.Printf("WARN: [%s] %s Không thể ghi SSE event %d (%s) cho client: %v. Client có thể đã ngắt kết nối.", requestID, logPrefix, eventIndex, err)
}
return false
}
c.Writer.Flush()
return true
}
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func safeDeref(s *string) string {
if s == nil {
return "nil"
}
return *s
}