// Copyright (c) 2025-2026 libaxuan // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. package utils import ( "bufio" "context" "crypto/rand" "github.com/libaxuan/cursor2api-go/middleware" "github.com/libaxuan/cursor2api-go/models" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "os/exec" "strings" "time" "github.com/gin-gonic/gin" "github.com/sirupsen/logrus" ) // GenerateRandomString 生成指定长度的随机字符串 func GenerateRandomString(length int) string { if length <= 0 { return "" } byteLen := (length + 1) / 2 bytes := make([]byte, byteLen) if _, err := rand.Read(bytes); err != nil { fallback := fmt.Sprintf("%d", time.Now().UnixNano()) if len(fallback) >= length { return fallback[:length] } return fallback } encoded := hex.EncodeToString(bytes) if len(encoded) < length { encoded += GenerateRandomString(length - len(encoded)) } return encoded[:length] } // GenerateChatCompletionID 生成聊天完成ID func GenerateChatCompletionID() string { return "chatcmpl-" + GenerateRandomString(29) } // GenerateResponseID 生成 Responses API 响应ID func GenerateResponseID() string { return "resp_" + GenerateRandomString(24) } // GenerateResponseItemID 生成 Responses API 输出项ID func GenerateResponseItemID(prefix string) string { if prefix == "" { prefix = "item_" } return prefix + GenerateRandomString(24) } // ParseSSELine 解析SSE数据行 func ParseSSELine(line string) string { line = strings.TrimSpace(line) if strings.HasPrefix(line, "data: ") { return strings.TrimSpace(line[6:]) // 去掉 'data: ' 前缀并去除前导空格 } return "" } // WriteSSEEvent 写入SSE事件 func WriteSSEEvent(w http.ResponseWriter, event, data string) error { if event != "" { if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil { return err } } if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { return err } // 刷新缓冲区 if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } return nil } // StreamChatCompletion 处理流式聊天完成 // StreamChatCompletion 处理流式聊天完成 func StreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) { // 设置SSE头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("Access-Control-Allow-Origin", "*") // 生成响应ID responseID := GenerateChatCompletionID() started := false toolCallIndex := 0 writeChunk := func(delta models.StreamDelta, finishReason *string) { streamResp := models.NewChatCompletionStreamResponse(responseID, modelName, delta, finishReason) if jsonData, err := json.Marshal(streamResp); err == nil { WriteSSEEvent(c.Writer, "", string(jsonData)) } } // 处理流式数据 ctx := c.Request.Context() for { select { case <-ctx.Done(): logrus.Debug("Client disconnected during streaming") return case data, ok := <-chatGenerator: if !ok { // 通道关闭,发送完成事件 reason := "stop" if toolCallIndex > 0 { reason = "tool_calls" } writeChunk(models.StreamDelta{}, stringPtr(reason)) WriteSSEEvent(c.Writer, "", "[DONE]") return } switch v := data.(type) { case models.AssistantEvent: if !started { writeChunk(models.StreamDelta{Role: "assistant"}, nil) started = true } switch v.Kind { case models.AssistantEventText: if v.Text != "" { writeChunk(models.StreamDelta{Content: v.Text}, nil) } case models.AssistantEventToolCall: if v.ToolCall != nil { writeChunk(models.StreamDelta{ ToolCalls: []models.ToolCallDelta{ { Index: toolCallIndex, ID: v.ToolCall.ID, Type: v.ToolCall.Type, Function: &models.FunctionCallDelta{ Name: v.ToolCall.Function.Name, Arguments: v.ToolCall.Function.Arguments, }, }, }, }, nil) toolCallIndex++ } } case string: if !started { writeChunk(models.StreamDelta{Role: "assistant"}, nil) started = true } if v != "" { writeChunk(models.StreamDelta{Content: v}, nil) } case models.Usage: // 使用统计 - 通常在最后发送 continue case error: logrus.WithError(v).Error("Stream generator error") WriteSSEEvent(c.Writer, "", "[DONE]") return default: logrus.Warnf("Unknown data type in stream: %T", v) } } } } // NonStreamChatCompletion 处理非流式聊天完成 func NonStreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) { var fullContent strings.Builder var usage models.Usage toolCalls := make([]models.ToolCall, 0, 2) finishReason := "stop" // 收集所有数据 ctx := c.Request.Context() for { select { case <-ctx.Done(): c.JSON(http.StatusRequestTimeout, models.NewErrorResponse( "Request timeout", "timeout_error", "request_timeout", )) return case data, ok := <-chatGenerator: if !ok { // 数据收集完成,返回响应 responseID := GenerateChatCompletionID() message := models.Message{ Role: "assistant", } if fullContent.Len() > 0 || len(toolCalls) == 0 { message.Content = fullContent.String() } if len(toolCalls) > 0 { message.ToolCalls = toolCalls finishReason = "tool_calls" } response := models.NewChatCompletionResponse( responseID, modelName, message, finishReason, usage, ) c.JSON(http.StatusOK, response) return } switch v := data.(type) { case models.AssistantEvent: switch v.Kind { case models.AssistantEventText: fullContent.WriteString(v.Text) case models.AssistantEventToolCall: if v.ToolCall != nil { toolCalls = append(toolCalls, *v.ToolCall) } } case string: fullContent.WriteString(v) case models.Usage: usage = v case error: middleware.HandleError(c, v) return } } } } // ErrorWrapper 错误包装器 func ErrorWrapper(handler func(*gin.Context) error) gin.HandlerFunc { return func(c *gin.Context) { if err := handler(c); err != nil { logrus.WithError(err).Error("Handler error") if !c.Writer.Written() { c.JSON(http.StatusInternalServerError, models.NewErrorResponse( "Internal server error", "internal_error", "", )) } } } } // SafeStreamWrapper 安全流式包装器 func SafeStreamWrapper(handler func(*gin.Context, <-chan interface{}, string), c *gin.Context, chatGenerator <-chan interface{}, modelName string) { defer func() { if r := recover(); r != nil { logrus.WithField("panic", r).Error("Panic in stream handler") if !c.Writer.Written() { c.JSON(http.StatusInternalServerError, models.NewErrorResponse( "Internal server error", "panic_error", "", )) } } }() firstItem, ok := <-chatGenerator if !ok { middleware.HandleError(c, middleware.NewCursorWebError(http.StatusInternalServerError, "empty stream")) return } if err, isErr := firstItem.(error); isErr { middleware.HandleError(c, err) return } buffered := make(chan interface{}, 1) buffered <- firstItem ctx := c.Request.Context() go func() { defer close(buffered) for { select { case <-ctx.Done(): return case item, ok := <-chatGenerator: if !ok { return } select { case buffered <- item: case <-ctx.Done(): return } } } }() handler(c, buffered, modelName) } // CreateHTTPClient 创建HTTP客户端 func CreateHTTPClient(timeout time.Duration) *http.Client { return &http.Client{ Timeout: timeout, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, MaxIdleConns: 100, MaxIdleConnsPerHost: 10, IdleConnTimeout: 90 * time.Second, }, } } // ReadSSEStream 读取SSE流 func ReadSSEStream(ctx context.Context, resp *http.Response, output chan<- interface{}) error { scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) defer resp.Body.Close() for scanner.Scan() { select { case <-ctx.Done(): return ctx.Err() default: } line := scanner.Text() data := ParseSSELine(line) if data == "" { continue } if data == "[DONE]" { return nil } // 尝试解析JSON数据 var eventData models.CursorEventData if err := json.Unmarshal([]byte(data), &eventData); err != nil { logrus.WithError(err).Debugf("Failed to parse SSE data: %s", data) continue } // 处理不同类型的事件 switch eventData.Type { case "error": if eventData.ErrorText != "" { return fmt.Errorf("cursor API error: %s", eventData.ErrorText) } case "finish": if eventData.MessageMetadata != nil && eventData.MessageMetadata.Usage != nil { usage := models.Usage{ PromptTokens: eventData.MessageMetadata.Usage.InputTokens, CompletionTokens: eventData.MessageMetadata.Usage.OutputTokens, TotalTokens: eventData.MessageMetadata.Usage.TotalTokens, } output <- usage } return nil default: if eventData.Delta != "" { output <- eventData.Delta } } } return scanner.Err() } // ValidateModel 验证模型名称 func ValidateModel(model string, validModels []string) bool { for _, validModel := range validModels { if validModel == model { return true } } return false } // SanitizeContent 清理内容 func SanitizeContent(content string) string { // 移除可能的恶意内容 content = strings.ReplaceAll(content, "\x00", "") return content } // stringPtr 返回字符串指针 func stringPtr(s string) *string { return &s } // CopyHeaders 复制HTTP头 func CopyHeaders(dst, src http.Header, skipHeaders []string) { skipMap := make(map[string]bool) for _, header := range skipHeaders { skipMap[strings.ToLower(header)] = true } for key, values := range src { if skipMap[strings.ToLower(key)] { continue } for _, value := range values { dst.Add(key, value) } } } // IsJSONContentType 检查是否为JSON内容类型 func IsJSONContentType(contentType string) bool { return strings.Contains(strings.ToLower(contentType), "application/json") } // ReadRequestBody 读取请求体 func ReadRequestBody(r *http.Request) ([]byte, error) { if r.Body == nil { return nil, nil } body, err := io.ReadAll(r.Body) if err != nil { return nil, fmt.Errorf("failed to read request body: %w", err) } return body, nil } // RunJS 执行JavaScript代码并返回标准输出内容 func RunJS(jsCode string) (string, error) { // 添加crypto模块导入并设置为全局变量 // 注意:使用stdin时,我们需要确保代码是自包含的 finalJS := `const crypto = require('crypto').webcrypto; global.crypto = crypto; globalThis.crypto = crypto; // 在Node.js环境中创建window对象 if (typeof window === 'undefined') { global.window = global; } window.crypto = crypto; this.crypto = crypto; ` + jsCode // 执行Node.js命令,使用stdin输入代码 cmd := exec.Command("node") // 设置输入 cmd.Stdin = strings.NewReader(finalJS) output, err := cmd.Output() if err != nil { if exitErr, ok := err.(*exec.ExitError); ok { return "", fmt.Errorf("node.js execution failed (exit code: %d)\nSTDOUT:\n%s\nSTDERR:\n%s", exitErr.ExitCode(), string(output), string(exitErr.Stderr)) } return "", fmt.Errorf("failed to execute node.js: %w", err) } return strings.TrimSpace(string(output)), nil }