| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| package utils |
|
|
| import ( |
| "bufio" |
| "context" |
| "crypto/rand" |
| "github.com/libaxuan/cursor2api-go/middleware" |
| "github.com/libaxuan/cursor2api-go/models" |
| "encoding/hex" |
| "encoding/json" |
| "fmt" |
| "io" |
| "net/http" |
| "os/exec" |
| "strings" |
| "time" |
|
|
| "github.com/gin-gonic/gin" |
| "github.com/sirupsen/logrus" |
| ) |
|
|
| |
| func GenerateRandomString(length int) string { |
| if length <= 0 { |
| return "" |
| } |
|
|
| byteLen := (length + 1) / 2 |
| bytes := make([]byte, byteLen) |
| if _, err := rand.Read(bytes); err != nil { |
| fallback := fmt.Sprintf("%d", time.Now().UnixNano()) |
| if len(fallback) >= length { |
| return fallback[:length] |
| } |
| return fallback |
| } |
|
|
| encoded := hex.EncodeToString(bytes) |
| if len(encoded) < length { |
| encoded += GenerateRandomString(length - len(encoded)) |
| } |
|
|
| return encoded[:length] |
| } |
|
|
| |
| func GenerateChatCompletionID() string { |
| return "chatcmpl-" + GenerateRandomString(29) |
| } |
|
|
| |
| func GenerateResponseID() string { |
| return "resp_" + GenerateRandomString(24) |
| } |
|
|
| |
| func GenerateResponseItemID(prefix string) string { |
| if prefix == "" { |
| prefix = "item_" |
| } |
| return prefix + GenerateRandomString(24) |
| } |
|
|
| |
| func ParseSSELine(line string) string { |
| line = strings.TrimSpace(line) |
| if strings.HasPrefix(line, "data: ") { |
| return strings.TrimSpace(line[6:]) |
| } |
| return "" |
| } |
|
|
| |
| func WriteSSEEvent(w http.ResponseWriter, event, data string) error { |
| if event != "" { |
| if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil { |
| return err |
| } |
| } |
| if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil { |
| return err |
| } |
|
|
| |
| if flusher, ok := w.(http.Flusher); ok { |
| flusher.Flush() |
| } |
|
|
| return nil |
| } |
|
|
| |
| |
| func StreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) { |
| |
| c.Header("Content-Type", "text/event-stream") |
| c.Header("Cache-Control", "no-cache") |
| c.Header("Connection", "keep-alive") |
| c.Header("Access-Control-Allow-Origin", "*") |
|
|
| |
| responseID := GenerateChatCompletionID() |
| started := false |
| toolCallIndex := 0 |
|
|
| writeChunk := func(delta models.StreamDelta, finishReason *string) { |
| streamResp := models.NewChatCompletionStreamResponse(responseID, modelName, delta, finishReason) |
| if jsonData, err := json.Marshal(streamResp); err == nil { |
| WriteSSEEvent(c.Writer, "", string(jsonData)) |
| } |
| } |
|
|
| |
| ctx := c.Request.Context() |
| for { |
| select { |
| case <-ctx.Done(): |
| logrus.Debug("Client disconnected during streaming") |
| return |
|
|
| case data, ok := <-chatGenerator: |
| if !ok { |
| |
| reason := "stop" |
| if toolCallIndex > 0 { |
| reason = "tool_calls" |
| } |
| writeChunk(models.StreamDelta{}, stringPtr(reason)) |
| WriteSSEEvent(c.Writer, "", "[DONE]") |
| return |
| } |
|
|
| switch v := data.(type) { |
| case models.AssistantEvent: |
| if !started { |
| writeChunk(models.StreamDelta{Role: "assistant"}, nil) |
| started = true |
| } |
|
|
| switch v.Kind { |
| case models.AssistantEventText: |
| if v.Text != "" { |
| writeChunk(models.StreamDelta{Content: v.Text}, nil) |
| } |
| case models.AssistantEventToolCall: |
| if v.ToolCall != nil { |
| writeChunk(models.StreamDelta{ |
| ToolCalls: []models.ToolCallDelta{ |
| { |
| Index: toolCallIndex, |
| ID: v.ToolCall.ID, |
| Type: v.ToolCall.Type, |
| Function: &models.FunctionCallDelta{ |
| Name: v.ToolCall.Function.Name, |
| Arguments: v.ToolCall.Function.Arguments, |
| }, |
| }, |
| }, |
| }, nil) |
| toolCallIndex++ |
| } |
| } |
| case string: |
| if !started { |
| writeChunk(models.StreamDelta{Role: "assistant"}, nil) |
| started = true |
| } |
| if v != "" { |
| writeChunk(models.StreamDelta{Content: v}, nil) |
| } |
|
|
| case models.Usage: |
| |
| continue |
|
|
| case error: |
| logrus.WithError(v).Error("Stream generator error") |
| WriteSSEEvent(c.Writer, "", "[DONE]") |
| return |
|
|
| default: |
| logrus.Warnf("Unknown data type in stream: %T", v) |
| } |
| } |
| } |
| } |
|
|
| |
| func NonStreamChatCompletion(c *gin.Context, chatGenerator <-chan interface{}, modelName string) { |
| var fullContent strings.Builder |
| var usage models.Usage |
| toolCalls := make([]models.ToolCall, 0, 2) |
| finishReason := "stop" |
|
|
| |
| ctx := c.Request.Context() |
| for { |
| select { |
| case <-ctx.Done(): |
| c.JSON(http.StatusRequestTimeout, models.NewErrorResponse( |
| "Request timeout", |
| "timeout_error", |
| "request_timeout", |
| )) |
| return |
|
|
| case data, ok := <-chatGenerator: |
| if !ok { |
| |
| responseID := GenerateChatCompletionID() |
| message := models.Message{ |
| Role: "assistant", |
| } |
| if fullContent.Len() > 0 || len(toolCalls) == 0 { |
| message.Content = fullContent.String() |
| } |
| if len(toolCalls) > 0 { |
| message.ToolCalls = toolCalls |
| finishReason = "tool_calls" |
| } |
| response := models.NewChatCompletionResponse( |
| responseID, |
| modelName, |
| message, |
| finishReason, |
| usage, |
| ) |
| c.JSON(http.StatusOK, response) |
| return |
| } |
|
|
| switch v := data.(type) { |
| case models.AssistantEvent: |
| switch v.Kind { |
| case models.AssistantEventText: |
| fullContent.WriteString(v.Text) |
| case models.AssistantEventToolCall: |
| if v.ToolCall != nil { |
| toolCalls = append(toolCalls, *v.ToolCall) |
| } |
| } |
| case string: |
| fullContent.WriteString(v) |
| case models.Usage: |
| usage = v |
| case error: |
| middleware.HandleError(c, v) |
| return |
| } |
| } |
| } |
| } |
|
|
| |
| func ErrorWrapper(handler func(*gin.Context) error) gin.HandlerFunc { |
| return func(c *gin.Context) { |
| if err := handler(c); err != nil { |
| logrus.WithError(err).Error("Handler error") |
|
|
| if !c.Writer.Written() { |
| c.JSON(http.StatusInternalServerError, models.NewErrorResponse( |
| "Internal server error", |
| "internal_error", |
| "", |
| )) |
| } |
| } |
| } |
| } |
|
|
| |
| func SafeStreamWrapper(handler func(*gin.Context, <-chan interface{}, string), c *gin.Context, chatGenerator <-chan interface{}, modelName string) { |
| defer func() { |
| if r := recover(); r != nil { |
| logrus.WithField("panic", r).Error("Panic in stream handler") |
| if !c.Writer.Written() { |
| c.JSON(http.StatusInternalServerError, models.NewErrorResponse( |
| "Internal server error", |
| "panic_error", |
| "", |
| )) |
| } |
| } |
| }() |
|
|
| firstItem, ok := <-chatGenerator |
| if !ok { |
| middleware.HandleError(c, middleware.NewCursorWebError(http.StatusInternalServerError, "empty stream")) |
| return |
| } |
|
|
| if err, isErr := firstItem.(error); isErr { |
| middleware.HandleError(c, err) |
| return |
| } |
|
|
| buffered := make(chan interface{}, 1) |
| buffered <- firstItem |
| ctx := c.Request.Context() |
|
|
| go func() { |
| defer close(buffered) |
| for { |
| select { |
| case <-ctx.Done(): |
| return |
| case item, ok := <-chatGenerator: |
| if !ok { |
| return |
| } |
| select { |
| case buffered <- item: |
| case <-ctx.Done(): |
| return |
| } |
| } |
| } |
| }() |
|
|
| handler(c, buffered, modelName) |
| } |
|
|
| |
| func CreateHTTPClient(timeout time.Duration) *http.Client { |
| return &http.Client{ |
| Timeout: timeout, |
| Transport: &http.Transport{ |
| Proxy: http.ProxyFromEnvironment, |
| MaxIdleConns: 100, |
| MaxIdleConnsPerHost: 10, |
| IdleConnTimeout: 90 * time.Second, |
| }, |
| } |
| } |
|
|
| |
| func ReadSSEStream(ctx context.Context, resp *http.Response, output chan<- interface{}) error { |
| scanner := bufio.NewScanner(resp.Body) |
| scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) |
| defer resp.Body.Close() |
|
|
| for scanner.Scan() { |
| select { |
| case <-ctx.Done(): |
| return ctx.Err() |
| default: |
| } |
|
|
| line := scanner.Text() |
| data := ParseSSELine(line) |
| if data == "" { |
| continue |
| } |
|
|
| if data == "[DONE]" { |
| return nil |
| } |
|
|
| |
| var eventData models.CursorEventData |
| if err := json.Unmarshal([]byte(data), &eventData); err != nil { |
| logrus.WithError(err).Debugf("Failed to parse SSE data: %s", data) |
| continue |
| } |
|
|
| |
| switch eventData.Type { |
| case "error": |
| if eventData.ErrorText != "" { |
| return fmt.Errorf("cursor API error: %s", eventData.ErrorText) |
| } |
|
|
| case "finish": |
| if eventData.MessageMetadata != nil && eventData.MessageMetadata.Usage != nil { |
| usage := models.Usage{ |
| PromptTokens: eventData.MessageMetadata.Usage.InputTokens, |
| CompletionTokens: eventData.MessageMetadata.Usage.OutputTokens, |
| TotalTokens: eventData.MessageMetadata.Usage.TotalTokens, |
| } |
| output <- usage |
| } |
| return nil |
|
|
| default: |
| if eventData.Delta != "" { |
| output <- eventData.Delta |
| } |
| } |
| } |
|
|
| return scanner.Err() |
| } |
|
|
| |
| func ValidateModel(model string, validModels []string) bool { |
| for _, validModel := range validModels { |
| if validModel == model { |
| return true |
| } |
| } |
| return false |
| } |
|
|
| |
| func SanitizeContent(content string) string { |
| |
| content = strings.ReplaceAll(content, "\x00", "") |
| return content |
| } |
|
|
| |
| func stringPtr(s string) *string { |
| return &s |
| } |
|
|
| |
| func CopyHeaders(dst, src http.Header, skipHeaders []string) { |
| skipMap := make(map[string]bool) |
| for _, header := range skipHeaders { |
| skipMap[strings.ToLower(header)] = true |
| } |
|
|
| for key, values := range src { |
| if skipMap[strings.ToLower(key)] { |
| continue |
| } |
| for _, value := range values { |
| dst.Add(key, value) |
| } |
| } |
| } |
|
|
| |
| func IsJSONContentType(contentType string) bool { |
| return strings.Contains(strings.ToLower(contentType), "application/json") |
| } |
|
|
| |
| func ReadRequestBody(r *http.Request) ([]byte, error) { |
| if r.Body == nil { |
| return nil, nil |
| } |
|
|
| body, err := io.ReadAll(r.Body) |
| if err != nil { |
| return nil, fmt.Errorf("failed to read request body: %w", err) |
| } |
|
|
| return body, nil |
| } |
|
|
| |
| func RunJS(jsCode string) (string, error) { |
| |
| |
| finalJS := `const crypto = require('crypto').webcrypto; |
| global.crypto = crypto; |
| globalThis.crypto = crypto; |
| // 在Node.js环境中创建window对象 |
| if (typeof window === 'undefined') { global.window = global; } |
| window.crypto = crypto; |
| this.crypto = crypto; |
| ` + jsCode |
|
|
| |
| cmd := exec.Command("node") |
|
|
| |
| cmd.Stdin = strings.NewReader(finalJS) |
|
|
| output, err := cmd.Output() |
| if err != nil { |
| if exitErr, ok := err.(*exec.ExitError); ok { |
| return "", fmt.Errorf("node.js execution failed (exit code: %d)\nSTDOUT:\n%s\nSTDERR:\n%s", |
| exitErr.ExitCode(), string(output), string(exitErr.Stderr)) |
| } |
| return "", fmt.Errorf("failed to execute node.js: %w", err) |
| } |
|
|
| return strings.TrimSpace(string(output)), nil |
| } |
|
|