KaThaNg commited on
Commit
01d9631
·
verified ·
1 Parent(s): e60212a

Upload 10 files

Browse files
Files changed (10) hide show
  1. Dockerfile +45 -25
  2. auth.go +45 -0
  3. config.go +120 -0
  4. convert.go +236 -0
  5. go.mod +39 -0
  6. go.sum +0 -0
  7. handlers.go +218 -0
  8. main.go +103 -0
  9. sse.go +320 -0
  10. structs.go +183 -0
Dockerfile CHANGED
@@ -1,34 +1,54 @@
1
- # Use an official Python runtime as a parent image
2
- # Using slim variant for smaller image size
3
- FROM python:3.10-slim
4
 
5
- # Set environment variables to prevent Python from writing pyc files and buffering stdout/stderr
6
- ENV PYTHONDONTWRITEBYTECODE 1
7
- ENV PYTHONUNBUFFERED 1
8
- # --- Change default port to 7860 ---
9
- ENV PORT=7860
10
 
11
- # Set the working directory in the container
12
- WORKDIR /app
13
 
14
- # Install system dependencies if needed
15
- # RUN apt-get update && apt-get install -y --no-install-recommends ... && rm -rf /var/lib/apt/lists/*
 
16
 
17
- # Install Python dependencies
18
- RUN pip install --no-cache-dir --upgrade pip
19
- COPY requirements.txt requirements.txt
20
- RUN pip install --no-cache-dir -r requirements.txt
21
 
22
- # Copy the rest of the application code
23
  COPY . .
24
 
25
- # Create a non-root user and switch to it
26
- RUN useradd --create-home --uid 1001 appuser
27
- USER appuser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- # Expose the port the app runs on
30
- EXPOSE ${PORT}
 
31
 
32
- # Define the command to run the application using Uvicorn
33
- # Use shell form to allow ${PORT} substitution
34
- CMD uvicorn proxy_server:app --host 0.0.0.0 --port ${PORT}
 
1
+ # Stage 1: Build the Go application
2
+ FROM golang:1.21-alpine AS builder
 
3
 
4
+ # Set necessary environment variables
5
+ ENV CGO_ENABLED=0 GOOS=linux GOARCH=amd64
 
 
 
6
 
7
+ WORKDIR /build
 
8
 
9
+ # Copy only the module definition file first
10
+ COPY go.mod ./
11
+ # DO NOT copy go.sum here initially.
12
 
13
+ # Download dependencies and ensure go.sum is consistent
14
+ # 'go mod tidy' synchronizes the go.mod and go.sum files with the source code imports.
15
+ # It should be run after copying the source code.
16
+ RUN go mod download
17
 
18
+ # Copy the rest of the source code AFTER initial download
19
  COPY . .
20
 
21
+ # Now run tidy to ensure go.mod and go.sum match the code
22
+ RUN go mod tidy
23
+
24
+ # Verify dependencies (optional but good practice)
25
+ RUN go mod verify
26
+
27
+
28
+ # Build the Go application statically linked
29
+ # -ldflags="-w -s" reduces binary size by removing debug info
30
+ RUN go build -ldflags="-w -s" -o /app/proxy-server .
31
+
32
+ # Stage 2: Create the final minimal image
33
+ FROM alpine:latest
34
+
35
+ # Install ca-certificates for HTTPS calls and tzdata for timezone info
36
+ RUN apk update && apk add --no-cache ca-certificates tzdata
37
+
38
+ # Set the working directory
39
+ WORKDIR /app
40
+
41
+ # Copy the built binary from the builder stage
42
+ COPY --from=builder /app/proxy-server /app/proxy-server
43
+
44
+ # Expose the port the app runs on (using the default from config or ENV)
45
+ # Defaulting to 7860 if not overridden by ENV PORT in runtime environment
46
+ EXPOSE 7860
47
 
48
+ # Set the entrypoint command to run the binary
49
+ # The application will read environment variables at runtime
50
+ ENTRYPOINT ["/app/proxy-server"]
51
 
