cursor / utils /utils.go
cacode's picture
Upload 48 files
1766992 verified
// Copyright (c) 2025-2026 libaxuan
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package utils
import (
"bufio"
"context"
"crypto/rand"
"github.com/libaxuan/cursor2api-go/middleware"
"github.com/libaxuan/cursor2api-go/models"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"os/exec"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)
// GenerateRandomString 生成指定长度的随机字符串
func GenerateRandomString(length int) string {
if length <= 0 {
return ""
}
byteLen := (length + 1) / 2
bytes := make([]byte, byteLen)
if _, err := rand.Read(bytes); err != nil {
fallback := fmt.Sprintf("%d", time.Now().UnixNano())
if len(fallback) >= length {
return fallback[:length]
}
return fallback
}
encoded := hex.EncodeToString(bytes)
if len(encoded) < length {
encoded += GenerateRandomString(length - len(encoded))
}
return encoded[:length]
}
// GenerateChatCompletionID 生成聊天完成ID
func GenerateChatCompletionID() string {
return "chatcmpl-" + GenerateRandomString(29)
}
// GenerateResponseID 生成 Responses API 响应ID
func GenerateResponseID() string {
return "resp_" + GenerateRandomString(24)
}
// GenerateResponseItemID 生成 Responses API 输出项ID
func GenerateResponseItemID(prefix string) string {
if prefix == "" {
prefix = "item_"
}
return prefix + GenerateRandomString(24)
}
// ParseSSELine 解析SSE数据行
func ParseSSELine(line string) string {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data: ") {
return strings.TrimSpace(line[6:]) // 去掉 'data: ' 前缀并去除前导空格
}
return ""
}
// WriteSSEEvent 写入SSE事件
func WriteSSEEvent(w http.ResponseWriter, event, data string) error {
if event != "" {
if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil {
return err
}
}
if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
return err
}
// 刷新缓冲区
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return nil
}
// StreamChatCompletion 处理流式聊天完成
// StreamChatCompletion 处理流式聊天完成
func StreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) {
// 设置SSE头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// 生成响应ID
responseID := GenerateChatCompletionID()
started := false
toolCallIndex := 0
writeChunk := func(delta models.StreamDelta, finishReason *string) {
streamResp := models.NewChatCompletionStreamResponse(responseID, modelName, delta, finishReason)
if jsonData, err := json.Marshal(streamResp); err == nil {
WriteSSEEvent(c.Writer, "", string(jsonData))
}
}
// 处理流式数据
ctx := c.Request.Context()
for {
select {
case <-ctx.Done():
logrus.Debug("Client disconnected during streaming")
return
case data, ok := <-chatGenerator:
if !ok {
// 通道关闭,发送完成事件
reason := "stop"
if toolCallIndex > 0 {
reason = "tool_calls"
}
writeChunk(models.StreamDelta{}, stringPtr(reason))
WriteSSEEvent(c.Writer, "", "[DONE]")
return
}
switch v := data.(type) {
case models.AssistantEvent:
if !started {
writeChunk(models.StreamDelta{Role: "assistant"}, nil)
started = true
}
switch v.Kind {
case models.AssistantEventText:
if v.Text != "" {
writeChunk(models.StreamDelta{Content: v.Text}, nil)
}
case models.AssistantEventToolCall:
if v.ToolCall != nil {
writeChunk(models.StreamDelta{
ToolCalls: []models.ToolCallDelta{
{
Index: toolCallIndex,
ID: v.ToolCall.ID,
Type: v.ToolCall.Type,
Function: &models.FunctionCallDelta{
Name: v.ToolCall.Function.Name,
Arguments: v.ToolCall.Function.Arguments,
},
},
},
}, nil)
toolCallIndex++
}
}
case string:
if !started {
writeChunk(models.StreamDelta{Role: "assistant"}, nil)
started = true
}
if v != "" {
writeChunk(models.StreamDelta{Content: v}, nil)
}
case models.Usage:
// 使用统计 - 通常在最后发送
continue
case error:
logrus.WithError(v).Error("Stream generator error")
WriteSSEEvent(c.Writer, "", "[DONE]")
return
default:
logrus.Warnf("Unknown data type in stream: %T", v)
}
}
}
}
// NonStreamChatCompletion 处理非流式聊天完成
func NonStreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) {
var fullContent strings.Builder
var usage models.Usage
toolCalls := make([]models.ToolCall, 0, 2)
finishReason := "stop"
// 收集所有数据
ctx := c.Request.Context()
for {
select {
case <-ctx.Done():
c.JSON(http.StatusRequestTimeout, models.NewErrorResponse(
"Request timeout",
"timeout_error",
"request_timeout",
))
return
case data, ok := <-chatGenerator:
if !ok {
// 数据收集完成,返回响应
responseID := GenerateChatCompletionID()
message := models.Message{
Role: "assistant",
}
if fullContent.Len() > 0 || len(toolCalls) == 0 {
message.Content = fullContent.String()
}
if len(toolCalls) > 0 {
message.ToolCalls = toolCalls
finishReason = "tool_calls"
}
response := models.NewChatCompletionResponse(
responseID,
modelName,
message,
finishReason,
usage,
)
c.JSON(http.StatusOK, response)
return
}
switch v := data.(type) {
case models.AssistantEvent:
switch v.Kind {
case models.AssistantEventText:
fullContent.WriteString(v.Text)
case models.AssistantEventToolCall:
if v.ToolCall != nil {
toolCalls = append(toolCalls, *v.ToolCall)
}
}
case string:
fullContent.WriteString(v)
case models.Usage:
usage = v
case error:
middleware.HandleError(c, v)
return
}
}
}
}
// ErrorWrapper 错误包装器
func ErrorWrapper(handler func(*gin.Context) error) gin.HandlerFunc {
return func(c *gin.Context) {
if err := handler(c); err != nil {
logrus.WithError(err).Error("Handler error")
if !c.Writer.Written() {
c.JSON(http.StatusInternalServerError, models.NewErrorResponse(
"Internal server error",
"internal_error",
"",
))
}
}
}
}
// SafeStreamWrapper 安全流式包装器
func SafeStreamWrapper(handler func(*gin.Context, <-chan interface{}, string), c *gin.Context, chatGenerator <-chan interface{}, modelName string) {
defer func() {
if r := recover(); r != nil {
logrus.WithField("panic", r).Error("Panic in stream handler")
if !c.Writer.Written() {
c.JSON(http.StatusInternalServerError, models.NewErrorResponse(
"Internal server error",
"panic_error",
"",
))
}
}
}()
firstItem, ok := <-chatGenerator
if !ok {
middleware.HandleError(c, middleware.NewCursorWebError(http.StatusInternalServerError, "empty stream"))
return
}
if err, isErr := firstItem.(error); isErr {
middleware.HandleError(c, err)
return
}
buffered := make(chan interface{}, 1)
buffered <- firstItem
ctx := c.Request.Context()
go func() {
defer close(buffered)
for {
select {
case <-ctx.Done():
return
case item, ok := <-chatGenerator:
if !ok {
return
}
select {
case buffered <- item:
case <-ctx.Done():
return
}
}
}
}()
handler(c, buffered, modelName)
}
// CreateHTTPClient 创建HTTP客户端
func CreateHTTPClient(timeout time.Duration) *http.Client {
return &http.Client{
Timeout: timeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
},
}
}
// ReadSSEStream 读取SSE流
func ReadSSEStream(ctx context.Context, resp *http.Response, output chan<- interface{}) error {
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
defer resp.Body.Close()
for scanner.Scan() {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
line := scanner.Text()
data := ParseSSELine(line)
if data == "" {
continue
}
if data == "[DONE]" {
return nil
}
// 尝试解析JSON数据
var eventData models.CursorEventData
if err := json.Unmarshal([]byte(data), &eventData); err != nil {
logrus.WithError(err).Debugf("Failed to parse SSE data: %s", data)
continue
}
// 处理不同类型的事件
switch eventData.Type {
case "error":
if eventData.ErrorText != "" {
return fmt.Errorf("cursor API error: %s", eventData.ErrorText)
}
case "finish":
if eventData.MessageMetadata != nil && eventData.MessageMetadata.Usage != nil {
usage := models.Usage{
PromptTokens: eventData.MessageMetadata.Usage.InputTokens,
CompletionTokens: eventData.MessageMetadata.Usage.OutputTokens,
TotalTokens: eventData.MessageMetadata.Usage.TotalTokens,
}
output <- usage
}
return nil
default:
if eventData.Delta != "" {
output <- eventData.Delta
}
}
}
return scanner.Err()
}
// ValidateModel 验证模型名称
func ValidateModel(model string, validModels []string) bool {
for _, validModel := range validModels {
if validModel == model {
return true
}
}
return false
}
// SanitizeContent 清理内容
func SanitizeContent(content string) string {
// 移除可能的恶意内容
content = strings.ReplaceAll(content, "\x00", "")
return content
}
// stringPtr 返回字符串指针
func stringPtr(s string) *string {
return &s
}
// CopyHeaders 复制HTTP头
func CopyHeaders(dst, src http.Header, skipHeaders []string) {
skipMap := make(map[string]bool)
for _, header := range skipHeaders {
skipMap[strings.ToLower(header)] = true
}
for key, values := range src {
if skipMap[strings.ToLower(key)] {
continue
}
for _, value := range values {
dst.Add(key, value)
}
}
}
// IsJSONContentType 检查是否为JSON内容类型
func IsJSONContentType(contentType string) bool {
return strings.Contains(strings.ToLower(contentType), "application/json")
}
// ReadRequestBody 读取请求体
func ReadRequestBody(r *http.Request) ([]byte, error) {
if r.Body == nil {
return nil, nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("failed to read request body: %w", err)
}
return body, nil
}
// RunJS 执行JavaScript代码并返回标准输出内容
func RunJS(jsCode string) (string, error) {
// 添加crypto模块导入并设置为全局变量
// 注意:使用stdin时,我们需要确保代码是自包含的
finalJS := `const crypto = require('crypto').webcrypto;
global.crypto = crypto;
globalThis.crypto = crypto;
// 在Node.js环境中创建window对象
if (typeof window === 'undefined') { global.window = global; }
window.crypto = crypto;
this.crypto = crypto;
` + jsCode
// 执行Node.js命令,使用stdin输入代码
cmd := exec.Command("node")
// 设置输入
cmd.Stdin = strings.NewReader(finalJS)
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("node.js execution failed (exit code: %d)\nSTDOUT:\n%s\nSTDERR:\n%s",
exitErr.ExitCode(), string(output), string(exitErr.Stderr))
}
return "", fmt.Errorf("failed to execute node.js: %w", err)
}
return strings.TrimSpace(string(output)), nil
}