52
+ # Optional: Add a non-root user for security (Uncomment if needed)
53
+ # RUN addgroup -S appgroup && adduser -S appuser -G appgroup
54
+ # USER appuser
auth.go ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "log"
5
+ "net/http"
6
+
7
+ "github.com/gin-gonic/gin"
8
+ )
9
+
10
+ const APIKeyHeaderName = "X-API-Key"
11
+
12
+ // APIKeyAuthMiddleware creates a Gin middleware for API key authentication
13
+ func APIKeyAuthMiddleware(validKeys map[string]bool) gin.HandlerFunc {
14
+ return func(c *gin.Context) {
15
+ apiKey := c.GetHeader(APIKeyHeaderName)
16
+
17
+ if apiKey == "" {
18
+ log.Printf("WARN: [%s] API Key missing in header '%s'", c.ClientIP(), APIKeyHeaderName)
19
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
20
+ "type": "error",
21
+ "error": gin.H{
22
+ "type": "authentication_error",
23
+ "message": "API Key required in header '" + APIKeyHeaderName + "'",
24
+ },
25
+ })
26
+ return
27
+ }
28
+
29
+ if _, isValid := validKeys[apiKey]; !isValid {
30
+ log.Printf("WARN: [%s] Invalid API Key received (length: %d)", c.ClientIP(), len(apiKey))
31
+ c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
32
+ "type": "error",
33
+ "error": gin.H{
34
+ "type": "authentication_error",
35
+ "message": "Invalid or expired API Key",
36
+ },
37
+ })
38
+ return
39
+ }
40
+
41
+ // Log successful authentication (optional, consider security implications)
42
+ // log.Printf("INFO: [%s] Valid API key received (length: %d)", c.ClientIP(), len(apiKey))
43
+ c.Next() // Proceed to the next handler
44
+ }
45
+ }
config.go ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "log"
5
+ "net" // <<< Added import
6
+ "net/http"
7
+ "net/url"
8
+ "os"
9
+ "strconv"
10
+ "strings"
11
+ "time"
12
+ )
13
+
14
+ // Config holds all configuration for the application
15
+ type Config struct {
16
+ OpenAIAPIEndpoint string
17
+ OpenAIAPIKey string
18
+ ProxyAPIKeys string // Comma-separated keys
19
+ ValidAPIKeys map[string]bool // Set of valid keys for quick lookup
20
+ ConnectTimeout time.Duration
21
+ ReadTimeout time.Duration
22
+ WriteTimeout time.Duration
23
+ PoolTimeout time.Duration // Note: Go's default transport manages pooling differently
24
+ HTTPProxyURL *url.URL
25
+ Port string
26
+ GinMode string
27
+ LogLevel string // For potential future structured logging integration
28
+ UpstreamTransport http.RoundTripper // Custom transport for http client
29
+ }
30
+
31
+ // LoadConfig reads configuration from environment variables
32
+ func LoadConfig() *Config {
33
+ cfg := &Config{
34
+ OpenAIAPIEndpoint: getEnv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions"),
35
+ OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""),
36
+ ProxyAPIKeys: getEnv("PROXY_API_KEYS", ""),
37
+ ConnectTimeout: getEnvDuration("CONNECT_TIMEOUT", 5*time.Second),
38
+ ReadTimeout: getEnvDuration("READ_TIMEOUT", 180*time.Second),
39
+ WriteTimeout: getEnvDuration("WRITE_TIMEOUT", 30*time.Second),
40
+ PoolTimeout: getEnvDuration("POOL_TIMEOUT", 5*time.Second), // Less directly applicable in Go's default client
41
+ Port: getEnv("PORT", "7860"),
42
+ GinMode: getEnv("GIN_MODE", "release"), // "debug" or "release"
43
+ LogLevel: getEnv("LOG_LEVEL", "INFO"),
44
+ }
45
+
46
+ // Process API Keys into a map for efficient lookup
47
+ cfg.ValidAPIKeys = make(map[string]bool)
48
+ if cfg.ProxyAPIKeys != "" {
49
+ keys := strings.Split(cfg.ProxyAPIKeys, ",")
50
+ for _, key := range keys {
51
+ trimmedKey := strings.TrimSpace(key)
52
+ if trimmedKey != "" {
53
+ cfg.ValidAPIKeys[trimmedKey] = true
54
+ }
55
+ }
56
+ }
57
+
58
+ // Parse HTTP Proxy URL
59
+ proxyStr := getEnv("HTTP_PROXY", "")
60
+ if proxyStr != "" {
61
+ proxyURL, err := url.Parse(proxyStr)
62
+ if err != nil {
63
+ log.Printf("WARN: Invalid HTTP_PROXY URL '%s': %v. Proxy disabled.", proxyStr, err)
64
+ cfg.HTTPProxyURL = nil
65
+ } else {
66
+ cfg.HTTPProxyURL = proxyURL
67
+ log.Printf("Using outbound proxy: %s", cfg.HTTPProxyURL.String())
68
+ }
69
+ }
70
+
71
+ // Configure the shared HTTP client transport
72
+ defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
73
+ if cfg.HTTPProxyURL != nil { // Set proxy only if URL is valid
74
+ defaultTransport.Proxy = http.ProxyURL(cfg.HTTPProxyURL)
75
+ }
76
+ // Configure timeouts (Connect timeout is part of DialContext)
77
+ defaultTransport.DialContext = (&net.Dialer{ // <<< Used net.Dialer here
78
+ Timeout: cfg.ConnectTimeout, // Connect timeout
79
+ KeepAlive: 30 * time.Second, // Keep-alive interval
80
+ }).DialContext
81
+ defaultTransport.TLSHandshakeTimeout = 10 * time.Second // TLS handshake timeout
82
+ defaultTransport.ResponseHeaderTimeout = cfg.ReadTimeout // Timeout waiting for response headers
83
+ // Go's http client manages connection pooling automatically.
84
+ // MaxIdleConns, MaxIdleConnsPerHost can be tuned if needed.
85
+ defaultTransport.MaxIdleConns = 100
86
+ defaultTransport.MaxIdleConnsPerHost = 10
87
+ defaultTransport.IdleConnTimeout = 90 * time.Second
88
+
89
+ cfg.UpstreamTransport = defaultTransport
90
+
91
+ // Log warnings for missing keys
92
+ if cfg.OpenAIAPIKey == "" {
93
+ log.Println("WARN: OPENAI_API_KEY is not set.")
94
+ }
95
+ if len(cfg.ValidAPIKeys) == 0 {
96
+ log.Println("WARN: PROXY_API_KEYS is not set. Proxy is open (no authentication).")
97
+ }
98
+
99
+ return cfg
100
+ }
101
+
102
+ // getEnv reads an environment variable or returns a default value
103
+ func getEnv(key, defaultValue string) string {
104
+ if value, exists := os.LookupEnv(key); exists {
105
+ return value
106
+ }
107
+ return defaultValue
108
+ }
109
+
110
+ // getEnvDuration reads an environment variable as seconds and returns a time.Duration
111
+ func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
112
+ valueStr := getEnv(key, "")
113
+ if valueStr != "" {
114
+ if valueFloat, err := strconv.ParseFloat(valueStr, 64); err == nil {
115
+ return time.Duration(valueFloat * float64(time.Second))
116
+ }
117
+ log.Printf("WARN: Invalid duration format for %s: '%s'. Using default: %v", key, valueStr, defaultValue)
118
+ }
119
+ return defaultValue
120
+ }
convert.go ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "encoding/json"
5
+ "errors"
6
+ "fmt"
7
+ "log"
8
+ "strings"
9
+ )
10
+
11
+ // Helper function to estimate tokens (same as Python version)
12
+ func estimateTokens(text string) int {
13
+ if text == "" {
14
+ return 0
15
+ }
16
+ est := len(text) / 3 // Simple heuristic
17
+ if est == 0 && len(text) > 0 {
18
+ return 1 // Ensure at least 1 token for non-empty string
19
+ }
20
+ return est
21
+ }
22
+
23
+ // calculateInputTokensFromClaudeRequest estimates input tokens from Claude request
24
+ func calculateInputTokensFromClaudeRequest(claudeReq *ClaudeRequest) int {
25
+ totalChars := 0
26
+
27
+ // Process system prompt
28
+ if len(claudeReq.System) > 0 {
29
+ // Try unmarshaling as string first
30
+ var systemStr string
31
+ if err := json.Unmarshal(claudeReq.System, &systemStr); err == nil {
32
+ totalChars += len(systemStr)
33
+ } else {
34
+ // Try unmarshaling as list of blocks
35
+ var systemBlocks []ClaudeContentBlock
36
+ if err := json.Unmarshal(claudeReq.System, &systemBlocks); err == nil {
37
+ for _, block := range systemBlocks {
38
+ if block.Type == "text" {
39
+ totalChars += len(block.Text)
40
+ }
41
+ }
42
+ } else {
43
+ log.Printf("WARN: Could not parse system prompt format: %s", string(claudeReq.System))
44
+ }
45
+ }
46
+ }
47
+
48
+ // Process messages
49
+ for _, msg := range claudeReq.Messages {
50
+ // Try unmarshaling as string first
51
+ var contentStr string
52
+ if err := json.Unmarshal(msg.Content, &contentStr); err == nil {
53
+ totalChars += len(contentStr)
54
+ } else {
55
+ // Try unmarshaling as list of blocks
56
+ var contentBlocks []ClaudeContentBlock
57
+ if err := json.Unmarshal(msg.Content, &contentBlocks); err == nil {
58
+ for _, block := range contentBlocks {
59
+ if block.Type == "text" {
60
+ totalChars += len(block.Text)
61
+ }
62
+ }
63
+ } else {
64
+ log.Printf("WARN: Could not parse message content format for role %s: %s", msg.Role, string(msg.Content))
65
+ }
66
+ }
67
+ }
68
+
69
+ estimated := estimateTokens(fmt.Sprintf("%d", totalChars)) // Pass total chars as string to estimate
70
+ log.Printf("DEBUG: Estimated input characters: %d, Estimated input tokens: %d", totalChars, estimated)
71
+ return estimated
72
+ }
73
+
74
+ // convertClaudeRequestToOpenAI converts Claude request to OpenAI format
75
+ func convertClaudeRequestToOpenAI(claudeReq *ClaudeRequest) (*OpenAIRequest, error) {
76
+ openAIMessages := []OpenAIMessage{}
77
+
78
+ // --- Handle System Prompt ---
79
+ if len(claudeReq.System) > 0 {
80
+ systemContent := ""
81
+ var systemStr string
82
+ // Try simple string first
83
+ if err := json.Unmarshal(claudeReq.System, &systemStr); err == nil {
84
+ systemContent = systemStr
85
+ } else {
86
+ // Try list of blocks
87
+ var systemBlocks []ClaudeContentBlock
88
+ if err := json.Unmarshal(claudeReq.System, &systemBlocks); err == nil {
89
+ var parts []string
90
+ for _, block := range systemBlocks {
91
+ if block.Type == "text" {
92
+ parts = append(parts, block.Text)
93
+ }
94
+ }
95
+ systemContent = strings.Join(parts, "\n")
96
+ } else {
97
+ log.Printf("WARN: Could not parse system prompt format for conversion: %s", string(claudeReq.System))
98
+ // Decide how to handle - skip system prompt or return error? Skipping for now.
99
+ }
100
+ }
101
+ if systemContent != "" {
102
+ openAIMessages = append(openAIMessages, OpenAIMessage{Role: "system", Content: systemContent})
103
+ }
104
+ }
105
+
106
+ // --- Handle Messages ---
107
+ for _, msg := range claudeReq.Messages {
108
+ if msg.Role != "user" && msg.Role != "assistant" {
109
+ log.Printf("WARN: Skipping message with unsupported role: %s", msg.Role)
110
+ continue
111
+ }
112
+
113
+ messageContent := ""
114
+ var contentStr string
115
+ // Try simple string first
116
+ if err := json.Unmarshal(msg.Content, &contentStr); err == nil {
117
+ messageContent = contentStr
118
+ } else {
119
+ // Try list of blocks
120
+ var contentBlocks []ClaudeContentBlock
121
+ if err := json.Unmarshal(msg.Content, &contentBlocks); err == nil {
122
+ var parts []string
123
+ for _, block := range contentBlocks {
124
+ if block.Type == "text" {
125
+ parts = append(parts, block.Text)
126
+ } else {
127
+ log.Printf("WARN: Skipping non-text content block type '%s' in message for role %s", block.Type, msg.Role)
128
+ }
129
+ }
130
+ messageContent = strings.Join(parts, "\n")
131
+ } else {
132
+ log.Printf("WARN: Could not parse message content format for role %s during conversion: %s", msg.Role, string(msg.Content))
133
+ // Skip message if content parsing fails
134
+ continue
135
+ }
136
+ }
137
+
138
+ if messageContent != "" || msg.Role == "assistant" { // Allow empty assistant messages if needed? Check OpenAI spec. Usually needs content.
139
+ openAIMessages = append(openAIMessages, OpenAIMessage{Role: msg.Role, Content: messageContent})
140
+ } else {
141
+ log.Printf("WARN: Skipping message for role %s with no valid text content after parsing.", msg.Role)
142
+ }
143
+ }
144
+
145
+ if len(openAIMessages) == 0 {
146
+ return nil, errors.New("conversion resulted in no valid messages for OpenAI request")
147
+ }
148
+
149
+ // --- Construct OpenAI Request ---
150
+ openAIReq := &OpenAIRequest{
151
+ Model: claudeReq.Model, // Use the model specified in Claude request
152
+ Messages: openAIMessages,
153
+ Stream: claudeReq.Stream,
154
+ MaxTokens: claudeReq.MaxTokens,
155
+ Temperature: claudeReq.Temperature,
156
+ TopP: claudeReq.TopP,
157
+ // Stop sequences mapping
158
+ Stop: claudeReq.StopSequences,
159
+ }
160
+
161
+ // Default model if not provided
162
+ if openAIReq.Model == "" {
163
+ openAIReq.Model = "gpt-3.5-turbo" // Or get from config
164
+ }
165
+
166
+ return openAIReq, nil
167
+ }
168
+
169
+ // mapOpenAIFinishReasonToClaude maps OpenAI finish reason to Claude stop reason
170
+ func mapOpenAIFinishReasonToClaude(openAIFinishReason *string) string {
171
+ if openAIFinishReason == nil {
172
+ return "end_turn" // Default if nil
173
+ }
174
+ reason := *openAIFinishReason
175
+ switch reason {
176
+ case "stop":
177
+ return "end_turn"
178
+ case "length":
179
+ return "max_tokens"
180
+ case "function_call", "tool_calls":
181
+ return "tool_use"
182
+ case "content_filter":
183
+ return "stop_sequence" // Or maybe map to an error type?
184
+ default:
185
+ log.Printf("WARN: Unknown OpenAI finish reason '%s', mapping to 'end_turn'", reason)
186
+ return "end_turn" // Default for unknown reasons
187
+ }
188
+ }
189
+
190
+ // convertOpenAIResponseToClaude converts non-streaming OpenAI response to Claude format
191
+ func convertOpenAIResponseToClaude(openAIResp *OpenAIResponse, claudeRequestID string) (*ClaudeResponse, error) {
192
+ if len(openAIResp.Choices) == 0 {
193
+ return nil, errors.New("OpenAI response has no choices")
194
+ }
195
+
196
+ choice := openAIResp.Choices[0]
197
+ claudeStopReason := mapOpenAIFinishReasonToClaude(choice.FinishReason)
198
+
199
+ // --- Prepare Usage ---
200
+ claudeUsage := ClaudeUsage{}
201
+ if openAIResp.Usage != nil {
202
+ claudeUsage.InputTokens = openAIResp.Usage.PromptTokens
203
+ claudeUsage.OutputTokens = openAIResp.Usage.CompletionTokens
204
+ } else {
205
+ log.Printf("WARN: [%s] Usage data missing in non-streaming OpenAI response", claudeRequestID)
206
+ // Potentially estimate usage here if critical, otherwise leave as zero
207
+ }
208
+
209
+ // --- Prepare Content ---
210
+ claudeContent := []ClaudeContentBlock{
211
+ {
212
+ Type: "text",
213
+ Text: choice.Message.Content, // Assuming message content is always text
214
+ },
215
+ }
216
+
217
+ // --- Construct Claude Response ---
218
+ claudeResp := &ClaudeResponse{
219
+ ID: openAIResp.ID, // Use OpenAI's response ID
220
+ Type: "message",
221
+ Role: "assistant", // Assuming OpenAI response role is assistant
222
+ Content: claudeContent,
223
+ Model: openAIResp.Model, // Use the model OpenAI reported
224
+ StopReason: claudeStopReason,
225
+ StopSequence: nil, // Typically null
226
+ Usage: claudeUsage,
227
+ }
228
+
229
+ // Use original request ID if OpenAI ID is missing (shouldn't happen often)
230
+ if claudeResp.ID == "" {
231
+ log.Printf("WARN: OpenAI response ID missing, using original request ID: %s", claudeRequestID)
232
+ claudeResp.ID = claudeRequestID
233
+ }
234
+
235
+ return claudeResp, nil
236
+ }
go.mod ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module claude-proxy-go // Bạn có thể thay đổi "claude-proxy-go" thành tên module mong muốn
2
+
3
+ go 1.21 // Hoặc phiên bản Go bạn muốn sử dụng (phù hợp với Dockerfile)
4
+
5
+ require (
6
+ github.com/gin-contrib/cors v1.7.2
7
+ github.com/gin-gonic/gin v1.10.0
8
+ github.com/google/uuid v1.6.0
9
+ github.com/joho/godotenv v1.5.1 // Optional: for local .env loading
10
+ )
11
+
12
+ require (
13
+ github.com/bytedance/sonic v1.11.6 // indirect
14
+ github.com/bytedance/sonic/loader v0.1.1 // indirect
15
+ github.com/cloudwego/base64x v0.1.4 // indirect
16
+ github.com/cloudwego/iasm v0.2.0 // indirect
17
+ github.com/gabriel-vasile/mimetype v1.4.3 // indirect
18
+ github.com/gin-contrib/sse v0.1.0 // indirect
19
+ github.com/go-playground/locales v0.14.1 // indirect
20
+ github.com/go-playground/universal-translator v0.18.1 // indirect
21
+ github.com/go-playground/validator/v10 v10.20.0 // indirect
22
+ github.com/goccy/go-json v0.10.2 // indirect
23
+ github.com/json-iterator/go v1.1.12 // indirect
24
+ github.com/klauspost/cpuid/v2 v2.2.7 // indirect
25
+ github.com/leodido/go-urn v1.4.0 // indirect
26
+ github.com/mattn/go-isatty v0.0.20 // indirect
27
+ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
28
+ github.com/modern-go/reflect2 v1.0.2 // indirect
29
+ github.com/pelletier/go-toml/v2 v2.2.2 // indirect
30
+ github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
31
+ github.com/ugorji/go/codec v1.2.12 // indirect
32
+ golang.org/x/arch v0.8.0 // indirect
33
+ golang.org/x/crypto v0.23.0 // indirect
34
+ golang.org/x/net v0.25.0 // indirect
35
+ golang.org/x/sys v0.20.0 // indirect
36
+ golang.org/x/text v0.15.0 // indirect
37
+ google.golang.org/protobuf v1.34.1 // indirect
38
+ gopkg.in/yaml.v3 v3.0.1 // indirect
39
+ )
go.sum ADDED
File without changes
handlers.go ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bytes"
5
+ "encoding/json"
6
+ "fmt"
7
+ "io"
8
+ "log"
9
+ "net" // <<< Added import
10
+ "net/http"
11
+ "strings" // <<< Added import
12
+ "time"
13
+
14
+ "github.com/gin-gonic/gin"
15
+ "github.com/google/uuid"
16
+ )
17
+
18
+ // HealthCheckHandler handles the /health endpoint
19
+ func HealthCheckHandler(c *gin.Context) {
20
+ c.JSON(http.StatusOK, gin.H{"status": "healthy"})
21
+ }
22
+
23
+ // MessagesHandler handles the /v1/messages endpoint
24
+ func MessagesHandler(c *gin.Context) {
25
+ requestID := fmt.Sprintf("msg_%s", uuid.NewString()[:24]) // Generate unique request ID
26
+ cfg := LoadConfig() // Load config (consider passing it down instead of reloading)
27
+
28
+ // --- 1. Read and Parse Incoming Request ---
29
+ var claudeReq ClaudeRequest
30
+ bodyBytes, err := io.ReadAll(c.Request.Body)
31
+ if err != nil {
32
+ log.Printf("ERROR: [%s] Failed to read request body: %v", requestID, err)
33
+ sendClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Could not read request body.")
34
+ return
35
+ }
36
+ // Restore body for potential re-reads (though we don't re-read here)
37
+ c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
38
+
39
+ if err := json.Unmarshal(bodyBytes, &claudeReq); err != nil {
40
+ log.Printf("ERROR: [%s] Failed to decode request JSON: %v. Body: %s", requestID, err, string(bodyBytes))
41
+ sendClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid JSON format in request body.")
42
+ return
43
+ }
44
+
45
+ isStreaming := claudeReq.Stream
46
+ modelRequested := claudeReq.Model
47
+ if modelRequested == "" {
48
+ modelRequested = "unknown_model"
49
+ } // Handle empty model case
50
+ log.Printf("INFO: [%s] Received request. Stream: %t. Model: %s", requestID, isStreaming, modelRequested)
51
+ // Optional: Log full payload at debug level
52
+ // log.Printf("DEBUG: [%s] Received Payload: %s", requestID, string(bodyBytes))
53
+
54
+ // --- 2. Convert Request Format ---
55
+ openAIReq, err := convertClaudeRequestToOpenAI(&claudeReq)
56
+ if err != nil {
57
+ log.Printf("ERROR: [%s] Failed to convert Claude request to OpenAI format: %v", requestID, err)
58
+ sendClaudeError(c, http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Error converting request data: %v", err))
59
+ return
60
+ }
61
+ // Ensure stream flag is correctly set in the converted request
62
+ openAIReq.Stream = isStreaming
63
+
64
+ // --- 3. Prepare and Send Upstream Request ---
65
+ // Marshal the OpenAI request body
66
+ openaiReqBytes, err := json.Marshal(openAIReq)
67
+ if err != nil {
68
+ log.Printf("ERROR: [%s] Failed to marshal OpenAI request JSON: %v", requestID, err)
69
+ sendClaudeError(c, http.StatusInternalServerError, "internal_server_error", "Failed to prepare upstream request.")
70
+ return
71
+ }
72
+
73
+ // Create the HTTP request to the upstream endpoint
74
+ upstreamURL := cfg.OpenAIAPIEndpoint
75
+ req, err := http.NewRequestWithContext(c.Request.Context(), "POST", upstreamURL, bytes.NewBuffer(openaiReqBytes))
76
+ if err != nil {
77
+ log.Printf("ERROR: [%s] Failed to create upstream HTTP request: %v", requestID, err)
78
+ sendClaudeError(c, http.StatusInternalServerError, "internal_server_error", "Failed to create upstream request.")
79
+ return
80
+ }
81
+
82
+ // Set headers for upstream request
83
+ req.Header.Set("Content-Type", "application/json")
84
+ if isStreaming {
85
+ req.Header.Set("Accept", "text/event-stream")
86
+ } else {
87
+ req.Header.Set("Accept", "application/json")
88
+ }
89
+ if cfg.OpenAIAPIKey != "" {
90
+ req.Header.Set("Authorization", "Bearer "+cfg.OpenAIAPIKey)
91
+ }
92
+ // Copy potentially relevant headers from original request? (e.g., User-Agent) - Be cautious about security.
93
+ // req.Header.Set("User-Agent", c.GetHeader("User-Agent"))
94
+
95
+ // Log upstream request details (optional, redact sensitive info)
96
+ // log.Printf("DEBUG: [%s] Sending upstream request to %s. Headers: %v", requestID, upstreamURL, req.Header)
97
+ // log.Printf("DEBUG: [%s] Upstream Payload: %s", requestID, string(openaiReqBytes))
98
+ log.Printf("INFO: [%s] Sending upstream request (Stream=%t) to %s...", requestID, isStreaming, upstreamURL)
99
+
100
+ // --- Execute Upstream Request ---
101
+ // Use a client with the configured transport (timeouts, proxy)
102
+ httpClient := &http.Client{
103
+ Transport: cfg.UpstreamTransport,
104
+ Timeout: 0, // Timeout is handled by the transport's ResponseHeaderTimeout and DialContext Timeout
105
+ }
106
+ startTime := time.Now()
107
+ upstreamResp, err := httpClient.Do(req)
108
+ if err != nil {
109
+ // Handle client-side errors (network, DNS, timeout before connection, etc.)
110
+ log.Printf("ERROR: [%s] Upstream request failed: %v", requestID, err)
111
+ // Check for timeout specifically
112
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() { // <<< Used net.Error here
113
+ sendClaudeError(c, http.StatusGatewayTimeout, "api_error", fmt.Sprintf("Gateway Timeout connecting to upstream (%v).", cfg.ConnectTimeout))
114
+ } else {
115
+ sendClaudeError(c, http.StatusBadGateway, "api_error", fmt.Sprintf("Bad Gateway: Could not connect to upstream. Error: %v", err))
116
+ }
117
+ return
118
+ }
119
+ // Note: We don't close upstreamResp.Body here yet, it's needed for streaming or reading non-streaming body.
120
+ // It will be closed by the streaming handler or after reading the body in non-streaming case.
121
+ log.Printf("INFO: [%s] Received upstream status: %d (%s)", requestID, upstreamResp.StatusCode, http.StatusText(upstreamResp.StatusCode))
122
+
123
+ // --- 4. Process Upstream Response ---
124
+
125
+ // Handle non-OK status codes
126
+ if upstreamResp.StatusCode != http.StatusOK {
127
+ // Read error body from upstream
128
+ errorBodyBytes, readErr := io.ReadAll(upstreamResp.Body)
129
+ if readErr != nil {
130
+ log.Printf("WARN: [%s] Failed to read upstream error body (Status %d): %v", requestID, upstreamResp.StatusCode, readErr)
131
+ }
132
+ _ = upstreamResp.Body.Close() // Ensure body is closed after reading or error
133
+
134
+ errorBodyStr := string(errorBodyBytes)
135
+ log.Printf("ERROR: [%s] Upstream returned error status %d. Body: %s", requestID, upstreamResp.StatusCode, errorBodyStr)
136
+
137
+ // Try to map the error to Claude format
138
+ // Basic mapping, can be improved by parsing OpenAI error structure if available
139
+ var errorType string
140
+ switch upstreamResp.StatusCode {
141
+ case http.StatusBadRequest:
142
+ errorType = "invalid_request_error"
143
+ case http.StatusUnauthorized:
144
+ errorType = "authentication_error"
145
+ case http.StatusForbidden:
146
+ errorType = "permission_error"
147
+ case http.StatusTooManyRequests:
148
+ errorType = "rate_limit_error"
149
+ case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
150
+ errorType = "api_error"
151
+ default:
152
+ errorType = "api_error" // Default for other errors
153
+ }
154
+ errMsg := fmt.Sprintf("Upstream API error (%d). Details: %s", upstreamResp.StatusCode, strings.TrimSpace(errorBodyStr)) // <<< Used strings.TrimSpace here
155
+ // Truncate long error messages if necessary
156
+ if len(errMsg) > 300 {
157
+ errMsg = errMsg[:300] + "..."
158
+ }
159
+ sendClaudeError(c, upstreamResp.StatusCode, errorType, errMsg)
160
+ return
161
+ }
162
+
163
+ // --- Handle OK response based on streaming ---
164
+ if isStreaming {
165
+ log.Printf("INFO: [%s] Upstream stream received. Starting SSE conversion (Go v1.9.0 - Priority Delta).", requestID)
166
+ // Delegate to the SSE streaming function
167
+ // This function will handle reading upstreamResp.Body and closing it
168
+ streamOpenAIResponseToClaudeSSE(c, upstreamResp, requestID, openAIReq.Model, &claudeReq)
169
+ } else {
170
+ // --- Non-Streaming ---
171
+ log.Printf("INFO: [%s] Upstream non-stream response received. Converting.", requestID)
172
+ defer upstreamResp.Body.Close() // Ensure body is closed after reading
173
+
174
+ // Read and parse upstream JSON response
175
+ var openAIResp OpenAIResponse
176
+ bodyBytes, err := io.ReadAll(upstreamResp.Body)
177
+ if err != nil {
178
+ log.Printf("ERROR: [%s] Failed to read non-streaming upstream response body: %v", requestID, err)
179
+ sendClaudeError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response.")
180
+ return
181
+ }
182
+ if err := json.Unmarshal(bodyBytes, &openAIResp); err != nil {
183
+ log.Printf("ERROR: [%s] Failed to decode non-streaming upstream JSON: %v. Body: %s", requestID, err, string(bodyBytes))
184
+ sendClaudeError(c, http.StatusBadGateway, "api_error", "Upstream API returned invalid JSON.")
185
+ return
186
+ }
187
+
188
+ // Convert OpenAI response to Claude format
189
+ claudeResp, err := convertOpenAIResponseToClaude(&openAIResp, requestID)
190
+ if err != nil {
191
+ log.Printf("ERROR: [%s] Failed to convert non-streaming OpenAI response: %v", requestID, err)
192
+ sendClaudeError(c, http.StatusInternalServerError, "internal_server_error", fmt.Sprintf("Error processing upstream response: %v", err))
193
+ return
194
+ }
195
+
196
+ // Send the converted Claude response
197
+ c.JSON(http.StatusOK, claudeResp)
198
+ log.Printf("INFO: [%s] Successfully processed non-streaming request in %v", requestID, time.Since(startTime))
199
+ }
200
+ }
201
+
202
+ // sendClaudeError is a helper to send standardized Claude error responses
203
+ func sendClaudeError(c *gin.Context, statusCode int, errorType string, message string) {
204
+ errResp := ClaudeErrorResponse{
205
+ Type: "error",
206
+ Error: ClaudeError{
207
+ Type: errorType,
208
+ Message: message,
209
+ },
210
+ }
211
+ // Ensure status code is in valid range, default to 500 if not
212
+ if statusCode < 400 || statusCode > 599 {
213
+ log.Printf("WARN: Invalid status code %d provided for error, defaulting to 500.", statusCode)
214
+ statusCode = http.StatusInternalServerError
215
+ }
216
+ c.AbortWithStatusJSON(statusCode, errResp)
217
+ }
218
+
main.go ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "context"
5
+ "fmt"
6
+ "log"
7
+ "net/http"
8
+ "os"
9
+ "os/signal"
10
+ "syscall"
11
+ "time"
12
+
13
+ "github.com/gin-contrib/cors"
14
+ "github.com/gin-gonic/gin"
15
+ "github.com/joho/godotenv" // Optional: for loading .env file
16
+ )
17
+
18
+ func main() {
19
+ // Load .env file if present (optional, good for local dev)
20
+ _ = godotenv.Load()
21
+
22
+ // Load configuration
23
+ cfg := LoadConfig()
24
+
25
+ // Set Gin mode (release or debug)
26
+ if cfg.GinMode == "release" {
27
+ gin.SetMode(gin.ReleaseMode)
28
+ } else {
29
+ gin.SetMode(gin.DebugMode)
30
+ }
31
+ log.Printf("Starting Go Proxy Server in %s mode...", gin.Mode())
32
+
33
+ // Initialize Gin router
34
+ router := gin.New()
35
+
36
+ // Middleware
37
+ router.Use(gin.Logger()) // Standard Gin logger
38
+ router.Use(gin.Recovery()) // Recover from panics
39
+ // CORS middleware (allow all for simplicity, adjust as needed)
40
+ router.Use(cors.New(cors.Config{
41
+ AllowOrigins: []string{"*"},
42
+ AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
43
+ AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization", "X-API-Key"}, // Include X-API-Key
44
+ ExposeHeaders: []string{"Content-Length"},
45
+ AllowCredentials: true,
46
+ MaxAge: 12 * time.Hour,
47
+ }))
48
+
49
+ // --- Routes ---
50
+ // Health check
51
+ router.GET("/health", HealthCheckHandler)
52
+
53
+ // Main proxy endpoint group
54
+ v1 := router.Group("/v1")
55
+ {
56
+ // Apply API Key Authentication middleware if keys are configured
57
+ if len(cfg.ValidAPIKeys) > 0 {
58
+ log.Printf("API Key authentication enabled (%d keys configured).", len(cfg.ValidAPIKeys))
59
+ v1.Use(APIKeyAuthMiddleware(cfg.ValidAPIKeys))
60
+ } else {
61
+ log.Println("WARN: No PROXY_API_KEYS configured. Proxy is open (no authentication).")
62
+ }
63
+ v1.POST("/messages", MessagesHandler)
64
+ }
65
+
66
+ // --- Server Setup ---
67
+ server := &http.Server{
68
+ Addr: fmt.Sprintf(":%s", cfg.Port),
69
+ Handler: router,
70
+ // Add timeouts for production hardening
71
+ ReadTimeout: 10 * time.Second,
72
+ WriteTimeout: cfg.ReadTimeout + 30*time.Second, // Ensure write timeout is longer than read timeout for streaming
73
+ IdleTimeout: 120 * time.Second,
74
+ }
75
+
76
+ // --- Graceful Shutdown ---
77
+ // Run server in a goroutine so it doesn't block
78
+ go func() {
79
+ log.Printf("Server listening on port %s", cfg.Port)
80
+ if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
81
+ log.Fatalf("listen: %s\n", err)
82
+ }
83
+ }()
84
+
85
+ // Wait for interrupt signal to gracefully shut down the server
86
+ quit := make(chan os.Signal, 1)
87
+ // kill (no param) default send syscall.SIGTERM
88
+ // kill -2 is syscall.SIGINT
89
+ // kill -9 is syscall.SIGKILL but can't be caught, so don't need to add it
90
+ signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
91
+ <-quit
92
+ log.Println("Shutting down server...")
93
+
94
+ // The context is used to inform the server it has 5 seconds to finish
95
+ // the requests it is currently handling
96
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
97
+ defer cancel()
98
+ if err := server.Shutdown(ctx); err != nil {
99
+ log.Fatal("Server forced to shutdown:", err)
100
+ }
101
+
102
+ log.Println("Server exiting")
103
+ }
sse.go ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "encoding/json"
6
+ "fmt"
7
+ "log"
8
+ "net/http"
9
+ "strings"
10
+ "time"
11
+
12
+ "github.com/gin-gonic/gin"
13
+ )
14
+
15
+ // streamOpenAIResponseToClaudeSSE handles the SSE streaming conversion
16
+ // v1.10.0: Sends message_delta with accumulated usage after each content delta.
17
+ func streamOpenAIResponseToClaudeSSE(
18
+ c *gin.Context,
19
+ upstreamResp *http.Response,
20
+ claudeRequestID string,
21
+ requestedModel string,
22
+ originalClaudeRequest *ClaudeRequest, // Pass original request for token calculation
23
+ ) {
24
+ // Ensure correct headers for SSE are set
25
+ c.Writer.Header().Set("Content-Type", "text/event-stream")
26
+ c.Writer.Header().Set("Cache-Control", "no-cache")
27
+ c.Writer.Header().Set("Connection", "keep-alive")
28
+ c.Writer.Header().Set("X-Content-Type-Options", "nosniff")
29
+ c.Writer.Flush() // Ensure headers are sent immediately
30
+
31
+ // --- State Variables ---
32
+ messageID := claudeRequestID
33
+ accumulatedContent := ""
34
+ var openAIFinishReason *string // Store the pointer
35
+ streamErrorOccurred := false
36
+ var errorDetails *ClaudeError // Store potential error details for final event
37
+
38
+ // Pre-calculate input tokens
39
+ calculatedInputTokens := calculateInputTokensFromClaudeRequest(originalClaudeRequest)
40
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Calculated input tokens: %d", messageID, calculatedInputTokens)
41
+ inputTokens := calculatedInputTokens
42
+ outputTokens := 0 // Initialize output tokens (will be updated frequently)
43
+ finalUsageReceivedFromStream := false
44
+ lastPingTime := time.Now()
45
+ eventIndex := 0 // For logging clarity
46
+
47
+ log.Printf("DEBUG: [%s] Starting SSE AggressiveDelta conversion (Go v1.10.0).", messageID)
48
+
49
+ // Use a channel to signal completion or error from the reading goroutine
50
+ doneChan := make(chan struct{})
51
+ errChan := make(chan error, 1) // Buffered channel for error
52
+
53
+ // Goroutine to read from the upstream response
54
+ go func() {
55
+ defer close(doneChan) // Signal completion when done
56
+ defer upstreamResp.Body.Close() // Ensure body is closed
57
+
58
+ scanner := bufio.NewScanner(upstreamResp.Body)
59
+ for scanner.Scan() {
60
+ // Check for client disconnect *before* processing line
61
+ select {
62
+ case <-c.Request.Context().Done():
63
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Client disconnected detected in read loop.", messageID)
64
+ return // Exit goroutine if client disconnected
65
+ default:
66
+ // Continue processing
67
+ }
68
+
69
+ line := scanner.Text()
70
+ if line == "" {
71
+ continue // Skip empty lines
72
+ }
73
+
74
+ if strings.HasPrefix(line, "data:") {
75
+ dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
76
+ if dataStr == "[DONE]" {
77
+ log.Printf("DEBUG: [%s] SSE AggressiveDelta: Received [DONE] marker.", messageID)
78
+ return // Normal stream completion
79
+ }
80
+
81
+ var chunk OpenAIStreamChunk
82
+ if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
83
+ log.Printf("WARN: [%s] SSE AggressiveDelta: Could not decode JSON chunk: %v. Data: %s", messageID, err, dataStr)
84
+ continue // Skip malformed chunks
85
+ }
86
+
87
+ // Process choices
88
+ if len(chunk.Choices) > 0 {
89
+ choice := chunk.Choices[0]
90
+ if choice.FinishReason != nil {
91
+ openAIFinishReason = choice.FinishReason // Store the pointer
92
+ log.Printf("DEBUG: [%s] SSE AggressiveDelta: Received OpenAI finish_reason: %s", messageID, *openAIFinishReason)
93
+ }
94
+ if choice.Delta.Content != nil {
95
+ contentChunk := *choice.Delta.Content
96
+ accumulatedContent += contentChunk
97
+ currentOutputTokens := estimateTokens(accumulatedContent) // Estimate based on current content
98
+
99
+ // --- Yield content_block_delta ---
100
+ deltaPayload := ClaudeSSEEvent{
101
+ Type: "content_block_delta",
102
+ Index: func() *int { i := 0; return &i }(), // Pointer to 0
103
+ Delta: &ClaudeSSEDelta{
104
+ Type: "text_delta",
105
+ Text: &contentChunk, // Pointer to the chunk
106
+ },
107
+ }
108
+ if !sendSSEEvent(c, "content_block_delta", deltaPayload, messageID, eventIndex) {
109
+ return // Stop if client disconnected
110
+ }
111
+ eventIndex++
112
+
113
+ // --- AGGRESSIVE DELTA: Yield message_delta with current usage ---
114
+ // Only send if output tokens have potentially changed
115
+ if currentOutputTokens != outputTokens {
116
+ outputTokens = currentOutputTokens // Update state
117
+ intermediateUsage := ClaudeSSEUsage{OutputTokens: outputTokens}
118
+ intermediateDeltaPayload := ClaudeSSEEvent{
119
+ Type: "message_delta",
120
+ Delta: &ClaudeSSEDelta{}, // Delta part is empty here, only usage matters
121
+ Usage: &intermediateUsage,
122
+ }
123
+ log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (INTERMEDIATE message_delta with usage): %+v", messageID, eventIndex, intermediateDeltaPayload)
124
+ if !sendSSEEvent(c, "message_delta", intermediateDeltaPayload, messageID, eventIndex) {
125
+ return // Stop if client disconnected
126
+ }
127
+ eventIndex++
128
+ }
129
+ // -----------------------------------------------------------------
130
+ }
131
+ }
132
+
133
+ // Check for OpenAI usage block (still useful for final confirmation)
134
+ if chunk.Usage != nil {
135
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Received usage block in OpenAI stream: %+v", messageID, *chunk.Usage)
136
+ if chunk.Usage.CompletionTokens > 0 {
137
+ // If OpenAI provides a final count, trust it more than estimation
138
+ outputTokens = chunk.Usage.CompletionTokens
139
+ finalUsageReceivedFromStream = true
140
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Using final completion_tokens from stream: %d", messageID, outputTokens)
141
+ }
142
+ if chunk.Usage.PromptTokens != inputTokens && chunk.Usage.PromptTokens > 0 {
143
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Updating input tokens based on stream usage block: %d -> %d", messageID, inputTokens, chunk.Usage.PromptTokens)
144
+ inputTokens = chunk.Usage.PromptTokens
145
+ }
146
+ }
147
+ } else {
148
+ log.Printf("TRACE: [%s] SSE AggressiveDelta: Received non-data line: %s", messageID, line)
149
+ }
150
+
151
+ // Send periodic pings
152
+ if time.Since(lastPingTime) >= 10*time.Second {
153
+ pingPayload := ClaudeSSEEvent{Type: "ping"}
154
+ if !sendSSEEvent(c, "ping", pingPayload, messageID, eventIndex) {
155
+ return // Stop if client disconnected
156
+ }
157
+ eventIndex++
158
+ lastPingTime = time.Now()
159
+ }
160
+ }
161
+
162
+ if err := scanner.Err(); err != nil {
163
+ // Check if the error is due to context cancellation (client disconnect)
164
+ select {
165
+ case <-c.Request.Context().Done():
166
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Upstream read interrupted by client disconnect: %v", messageID, c.Request.Context().Err())
167
+ default:
168
+ log.Printf("ERROR: [%s] SSE AggressiveDelta: Error reading upstream response body: %v", messageID, err)
169
+ errChan <- fmt.Errorf("upstream read error: %w", err)
170
+ }
171
+ }
172
+ }()
173
+
174
+ // --- Initial Events ---
175
+ // Send message_start
176
+ startUsage := ClaudeUsage{InputTokens: calculatedInputTokens, OutputTokens: 0}
177
+ startMessage := ClaudeSSEMessage{ ID: messageID, Type: "message", Role: "assistant", Content: []ClaudeContentBlock{}, Model: requestedModel, StopReason: nil, StopSequence: nil, Usage: startUsage }
178
+ startEvent := ClaudeSSEEvent{Type: "message_start", Message: &startMessage}
179
+ if !sendSSEEvent(c, "message_start", startEvent, messageID, eventIndex) { return }
180
+ eventIndex++
181
+
182
+ // Send content_block_start
183
+ contentStartBlock := ClaudeSSEContentBlock{Type: "text", Text: ""}
184
+ contentStartEvent := ClaudeSSEEvent{ Type: "content_block_start", Index: func() *int { i := 0; return &i }(), ContentBlock: &contentStartBlock }
185
+ if !sendSSEEvent(c, "content_block_start", contentStartEvent, messageID, eventIndex) { return }
186
+ eventIndex++
187
+
188
+ // Send initial ping
189
+ pingPayload := ClaudeSSEEvent{Type: "ping"}
190
+ if !sendSSEEvent(c, "ping", pingPayload, messageID, eventIndex) { return }
191
+ eventIndex++
192
+ lastPingTime = time.Now()
193
+
194
+ // --- Wait for completion or error or client disconnect ---
195
+ select {
196
+ case <-doneChan:
197
+ log.Printf("DEBUG: [%s] SSE AggressiveDelta: Upstream reading finished.", messageID)
198
+ case err := <-errChan:
199
+ log.Printf("ERROR: [%s] SSE AggressiveDelta: Received error from reading goroutine: %v", messageID, err)
200
+ streamErrorOccurred = true
201
+ errorDetails = &ClaudeError{Type: "api_error", Message: fmt.Sprintf("Error reading upstream response: %v", err)}
202
+ case <-c.Request.Context().Done():
203
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Client disconnected during stream processing: %v", messageID, c.Request.Context().Err())
204
+ streamErrorOccurred = true // Treat disconnect as a type of error for cleanup
205
+ errorDetails = &ClaudeError{Type: "client_disconnect", Message: "Client disconnected during stream"}
206
+ }
207
+
208
+ // --- Finally Block Logic ---
209
+ log.Printf("DEBUG: [%s] SSE AggressiveDelta: Entering finally block logic. Finish_reason: %v, Error: %t", messageID, openAIFinishReason, streamErrorOccurred)
210
+
211
+ // Determine final Claude stop reason
212
+ var claudeStopReason string
213
+ if streamErrorOccurred && errorDetails != nil && errorDetails.Type == "client_disconnect" {
214
+ claudeStopReason = "client_disconnect"
215
+ } else if streamErrorOccurred {
216
+ claudeStopReason = "error"
217
+ } else {
218
+ claudeStopReason = mapOpenAIFinishReasonToClaude(openAIFinishReason)
219
+ }
220
+
221
+ // Finalize token counts (use last known value, potentially from stream or final estimation)
222
+ finalInputTokens := inputTokens
223
+ finalOutputTokens := outputTokens // Use the value updated during the stream or from OpenAI's usage block
224
+
225
+ // If no explicit usage received, do a final estimation/forcing
226
+ if !finalUsageReceivedFromStream {
227
+ log.Printf("WARN: [%s] SSE AggressiveDelta: Final usage not explicitly received. Doing final estimate.", messageID)
228
+ estimatedOutput := estimateTokens(accumulatedContent)
229
+ finalOutputTokens = max(1, estimatedOutput)
230
+ if accumulatedContent == "" { finalOutputTokens = 0 }
231
+ log.Printf("WARN: [%s] SSE AggressiveDelta: Final Estimated/Forced output tokens: %d", messageID, finalOutputTokens)
232
+ } else {
233
+ // If usage *was* received, still force minimum 1 if non-zero
234
+ finalOutputTokens = max(1, outputTokens)
235
+ if outputTokens == 0 { finalOutputTokens = 0 }
236
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Using/Forced final usage from stream: output=%d", messageID, finalOutputTokens)
237
+ }
238
+ finalInputTokens = max(0, finalInputTokens)
239
+ finalOutputTokens = max(0, finalOutputTokens)
240
+
241
+ // Prepare usage data structures
242
+ // The *last* message_delta sent needs the final output tokens for chat-api billing hack
243
+ finalHackUsageData := ClaudeSSEUsage{OutputTokens: finalOutputTokens}
244
+ finalStopUsageData := ClaudeSSEUsage{ InputTokens: &finalInputTokens, OutputTokens: finalOutputTokens }
245
+
246
+ log.Printf("INFO: [%s] SSE AggressiveDelta: Stream finished. Stop Reason: %s. Final Usage: Input=%d Output=%d", messageID, claudeStopReason, finalInputTokens, finalOutputTokens)
247
+
248
+ // --- Yield Closing Events (Priority Final Delta First) ---
249
+
250
+ // *** Send FINAL message_delta WITH FINAL usage FIRST ***
251
+ // This ensures the most accurate count is sent last, potentially overwriting intermediate ones in chat-api
252
+ finalDeltaStopReason := claudeStopReason
253
+ priorityFinalDeltaPayload := ClaudeSSEEvent{
254
+ Type: "message_delta",
255
+ Delta: &ClaudeSSEDelta{
256
+ StopReason: &finalDeltaStopReason,
257
+ StopSequence: nil,
258
+ },
259
+ Usage: &finalHackUsageData, // Use the final calculated/forced output tokens
260
+ }
261
+ log.Printf("WARN: [%s] SSE AggressiveDelta: Yielding Event %d (PRIORITY FINAL message_delta WITH HACKED USAGE): %+v", messageID, eventIndex, priorityFinalDeltaPayload)
262
+ _ = sendSSEEvent(c, "message_delta", priorityFinalDeltaPayload, messageID, eventIndex) // Try to send even if disconnected
263
+ eventIndex++
264
+
265
+ // Send content_block_stop
266
+ contentStopPayload := ClaudeSSEEvent{ Type: "content_block_stop", Index: func() *int { i := 0; return &i }()}
267
+ log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (content_block_stop)", messageID, eventIndex)
268
+ _ = sendSSEEvent(c, "content_block_stop", contentStopPayload, messageID, eventIndex)
269
+ eventIndex++
270
+
271
+ // Send message_stop
272
+ messageStopPayload := ClaudeSSEEvent{ Type: "message_stop", Usage: &finalStopUsageData }
273
+ log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (message_stop)", messageID, eventIndex)
274
+ _ = sendSSEEvent(c, "message_stop", messageStopPayload, messageID, eventIndex)
275
+ eventIndex++
276
+
277
+ // Send error event if needed
278
+ if streamErrorOccurred && errorDetails != nil && errorDetails.Type != "client_disconnect" {
279
+ errorPayload := ClaudeSSEEvent{ Type: "error", Error: errorDetails }
280
+ log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (error)", messageID, eventIndex)
281
+ _ = sendSSEEvent(c, "error", errorPayload, messageID, eventIndex)
282
+ eventIndex++
283
+ }
284
+
285
+ log.Printf("INFO: [%s] Completed sending SSE AggressiveDelta stream.", messageID)
286
+ }
287
+
288
+ // sendSSEEvent sends a single SSE event and checks for client disconnect
289
+ func sendSSEEvent(c *gin.Context, eventName string, data interface{}, requestID string, eventIndex int) bool {
290
+ select {
291
+ case <-c.Request.Context().Done():
292
+ // Client disconnected
293
+ log.Printf("INFO: [%s] Client disconnected before sending SSE event %d (%s).", requestID, eventIndex, eventName)
294
+ return false
295
+ default:
296
+ // Client still connected, try sending
297
+ jsonData, err := json.Marshal(data)
298
+ if err != nil {
299
+ log.Printf("ERROR: [%s] Failed to marshal SSE event %d (%s): %v", requestID, eventIndex, eventName, err)
300
+ return true // Continue trying other events even if one fails marshaling?
301
+ }
302
+ // Use fmt.Fprintf for potentially better handling with Gin's writer interface
303
+ _, err = fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", eventName, string(jsonData))
304
+ if err != nil {
305
+ // This error often indicates the client disconnected during the write
306
+ log.Printf("WARN: [%s] Failed to write SSE event %d (%s) to client: %v. Client likely disconnected.", requestID, eventIndex, eventName, err)
307
+ return false // Stop processing if write fails
308
+ }
309
+ c.Writer.Flush() // Ensure data is sent immediately
310
+ return true
311
+ }
312
+ }
313
+
314
+ // Helper for max function
315
+ func max(a, b int) int {
316
+ if a > b {
317
+ return a
318
+ }
319
+ return b
320
+ }
structs.go ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import "encoding/json"
4
+
5
+ // --- Claude API Structs (Anthropic Format) ---
6
+
7
+ // ClaudeRequest represents the incoming request structure from the client
8
+ type ClaudeRequest struct {
9
+ Model string `json:"model"`
10
+ Messages []ClaudeMessage `json:"messages"`
11
+ System json.RawMessage `json:"system,omitempty"` // Can be string or list of blocks
12
+ MaxTokens *int `json:"max_tokens,omitempty"` // Use pointer for optional fields
13
+ StopSequences []string `json:"stop_sequences,omitempty"`
14
+ Stream bool `json:"stream,omitempty"`
15
+ Temperature *float64 `json:"temperature,omitempty"`
16
+ TopP *float64 `json:"top_p,omitempty"`
17
+ // TopK *int `json:"top_k,omitempty"` // OpenAI doesn't support TopK directly
18
+ }
19
+
20
+ // ClaudeMessage represents a message in the Claude request
21
+ type ClaudeMessage struct {
22
+ Role string `json:"role"` // "user" or "assistant"
23
+ Content json.RawMessage `json:"content"` // Can be string or list of blocks
24
+ }
25
+
26
+ // ClaudeContentBlock represents a block within the content array
27
+ type ClaudeContentBlock struct {
28
+ Type string `json:"type"`
29
+ Text string `json:"text,omitempty"`
30
+ // Add other block types if needed (e.g., image)
31
+ }
32
+
33
+ // ClaudeResponse represents the non-streaming response structure sent to the client
34
+ type ClaudeResponse struct {
35
+ ID string `json:"id"`
36
+ Type string `json:"type"` // e.g., "message"
37
+ Role string `json:"role"` // e.g., "assistant"
38
+ Content []ClaudeContentBlock `json:"content"`
39
+ Model string `json:"model"`
40
+ StopReason string `json:"stop_reason"` // e.g., "end_turn", "max_tokens"
41
+ StopSequence *string `json:"stop_sequence"` // Usually null
42
+ Usage ClaudeUsage `json:"usage"`
43
+ }
44
+
45
+ // ClaudeUsage represents the token usage information
46
+ type ClaudeUsage struct {
47
+ InputTokens int `json:"input_tokens"`
48
+ OutputTokens int `json:"output_tokens"`
49
+ }
50
+
51
+ // ClaudeErrorResponse represents the error structure sent to the client
52
+ type ClaudeErrorResponse struct {
53
+ Type string `json:"type"` // Always "error"
54
+ Error ClaudeError `json:"error"`
55
+ }
56
+
57
+ // ClaudeError represents the detailed error information
58
+ type ClaudeError struct {
59
+ Type string `json:"type"` // e.g., "invalid_request_error", "api_error"
60
+ Message string `json:"message"`
61
+ }
62
+
63
+
64
+ // --- OpenAI API Structs ---
65
+
66
+ // OpenAIRequest represents the request structure sent to the upstream OpenAI API
67
+ type OpenAIRequest struct {
68
+ Model string `json:"model"`
69
+ Messages []OpenAIMessage `json:"messages"`
70
+ MaxTokens *int `json:"max_tokens,omitempty"`
71
+ Temperature *float64 `json:"temperature,omitempty"`
72
+ TopP *float64 `json:"top_p,omitempty"`
73
+ Stop []string `json:"stop,omitempty"`
74
+ Stream bool `json:"stream,omitempty"`
75
+ // N *int `json:"n,omitempty"` // Not typically used with Claude proxy
76
+ // PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Not mapped
77
+ // FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Not mapped
78
+ }
79
+
80
+ // OpenAIMessage represents a message in the OpenAI request
81
+ type OpenAIMessage struct {
82
+ Role string `json:"role"` // "system", "user", or "assistant"
83
+ Content string `json:"content"`
84
+ }
85
+
86
+ // OpenAIResponse represents the non-streaming response from the upstream OpenAI API
87
+ type OpenAIResponse struct {
88
+ ID string `json:"id"`
89
+ Object string `json:"object"` // e.g., "chat.completion"
90
+ Created int64 `json:"created"`
91
+ Model string `json:"model"`
92
+ Choices []OpenAIChoice `json:"choices"`
93
+ Usage *OpenAIUsage `json:"usage,omitempty"` // Pointer as it might be missing in errors
94
+ // SystemFingerprint string `json:"system_fingerprint"` // Optional
95
+ }
96
+
97
+ // OpenAIChoice represents a choice in the OpenAI response
98
+ type OpenAIChoice struct {
99
+ Index int `json:"index"`
100
+ Message OpenAIMessage `json:"message"`
101
+ FinishReason *string `json:"finish_reason"` // Pointer as it can be null
102
+ // Logprobs interface{} `json:"logprobs"` // Not typically used here
103
+ }
104
+
105
+ // OpenAIUsage represents the token usage information from OpenAI
106
+ type OpenAIUsage struct {
107
+ PromptTokens int `json:"prompt_tokens"`
108
+ CompletionTokens int `json:"completion_tokens"`
109
+ TotalTokens int `json:"total_tokens"`
110
+ }
111
+
112
+ // OpenAIStreamChoice represents a choice within an OpenAI SSE chunk
113
+ type OpenAIStreamChoice struct {
114
+ Index int `json:"index"`
115
+ Delta OpenAIStreamDelta `json:"delta"`
116
+ FinishReason *string `json:"finish_reason"` // Pointer as it can be null
117
+ // Logprobs interface{} `json:"logprobs"` // Not typically used here
118
+ }
119
+
120
+ // OpenAIStreamDelta represents the delta content within an OpenAI SSE chunk
121
+ type OpenAIStreamDelta struct {
122
+ Role *string `json:"role,omitempty"` // Usually only in the first delta
123
+ Content *string `json:"content,omitempty"` // Pointer as it can be null or empty
124
+ }
125
+
126
+ // OpenAIStreamChunk represents the structure of a data chunk in the OpenAI SSE stream
127
+ type OpenAIStreamChunk struct {
128
+ ID string `json:"id"`
129
+ Object string `json:"object"` // e.g., "chat.completion.chunk"
130
+ Created int64 `json:"created"`
131
+ Model string `json:"model"`
132
+ Choices []OpenAIStreamChoice `json:"choices"`
133
+ Usage *OpenAIUsage `json:"usage,omitempty"` // Usually null except maybe in Azure's final chunk?
134
+ // SystemFingerprint string `json:"system_fingerprint"` // Optional
135
+ }
136
+
137
+
138
+ // --- Claude SSE Structs (for sending back to client) ---
139
+
140
+ // ClaudeSSEEvent represents a generic Claude SSE event structure for easy marshaling
141
+ type ClaudeSSEEvent struct {
142
+ Type string `json:"type"`
143
+ Index *int `json:"index,omitempty"` // Used in content_block_*
144
+ Message *ClaudeSSEMessage `json:"message,omitempty"` // Used in message_start
145
+ ContentBlock *ClaudeSSEContentBlock `json:"content_block,omitempty"` // Used in content_block_start
146
+ Delta *ClaudeSSEDelta `json:"delta,omitempty"` // Used in content_block_delta, message_delta
147
+ Usage *ClaudeSSEUsage `json:"usage,omitempty"` // Used in message_delta (HACK), message_stop
148
+ Error *ClaudeError `json:"error,omitempty"` // Used in error event
149
+ }
150
+
151
+ // ClaudeSSEMessage is nested within message_start
152
+ type ClaudeSSEMessage struct {
153
+ ID string `json:"id"`
154
+ Type string `json:"type"` // "message"
155
+ Role string `json:"role"` // "assistant"
156
+ Content []ClaudeContentBlock `json:"content"` // Initially empty
157
+ Model string `json:"model"`
158
+ StopReason *string `json:"stop_reason"` // Initially null
159
+ StopSequence *string `json:"stop_sequence"` // Initially null
160
+ Usage ClaudeUsage `json:"usage"` // Initial usage (input tokens)
161
+ }
162
+
163
+ // ClaudeSSEContentBlock is nested within content_block_start
164
+ type ClaudeSSEContentBlock struct {
165
+ Type string `json:"type"` // "text"
166
+ Text string `json:"text"` // Initially empty
167
+ }
168
+
169
+ // ClaudeSSEDelta is nested within content_block_delta and message_delta
170
+ type ClaudeSSEDelta struct {
171
+ Type string `json:"type,omitempty"` // "text_delta" in content_block_delta
172
+ Text *string `json:"text,omitempty"` // Pointer for content_block_delta
173
+ StopReason *string `json:"stop_reason,omitempty"` // Pointer for message_delta
174
+ StopSequence *string `json:"stop_sequence,omitempty"` // Pointer for message_delta (usually null)
175
+ }
176
+
177
+ // ClaudeSSEUsage is nested within message_delta (HACK) and message_stop
178
+ type ClaudeSSEUsage struct {
179
+ // Note: message_delta only needs output_tokens for the hack
180
+ // message_stop should have both
181
+ InputTokens *int `json:"input_tokens,omitempty"` // Only in message_stop
182
+ OutputTokens int `json:"output_tokens"` // In both (but value differs)
183
+ }