Spaces:
Paused
Paused
Upload 10 files
Browse files- Dockerfile +45 -25
- auth.go +45 -0
- config.go +120 -0
- convert.go +236 -0
- go.mod +39 -0
- go.sum +0 -0
- handlers.go +218 -0
- main.go +103 -0
- sse.go +320 -0
- structs.go +183 -0
Dockerfile
CHANGED
|
@@ -1,34 +1,54 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
| 3 |
-
FROM python:3.10-slim
|
| 4 |
|
| 5 |
-
# Set environment variables
|
| 6 |
-
ENV
|
| 7 |
-
ENV PYTHONUNBUFFERED 1
|
| 8 |
-
# --- Change default port to 7860 ---
|
| 9 |
-
ENV PORT=7860
|
| 10 |
|
| 11 |
-
|
| 12 |
-
WORKDIR /app
|
| 13 |
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
RUN
|
| 21 |
|
| 22 |
-
# Copy the rest of the
|
| 23 |
COPY . .
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
RUN
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
-
#
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
#
|
| 34 |
-
|
|
|
|
| 1 |
+
# Stage 1: Build the Go application
|
| 2 |
+
FROM golang:1.21-alpine AS builder
|
|
|
|
| 3 |
|
| 4 |
+
# Set necessary environment variables
|
| 5 |
+
ENV CGO_ENABLED=0 GOOS=linux GOARCH=amd64
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
WORKDIR /build
|
|
|
|
| 8 |
|
| 9 |
+
# Copy only the module definition file first
|
| 10 |
+
COPY go.mod ./
|
| 11 |
+
# DO NOT copy go.sum here initially.
|
| 12 |
|
| 13 |
+
# Download dependencies and ensure go.sum is consistent
|
| 14 |
+
# 'go mod tidy' synchronizes the go.mod and go.sum files with the source code imports.
|
| 15 |
+
# It should be run after copying the source code.
|
| 16 |
+
RUN go mod download
|
| 17 |
|
| 18 |
+
# Copy the rest of the source code AFTER initial download
|
| 19 |
COPY . .
|
| 20 |
|
| 21 |
+
# Now run tidy to ensure go.mod and go.sum match the code
|
| 22 |
+
RUN go mod tidy
|
| 23 |
+
|
| 24 |
+
# Verify dependencies (optional but good practice)
|
| 25 |
+
RUN go mod verify
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Build the Go application statically linked
|
| 29 |
+
# -ldflags="-w -s" reduces binary size by removing debug info
|
| 30 |
+
RUN go build -ldflags="-w -s" -o /app/proxy-server .
|
| 31 |
+
|
| 32 |
+
# Stage 2: Create the final minimal image
|
| 33 |
+
FROM alpine:latest
|
| 34 |
+
|
| 35 |
+
# Install ca-certificates for HTTPS calls and tzdata for timezone info
|
| 36 |
+
RUN apk update && apk add --no-cache ca-certificates tzdata
|
| 37 |
+
|
| 38 |
+
# Set the working directory
|
| 39 |
+
WORKDIR /app
|
| 40 |
+
|
| 41 |
+
# Copy the built binary from the builder stage
|
| 42 |
+
COPY --from=builder /app/proxy-server /app/proxy-server
|
| 43 |
+
|
| 44 |
+
# Expose the port the app runs on (using the default from config or ENV)
|
| 45 |
+
# Defaulting to 7860 if not overridden by ENV PORT in runtime environment
|
| 46 |
+
EXPOSE 7860
|
| 47 |
|
| 48 |
+
# Set the entrypoint command to run the binary
|
| 49 |
+
# The application will read environment variables at runtime
|
| 50 |
+
ENTRYPOINT ["/app/proxy-server"]
|
| 51 |
|
| 52 |
+
# Optional: Add a non-root user for security (Uncomment if needed)
|
| 53 |
+
# RUN addgroup -S appgroup && adduser -S appuser -G appgroup
|
| 54 |
+
# USER appuser
|
auth.go
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
"net/http"
|
| 6 |
+
|
| 7 |
+
"github.com/gin-gonic/gin"
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
const APIKeyHeaderName = "X-API-Key"
|
| 11 |
+
|
| 12 |
+
// APIKeyAuthMiddleware creates a Gin middleware for API key authentication
|
| 13 |
+
func APIKeyAuthMiddleware(validKeys map[string]bool) gin.HandlerFunc {
|
| 14 |
+
return func(c *gin.Context) {
|
| 15 |
+
apiKey := c.GetHeader(APIKeyHeaderName)
|
| 16 |
+
|
| 17 |
+
if apiKey == "" {
|
| 18 |
+
log.Printf("WARN: [%s] API Key missing in header '%s'", c.ClientIP(), APIKeyHeaderName)
|
| 19 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
| 20 |
+
"type": "error",
|
| 21 |
+
"error": gin.H{
|
| 22 |
+
"type": "authentication_error",
|
| 23 |
+
"message": "API Key required in header '" + APIKeyHeaderName + "'",
|
| 24 |
+
},
|
| 25 |
+
})
|
| 26 |
+
return
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
if _, isValid := validKeys[apiKey]; !isValid {
|
| 30 |
+
log.Printf("WARN: [%s] Invalid API Key received (length: %d)", c.ClientIP(), len(apiKey))
|
| 31 |
+
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
| 32 |
+
"type": "error",
|
| 33 |
+
"error": gin.H{
|
| 34 |
+
"type": "authentication_error",
|
| 35 |
+
"message": "Invalid or expired API Key",
|
| 36 |
+
},
|
| 37 |
+
})
|
| 38 |
+
return
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
// Log successful authentication (optional, consider security implications)
|
| 42 |
+
// log.Printf("INFO: [%s] Valid API key received (length: %d)", c.ClientIP(), len(apiKey))
|
| 43 |
+
c.Next() // Proceed to the next handler
|
| 44 |
+
}
|
| 45 |
+
}
|
config.go
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"log"
|
| 5 |
+
"net" // <<< Added import
|
| 6 |
+
"net/http"
|
| 7 |
+
"net/url"
|
| 8 |
+
"os"
|
| 9 |
+
"strconv"
|
| 10 |
+
"strings"
|
| 11 |
+
"time"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
// Config holds all configuration for the application
|
| 15 |
+
type Config struct {
|
| 16 |
+
OpenAIAPIEndpoint string
|
| 17 |
+
OpenAIAPIKey string
|
| 18 |
+
ProxyAPIKeys string // Comma-separated keys
|
| 19 |
+
ValidAPIKeys map[string]bool // Set of valid keys for quick lookup
|
| 20 |
+
ConnectTimeout time.Duration
|
| 21 |
+
ReadTimeout time.Duration
|
| 22 |
+
WriteTimeout time.Duration
|
| 23 |
+
PoolTimeout time.Duration // Note: Go's default transport manages pooling differently
|
| 24 |
+
HTTPProxyURL *url.URL
|
| 25 |
+
Port string
|
| 26 |
+
GinMode string
|
| 27 |
+
LogLevel string // For potential future structured logging integration
|
| 28 |
+
UpstreamTransport http.RoundTripper // Custom transport for http client
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
// LoadConfig reads configuration from environment variables
|
| 32 |
+
func LoadConfig() *Config {
|
| 33 |
+
cfg := &Config{
|
| 34 |
+
OpenAIAPIEndpoint: getEnv("OPENAI_API_ENDPOINT", "https://api.openai.com/v1/chat/completions"),
|
| 35 |
+
OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""),
|
| 36 |
+
ProxyAPIKeys: getEnv("PROXY_API_KEYS", ""),
|
| 37 |
+
ConnectTimeout: getEnvDuration("CONNECT_TIMEOUT", 5*time.Second),
|
| 38 |
+
ReadTimeout: getEnvDuration("READ_TIMEOUT", 180*time.Second),
|
| 39 |
+
WriteTimeout: getEnvDuration("WRITE_TIMEOUT", 30*time.Second),
|
| 40 |
+
PoolTimeout: getEnvDuration("POOL_TIMEOUT", 5*time.Second), // Less directly applicable in Go's default client
|
| 41 |
+
Port: getEnv("PORT", "7860"),
|
| 42 |
+
GinMode: getEnv("GIN_MODE", "release"), // "debug" or "release"
|
| 43 |
+
LogLevel: getEnv("LOG_LEVEL", "INFO"),
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
// Process API Keys into a map for efficient lookup
|
| 47 |
+
cfg.ValidAPIKeys = make(map[string]bool)
|
| 48 |
+
if cfg.ProxyAPIKeys != "" {
|
| 49 |
+
keys := strings.Split(cfg.ProxyAPIKeys, ",")
|
| 50 |
+
for _, key := range keys {
|
| 51 |
+
trimmedKey := strings.TrimSpace(key)
|
| 52 |
+
if trimmedKey != "" {
|
| 53 |
+
cfg.ValidAPIKeys[trimmedKey] = true
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// Parse HTTP Proxy URL
|
| 59 |
+
proxyStr := getEnv("HTTP_PROXY", "")
|
| 60 |
+
if proxyStr != "" {
|
| 61 |
+
proxyURL, err := url.Parse(proxyStr)
|
| 62 |
+
if err != nil {
|
| 63 |
+
log.Printf("WARN: Invalid HTTP_PROXY URL '%s': %v. Proxy disabled.", proxyStr, err)
|
| 64 |
+
cfg.HTTPProxyURL = nil
|
| 65 |
+
} else {
|
| 66 |
+
cfg.HTTPProxyURL = proxyURL
|
| 67 |
+
log.Printf("Using outbound proxy: %s", cfg.HTTPProxyURL.String())
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
// Configure the shared HTTP client transport
|
| 72 |
+
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
| 73 |
+
if cfg.HTTPProxyURL != nil { // Set proxy only if URL is valid
|
| 74 |
+
defaultTransport.Proxy = http.ProxyURL(cfg.HTTPProxyURL)
|
| 75 |
+
}
|
| 76 |
+
// Configure timeouts (Connect timeout is part of DialContext)
|
| 77 |
+
defaultTransport.DialContext = (&net.Dialer{ // <<< Used net.Dialer here
|
| 78 |
+
Timeout: cfg.ConnectTimeout, // Connect timeout
|
| 79 |
+
KeepAlive: 30 * time.Second, // Keep-alive interval
|
| 80 |
+
}).DialContext
|
| 81 |
+
defaultTransport.TLSHandshakeTimeout = 10 * time.Second // TLS handshake timeout
|
| 82 |
+
defaultTransport.ResponseHeaderTimeout = cfg.ReadTimeout // Timeout waiting for response headers
|
| 83 |
+
// Go's http client manages connection pooling automatically.
|
| 84 |
+
// MaxIdleConns, MaxIdleConnsPerHost can be tuned if needed.
|
| 85 |
+
defaultTransport.MaxIdleConns = 100
|
| 86 |
+
defaultTransport.MaxIdleConnsPerHost = 10
|
| 87 |
+
defaultTransport.IdleConnTimeout = 90 * time.Second
|
| 88 |
+
|
| 89 |
+
cfg.UpstreamTransport = defaultTransport
|
| 90 |
+
|
| 91 |
+
// Log warnings for missing keys
|
| 92 |
+
if cfg.OpenAIAPIKey == "" {
|
| 93 |
+
log.Println("WARN: OPENAI_API_KEY is not set.")
|
| 94 |
+
}
|
| 95 |
+
if len(cfg.ValidAPIKeys) == 0 {
|
| 96 |
+
log.Println("WARN: PROXY_API_KEYS is not set. Proxy is open (no authentication).")
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return cfg
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// getEnv reads an environment variable or returns a default value
|
| 103 |
+
func getEnv(key, defaultValue string) string {
|
| 104 |
+
if value, exists := os.LookupEnv(key); exists {
|
| 105 |
+
return value
|
| 106 |
+
}
|
| 107 |
+
return defaultValue
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// getEnvDuration reads an environment variable as seconds and returns a time.Duration
|
| 111 |
+
func getEnvDuration(key string, defaultValue time.Duration) time.Duration {
|
| 112 |
+
valueStr := getEnv(key, "")
|
| 113 |
+
if valueStr != "" {
|
| 114 |
+
if valueFloat, err := strconv.ParseFloat(valueStr, 64); err == nil {
|
| 115 |
+
return time.Duration(valueFloat * float64(time.Second))
|
| 116 |
+
}
|
| 117 |
+
log.Printf("WARN: Invalid duration format for %s: '%s'. Using default: %v", key, valueStr, defaultValue)
|
| 118 |
+
}
|
| 119 |
+
return defaultValue
|
| 120 |
+
}
|
convert.go
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"encoding/json"
|
| 5 |
+
"errors"
|
| 6 |
+
"fmt"
|
| 7 |
+
"log"
|
| 8 |
+
"strings"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
// Helper function to estimate tokens (same as Python version)
|
| 12 |
+
func estimateTokens(text string) int {
|
| 13 |
+
if text == "" {
|
| 14 |
+
return 0
|
| 15 |
+
}
|
| 16 |
+
est := len(text) / 3 // Simple heuristic
|
| 17 |
+
if est == 0 && len(text) > 0 {
|
| 18 |
+
return 1 // Ensure at least 1 token for non-empty string
|
| 19 |
+
}
|
| 20 |
+
return est
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// calculateInputTokensFromClaudeRequest estimates input tokens from Claude request
|
| 24 |
+
func calculateInputTokensFromClaudeRequest(claudeReq *ClaudeRequest) int {
|
| 25 |
+
totalChars := 0
|
| 26 |
+
|
| 27 |
+
// Process system prompt
|
| 28 |
+
if len(claudeReq.System) > 0 {
|
| 29 |
+
// Try unmarshaling as string first
|
| 30 |
+
var systemStr string
|
| 31 |
+
if err := json.Unmarshal(claudeReq.System, &systemStr); err == nil {
|
| 32 |
+
totalChars += len(systemStr)
|
| 33 |
+
} else {
|
| 34 |
+
// Try unmarshaling as list of blocks
|
| 35 |
+
var systemBlocks []ClaudeContentBlock
|
| 36 |
+
if err := json.Unmarshal(claudeReq.System, &systemBlocks); err == nil {
|
| 37 |
+
for _, block := range systemBlocks {
|
| 38 |
+
if block.Type == "text" {
|
| 39 |
+
totalChars += len(block.Text)
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
} else {
|
| 43 |
+
log.Printf("WARN: Could not parse system prompt format: %s", string(claudeReq.System))
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// Process messages
|
| 49 |
+
for _, msg := range claudeReq.Messages {
|
| 50 |
+
// Try unmarshaling as string first
|
| 51 |
+
var contentStr string
|
| 52 |
+
if err := json.Unmarshal(msg.Content, &contentStr); err == nil {
|
| 53 |
+
totalChars += len(contentStr)
|
| 54 |
+
} else {
|
| 55 |
+
// Try unmarshaling as list of blocks
|
| 56 |
+
var contentBlocks []ClaudeContentBlock
|
| 57 |
+
if err := json.Unmarshal(msg.Content, &contentBlocks); err == nil {
|
| 58 |
+
for _, block := range contentBlocks {
|
| 59 |
+
if block.Type == "text" {
|
| 60 |
+
totalChars += len(block.Text)
|
| 61 |
+
}
|
| 62 |
+
}
|
| 63 |
+
} else {
|
| 64 |
+
log.Printf("WARN: Could not parse message content format for role %s: %s", msg.Role, string(msg.Content))
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
estimated := estimateTokens(fmt.Sprintf("%d", totalChars)) // Pass total chars as string to estimate
|
| 70 |
+
log.Printf("DEBUG: Estimated input characters: %d, Estimated input tokens: %d", totalChars, estimated)
|
| 71 |
+
return estimated
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
// convertClaudeRequestToOpenAI converts Claude request to OpenAI format
|
| 75 |
+
func convertClaudeRequestToOpenAI(claudeReq *ClaudeRequest) (*OpenAIRequest, error) {
|
| 76 |
+
openAIMessages := []OpenAIMessage{}
|
| 77 |
+
|
| 78 |
+
// --- Handle System Prompt ---
|
| 79 |
+
if len(claudeReq.System) > 0 {
|
| 80 |
+
systemContent := ""
|
| 81 |
+
var systemStr string
|
| 82 |
+
// Try simple string first
|
| 83 |
+
if err := json.Unmarshal(claudeReq.System, &systemStr); err == nil {
|
| 84 |
+
systemContent = systemStr
|
| 85 |
+
} else {
|
| 86 |
+
// Try list of blocks
|
| 87 |
+
var systemBlocks []ClaudeContentBlock
|
| 88 |
+
if err := json.Unmarshal(claudeReq.System, &systemBlocks); err == nil {
|
| 89 |
+
var parts []string
|
| 90 |
+
for _, block := range systemBlocks {
|
| 91 |
+
if block.Type == "text" {
|
| 92 |
+
parts = append(parts, block.Text)
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
systemContent = strings.Join(parts, "\n")
|
| 96 |
+
} else {
|
| 97 |
+
log.Printf("WARN: Could not parse system prompt format for conversion: %s", string(claudeReq.System))
|
| 98 |
+
// Decide how to handle - skip system prompt or return error? Skipping for now.
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
if systemContent != "" {
|
| 102 |
+
openAIMessages = append(openAIMessages, OpenAIMessage{Role: "system", Content: systemContent})
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
// --- Handle Messages ---
|
| 107 |
+
for _, msg := range claudeReq.Messages {
|
| 108 |
+
if msg.Role != "user" && msg.Role != "assistant" {
|
| 109 |
+
log.Printf("WARN: Skipping message with unsupported role: %s", msg.Role)
|
| 110 |
+
continue
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
messageContent := ""
|
| 114 |
+
var contentStr string
|
| 115 |
+
// Try simple string first
|
| 116 |
+
if err := json.Unmarshal(msg.Content, &contentStr); err == nil {
|
| 117 |
+
messageContent = contentStr
|
| 118 |
+
} else {
|
| 119 |
+
// Try list of blocks
|
| 120 |
+
var contentBlocks []ClaudeContentBlock
|
| 121 |
+
if err := json.Unmarshal(msg.Content, &contentBlocks); err == nil {
|
| 122 |
+
var parts []string
|
| 123 |
+
for _, block := range contentBlocks {
|
| 124 |
+
if block.Type == "text" {
|
| 125 |
+
parts = append(parts, block.Text)
|
| 126 |
+
} else {
|
| 127 |
+
log.Printf("WARN: Skipping non-text content block type '%s' in message for role %s", block.Type, msg.Role)
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
messageContent = strings.Join(parts, "\n")
|
| 131 |
+
} else {
|
| 132 |
+
log.Printf("WARN: Could not parse message content format for role %s during conversion: %s", msg.Role, string(msg.Content))
|
| 133 |
+
// Skip message if content parsing fails
|
| 134 |
+
continue
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
if messageContent != "" || msg.Role == "assistant" { // Allow empty assistant messages if needed? Check OpenAI spec. Usually needs content.
|
| 139 |
+
openAIMessages = append(openAIMessages, OpenAIMessage{Role: msg.Role, Content: messageContent})
|
| 140 |
+
} else {
|
| 141 |
+
log.Printf("WARN: Skipping message for role %s with no valid text content after parsing.", msg.Role)
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
if len(openAIMessages) == 0 {
|
| 146 |
+
return nil, errors.New("conversion resulted in no valid messages for OpenAI request")
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
// --- Construct OpenAI Request ---
|
| 150 |
+
openAIReq := &OpenAIRequest{
|
| 151 |
+
Model: claudeReq.Model, // Use the model specified in Claude request
|
| 152 |
+
Messages: openAIMessages,
|
| 153 |
+
Stream: claudeReq.Stream,
|
| 154 |
+
MaxTokens: claudeReq.MaxTokens,
|
| 155 |
+
Temperature: claudeReq.Temperature,
|
| 156 |
+
TopP: claudeReq.TopP,
|
| 157 |
+
// Stop sequences mapping
|
| 158 |
+
Stop: claudeReq.StopSequences,
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Default model if not provided
|
| 162 |
+
if openAIReq.Model == "" {
|
| 163 |
+
openAIReq.Model = "gpt-3.5-turbo" // Or get from config
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
return openAIReq, nil
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// mapOpenAIFinishReasonToClaude maps OpenAI finish reason to Claude stop reason
|
| 170 |
+
func mapOpenAIFinishReasonToClaude(openAIFinishReason *string) string {
|
| 171 |
+
if openAIFinishReason == nil {
|
| 172 |
+
return "end_turn" // Default if nil
|
| 173 |
+
}
|
| 174 |
+
reason := *openAIFinishReason
|
| 175 |
+
switch reason {
|
| 176 |
+
case "stop":
|
| 177 |
+
return "end_turn"
|
| 178 |
+
case "length":
|
| 179 |
+
return "max_tokens"
|
| 180 |
+
case "function_call", "tool_calls":
|
| 181 |
+
return "tool_use"
|
| 182 |
+
case "content_filter":
|
| 183 |
+
return "stop_sequence" // Or maybe map to an error type?
|
| 184 |
+
default:
|
| 185 |
+
log.Printf("WARN: Unknown OpenAI finish reason '%s', mapping to 'end_turn'", reason)
|
| 186 |
+
return "end_turn" // Default for unknown reasons
|
| 187 |
+
}
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
// convertOpenAIResponseToClaude converts non-streaming OpenAI response to Claude format
|
| 191 |
+
func convertOpenAIResponseToClaude(openAIResp *OpenAIResponse, claudeRequestID string) (*ClaudeResponse, error) {
|
| 192 |
+
if len(openAIResp.Choices) == 0 {
|
| 193 |
+
return nil, errors.New("OpenAI response has no choices")
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
choice := openAIResp.Choices[0]
|
| 197 |
+
claudeStopReason := mapOpenAIFinishReasonToClaude(choice.FinishReason)
|
| 198 |
+
|
| 199 |
+
// --- Prepare Usage ---
|
| 200 |
+
claudeUsage := ClaudeUsage{}
|
| 201 |
+
if openAIResp.Usage != nil {
|
| 202 |
+
claudeUsage.InputTokens = openAIResp.Usage.PromptTokens
|
| 203 |
+
claudeUsage.OutputTokens = openAIResp.Usage.CompletionTokens
|
| 204 |
+
} else {
|
| 205 |
+
log.Printf("WARN: [%s] Usage data missing in non-streaming OpenAI response", claudeRequestID)
|
| 206 |
+
// Potentially estimate usage here if critical, otherwise leave as zero
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
// --- Prepare Content ---
|
| 210 |
+
claudeContent := []ClaudeContentBlock{
|
| 211 |
+
{
|
| 212 |
+
Type: "text",
|
| 213 |
+
Text: choice.Message.Content, // Assuming message content is always text
|
| 214 |
+
},
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
// --- Construct Claude Response ---
|
| 218 |
+
claudeResp := &ClaudeResponse{
|
| 219 |
+
ID: openAIResp.ID, // Use OpenAI's response ID
|
| 220 |
+
Type: "message",
|
| 221 |
+
Role: "assistant", // Assuming OpenAI response role is assistant
|
| 222 |
+
Content: claudeContent,
|
| 223 |
+
Model: openAIResp.Model, // Use the model OpenAI reported
|
| 224 |
+
StopReason: claudeStopReason,
|
| 225 |
+
StopSequence: nil, // Typically null
|
| 226 |
+
Usage: claudeUsage,
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// Use original request ID if OpenAI ID is missing (shouldn't happen often)
|
| 230 |
+
if claudeResp.ID == "" {
|
| 231 |
+
log.Printf("WARN: OpenAI response ID missing, using original request ID: %s", claudeRequestID)
|
| 232 |
+
claudeResp.ID = claudeRequestID
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
return claudeResp, nil
|
| 236 |
+
}
|
go.mod
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
module claude-proxy-go // Bạn có thể thay đổi "claude-proxy-go" thành tên module mong muốn
|
| 2 |
+
|
| 3 |
+
go 1.21 // Hoặc phiên bản Go bạn muốn sử dụng (phù hợp với Dockerfile)
|
| 4 |
+
|
| 5 |
+
require (
|
| 6 |
+
github.com/gin-contrib/cors v1.7.2
|
| 7 |
+
github.com/gin-gonic/gin v1.10.0
|
| 8 |
+
github.com/google/uuid v1.6.0
|
| 9 |
+
github.com/joho/godotenv v1.5.1 // Optional: for local .env loading
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
require (
|
| 13 |
+
github.com/bytedance/sonic v1.11.6 // indirect
|
| 14 |
+
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
| 15 |
+
github.com/cloudwego/base64x v0.1.4 // indirect
|
| 16 |
+
github.com/cloudwego/iasm v0.2.0 // indirect
|
| 17 |
+
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
| 18 |
+
github.com/gin-contrib/sse v0.1.0 // indirect
|
| 19 |
+
github.com/go-playground/locales v0.14.1 // indirect
|
| 20 |
+
github.com/go-playground/universal-translator v0.18.1 // indirect
|
| 21 |
+
github.com/go-playground/validator/v10 v10.20.0 // indirect
|
| 22 |
+
github.com/goccy/go-json v0.10.2 // indirect
|
| 23 |
+
github.com/json-iterator/go v1.1.12 // indirect
|
| 24 |
+
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
| 25 |
+
github.com/leodido/go-urn v1.4.0 // indirect
|
| 26 |
+
github.com/mattn/go-isatty v0.0.20 // indirect
|
| 27 |
+
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
| 28 |
+
github.com/modern-go/reflect2 v1.0.2 // indirect
|
| 29 |
+
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
| 30 |
+
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
| 31 |
+
github.com/ugorji/go/codec v1.2.12 // indirect
|
| 32 |
+
golang.org/x/arch v0.8.0 // indirect
|
| 33 |
+
golang.org/x/crypto v0.23.0 // indirect
|
| 34 |
+
golang.org/x/net v0.25.0 // indirect
|
| 35 |
+
golang.org/x/sys v0.20.0 // indirect
|
| 36 |
+
golang.org/x/text v0.15.0 // indirect
|
| 37 |
+
google.golang.org/protobuf v1.34.1 // indirect
|
| 38 |
+
gopkg.in/yaml.v3 v3.0.1 // indirect
|
| 39 |
+
)
|
go.sum
ADDED
|
File without changes
|
handlers.go
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bytes"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"io"
|
| 8 |
+
"log"
|
| 9 |
+
"net" // <<< Added import
|
| 10 |
+
"net/http"
|
| 11 |
+
"strings" // <<< Added import
|
| 12 |
+
"time"
|
| 13 |
+
|
| 14 |
+
"github.com/gin-gonic/gin"
|
| 15 |
+
"github.com/google/uuid"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
// HealthCheckHandler handles the /health endpoint
|
| 19 |
+
func HealthCheckHandler(c *gin.Context) {
|
| 20 |
+
c.JSON(http.StatusOK, gin.H{"status": "healthy"})
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// MessagesHandler handles the /v1/messages endpoint
|
| 24 |
+
func MessagesHandler(c *gin.Context) {
|
| 25 |
+
requestID := fmt.Sprintf("msg_%s", uuid.NewString()[:24]) // Generate unique request ID
|
| 26 |
+
cfg := LoadConfig() // Load config (consider passing it down instead of reloading)
|
| 27 |
+
|
| 28 |
+
// --- 1. Read and Parse Incoming Request ---
|
| 29 |
+
var claudeReq ClaudeRequest
|
| 30 |
+
bodyBytes, err := io.ReadAll(c.Request.Body)
|
| 31 |
+
if err != nil {
|
| 32 |
+
log.Printf("ERROR: [%s] Failed to read request body: %v", requestID, err)
|
| 33 |
+
sendClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Could not read request body.")
|
| 34 |
+
return
|
| 35 |
+
}
|
| 36 |
+
// Restore body for potential re-reads (though we don't re-read here)
|
| 37 |
+
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
| 38 |
+
|
| 39 |
+
if err := json.Unmarshal(bodyBytes, &claudeReq); err != nil {
|
| 40 |
+
log.Printf("ERROR: [%s] Failed to decode request JSON: %v. Body: %s", requestID, err, string(bodyBytes))
|
| 41 |
+
sendClaudeError(c, http.StatusBadRequest, "invalid_request_error", "Invalid JSON format in request body.")
|
| 42 |
+
return
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
isStreaming := claudeReq.Stream
|
| 46 |
+
modelRequested := claudeReq.Model
|
| 47 |
+
if modelRequested == "" {
|
| 48 |
+
modelRequested = "unknown_model"
|
| 49 |
+
} // Handle empty model case
|
| 50 |
+
log.Printf("INFO: [%s] Received request. Stream: %t. Model: %s", requestID, isStreaming, modelRequested)
|
| 51 |
+
// Optional: Log full payload at debug level
|
| 52 |
+
// log.Printf("DEBUG: [%s] Received Payload: %s", requestID, string(bodyBytes))
|
| 53 |
+
|
| 54 |
+
// --- 2. Convert Request Format ---
|
| 55 |
+
openAIReq, err := convertClaudeRequestToOpenAI(&claudeReq)
|
| 56 |
+
if err != nil {
|
| 57 |
+
log.Printf("ERROR: [%s] Failed to convert Claude request to OpenAI format: %v", requestID, err)
|
| 58 |
+
sendClaudeError(c, http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Error converting request data: %v", err))
|
| 59 |
+
return
|
| 60 |
+
}
|
| 61 |
+
// Ensure stream flag is correctly set in the converted request
|
| 62 |
+
openAIReq.Stream = isStreaming
|
| 63 |
+
|
| 64 |
+
// --- 3. Prepare and Send Upstream Request ---
|
| 65 |
+
// Marshal the OpenAI request body
|
| 66 |
+
openaiReqBytes, err := json.Marshal(openAIReq)
|
| 67 |
+
if err != nil {
|
| 68 |
+
log.Printf("ERROR: [%s] Failed to marshal OpenAI request JSON: %v", requestID, err)
|
| 69 |
+
sendClaudeError(c, http.StatusInternalServerError, "internal_server_error", "Failed to prepare upstream request.")
|
| 70 |
+
return
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
// Create the HTTP request to the upstream endpoint
|
| 74 |
+
upstreamURL := cfg.OpenAIAPIEndpoint
|
| 75 |
+
req, err := http.NewRequestWithContext(c.Request.Context(), "POST", upstreamURL, bytes.NewBuffer(openaiReqBytes))
|
| 76 |
+
if err != nil {
|
| 77 |
+
log.Printf("ERROR: [%s] Failed to create upstream HTTP request: %v", requestID, err)
|
| 78 |
+
sendClaudeError(c, http.StatusInternalServerError, "internal_server_error", "Failed to create upstream request.")
|
| 79 |
+
return
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
// Set headers for upstream request
|
| 83 |
+
req.Header.Set("Content-Type", "application/json")
|
| 84 |
+
if isStreaming {
|
| 85 |
+
req.Header.Set("Accept", "text/event-stream")
|
| 86 |
+
} else {
|
| 87 |
+
req.Header.Set("Accept", "application/json")
|
| 88 |
+
}
|
| 89 |
+
if cfg.OpenAIAPIKey != "" {
|
| 90 |
+
req.Header.Set("Authorization", "Bearer "+cfg.OpenAIAPIKey)
|
| 91 |
+
}
|
| 92 |
+
// Copy potentially relevant headers from original request? (e.g., User-Agent) - Be cautious about security.
|
| 93 |
+
// req.Header.Set("User-Agent", c.GetHeader("User-Agent"))
|
| 94 |
+
|
| 95 |
+
// Log upstream request details (optional, redact sensitive info)
|
| 96 |
+
// log.Printf("DEBUG: [%s] Sending upstream request to %s. Headers: %v", requestID, upstreamURL, req.Header)
|
| 97 |
+
// log.Printf("DEBUG: [%s] Upstream Payload: %s", requestID, string(openaiReqBytes))
|
| 98 |
+
log.Printf("INFO: [%s] Sending upstream request (Stream=%t) to %s...", requestID, isStreaming, upstreamURL)
|
| 99 |
+
|
| 100 |
+
// --- Execute Upstream Request ---
|
| 101 |
+
// Use a client with the configured transport (timeouts, proxy)
|
| 102 |
+
httpClient := &http.Client{
|
| 103 |
+
Transport: cfg.UpstreamTransport,
|
| 104 |
+
Timeout: 0, // Timeout is handled by the transport's ResponseHeaderTimeout and DialContext Timeout
|
| 105 |
+
}
|
| 106 |
+
startTime := time.Now()
|
| 107 |
+
upstreamResp, err := httpClient.Do(req)
|
| 108 |
+
if err != nil {
|
| 109 |
+
// Handle client-side errors (network, DNS, timeout before connection, etc.)
|
| 110 |
+
log.Printf("ERROR: [%s] Upstream request failed: %v", requestID, err)
|
| 111 |
+
// Check for timeout specifically
|
| 112 |
+
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { // <<< Used net.Error here
|
| 113 |
+
sendClaudeError(c, http.StatusGatewayTimeout, "api_error", fmt.Sprintf("Gateway Timeout connecting to upstream (%v).", cfg.ConnectTimeout))
|
| 114 |
+
} else {
|
| 115 |
+
sendClaudeError(c, http.StatusBadGateway, "api_error", fmt.Sprintf("Bad Gateway: Could not connect to upstream. Error: %v", err))
|
| 116 |
+
}
|
| 117 |
+
return
|
| 118 |
+
}
|
| 119 |
+
// Note: We don't close upstreamResp.Body here yet, it's needed for streaming or reading non-streaming body.
|
| 120 |
+
// It will be closed by the streaming handler or after reading the body in non-streaming case.
|
| 121 |
+
log.Printf("INFO: [%s] Received upstream status: %d (%s)", requestID, upstreamResp.StatusCode, http.StatusText(upstreamResp.StatusCode))
|
| 122 |
+
|
| 123 |
+
// --- 4. Process Upstream Response ---
|
| 124 |
+
|
| 125 |
+
// Handle non-OK status codes
|
| 126 |
+
if upstreamResp.StatusCode != http.StatusOK {
|
| 127 |
+
// Read error body from upstream
|
| 128 |
+
errorBodyBytes, readErr := io.ReadAll(upstreamResp.Body)
|
| 129 |
+
if readErr != nil {
|
| 130 |
+
log.Printf("WARN: [%s] Failed to read upstream error body (Status %d): %v", requestID, upstreamResp.StatusCode, readErr)
|
| 131 |
+
}
|
| 132 |
+
_ = upstreamResp.Body.Close() // Ensure body is closed after reading or error
|
| 133 |
+
|
| 134 |
+
errorBodyStr := string(errorBodyBytes)
|
| 135 |
+
log.Printf("ERROR: [%s] Upstream returned error status %d. Body: %s", requestID, upstreamResp.StatusCode, errorBodyStr)
|
| 136 |
+
|
| 137 |
+
// Try to map the error to Claude format
|
| 138 |
+
// Basic mapping, can be improved by parsing OpenAI error structure if available
|
| 139 |
+
var errorType string
|
| 140 |
+
switch upstreamResp.StatusCode {
|
| 141 |
+
case http.StatusBadRequest:
|
| 142 |
+
errorType = "invalid_request_error"
|
| 143 |
+
case http.StatusUnauthorized:
|
| 144 |
+
errorType = "authentication_error"
|
| 145 |
+
case http.StatusForbidden:
|
| 146 |
+
errorType = "permission_error"
|
| 147 |
+
case http.StatusTooManyRequests:
|
| 148 |
+
errorType = "rate_limit_error"
|
| 149 |
+
case http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
|
| 150 |
+
errorType = "api_error"
|
| 151 |
+
default:
|
| 152 |
+
errorType = "api_error" // Default for other errors
|
| 153 |
+
}
|
| 154 |
+
errMsg := fmt.Sprintf("Upstream API error (%d). Details: %s", upstreamResp.StatusCode, strings.TrimSpace(errorBodyStr)) // <<< Used strings.TrimSpace here
|
| 155 |
+
// Truncate long error messages if necessary
|
| 156 |
+
if len(errMsg) > 300 {
|
| 157 |
+
errMsg = errMsg[:300] + "..."
|
| 158 |
+
}
|
| 159 |
+
sendClaudeError(c, upstreamResp.StatusCode, errorType, errMsg)
|
| 160 |
+
return
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// --- Handle OK response based on streaming ---
|
| 164 |
+
if isStreaming {
|
| 165 |
+
log.Printf("INFO: [%s] Upstream stream received. Starting SSE conversion (Go v1.9.0 - Priority Delta).", requestID)
|
| 166 |
+
// Delegate to the SSE streaming function
|
| 167 |
+
// This function will handle reading upstreamResp.Body and closing it
|
| 168 |
+
streamOpenAIResponseToClaudeSSE(c, upstreamResp, requestID, openAIReq.Model, &claudeReq)
|
| 169 |
+
} else {
|
| 170 |
+
// --- Non-Streaming ---
|
| 171 |
+
log.Printf("INFO: [%s] Upstream non-stream response received. Converting.", requestID)
|
| 172 |
+
defer upstreamResp.Body.Close() // Ensure body is closed after reading
|
| 173 |
+
|
| 174 |
+
// Read and parse upstream JSON response
|
| 175 |
+
var openAIResp OpenAIResponse
|
| 176 |
+
bodyBytes, err := io.ReadAll(upstreamResp.Body)
|
| 177 |
+
if err != nil {
|
| 178 |
+
log.Printf("ERROR: [%s] Failed to read non-streaming upstream response body: %v", requestID, err)
|
| 179 |
+
sendClaudeError(c, http.StatusBadGateway, "api_error", "Failed to read upstream response.")
|
| 180 |
+
return
|
| 181 |
+
}
|
| 182 |
+
if err := json.Unmarshal(bodyBytes, &openAIResp); err != nil {
|
| 183 |
+
log.Printf("ERROR: [%s] Failed to decode non-streaming upstream JSON: %v. Body: %s", requestID, err, string(bodyBytes))
|
| 184 |
+
sendClaudeError(c, http.StatusBadGateway, "api_error", "Upstream API returned invalid JSON.")
|
| 185 |
+
return
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// Convert OpenAI response to Claude format
|
| 189 |
+
claudeResp, err := convertOpenAIResponseToClaude(&openAIResp, requestID)
|
| 190 |
+
if err != nil {
|
| 191 |
+
log.Printf("ERROR: [%s] Failed to convert non-streaming OpenAI response: %v", requestID, err)
|
| 192 |
+
sendClaudeError(c, http.StatusInternalServerError, "internal_server_error", fmt.Sprintf("Error processing upstream response: %v", err))
|
| 193 |
+
return
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
// Send the converted Claude response
|
| 197 |
+
c.JSON(http.StatusOK, claudeResp)
|
| 198 |
+
log.Printf("INFO: [%s] Successfully processed non-streaming request in %v", requestID, time.Since(startTime))
|
| 199 |
+
}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
// sendClaudeError is a helper to send standardized Claude error responses
|
| 203 |
+
func sendClaudeError(c *gin.Context, statusCode int, errorType string, message string) {
|
| 204 |
+
errResp := ClaudeErrorResponse{
|
| 205 |
+
Type: "error",
|
| 206 |
+
Error: ClaudeError{
|
| 207 |
+
Type: errorType,
|
| 208 |
+
Message: message,
|
| 209 |
+
},
|
| 210 |
+
}
|
| 211 |
+
// Ensure status code is in valid range, default to 500 if not
|
| 212 |
+
if statusCode < 400 || statusCode > 599 {
|
| 213 |
+
log.Printf("WARN: Invalid status code %d provided for error, defaulting to 500.", statusCode)
|
| 214 |
+
statusCode = http.StatusInternalServerError
|
| 215 |
+
}
|
| 216 |
+
c.AbortWithStatusJSON(statusCode, errResp)
|
| 217 |
+
}
|
| 218 |
+
|
main.go
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"context"
|
| 5 |
+
"fmt"
|
| 6 |
+
"log"
|
| 7 |
+
"net/http"
|
| 8 |
+
"os"
|
| 9 |
+
"os/signal"
|
| 10 |
+
"syscall"
|
| 11 |
+
"time"
|
| 12 |
+
|
| 13 |
+
"github.com/gin-contrib/cors"
|
| 14 |
+
"github.com/gin-gonic/gin"
|
| 15 |
+
"github.com/joho/godotenv" // Optional: for loading .env file
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
func main() {
|
| 19 |
+
// Load .env file if present (optional, good for local dev)
|
| 20 |
+
_ = godotenv.Load()
|
| 21 |
+
|
| 22 |
+
// Load configuration
|
| 23 |
+
cfg := LoadConfig()
|
| 24 |
+
|
| 25 |
+
// Set Gin mode (release or debug)
|
| 26 |
+
if cfg.GinMode == "release" {
|
| 27 |
+
gin.SetMode(gin.ReleaseMode)
|
| 28 |
+
} else {
|
| 29 |
+
gin.SetMode(gin.DebugMode)
|
| 30 |
+
}
|
| 31 |
+
log.Printf("Starting Go Proxy Server in %s mode...", gin.Mode())
|
| 32 |
+
|
| 33 |
+
// Initialize Gin router
|
| 34 |
+
router := gin.New()
|
| 35 |
+
|
| 36 |
+
// Middleware
|
| 37 |
+
router.Use(gin.Logger()) // Standard Gin logger
|
| 38 |
+
router.Use(gin.Recovery()) // Recover from panics
|
| 39 |
+
// CORS middleware (allow all for simplicity, adjust as needed)
|
| 40 |
+
router.Use(cors.New(cors.Config{
|
| 41 |
+
AllowOrigins: []string{"*"},
|
| 42 |
+
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
|
| 43 |
+
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type", "Authorization", "X-API-Key"}, // Include X-API-Key
|
| 44 |
+
ExposeHeaders: []string{"Content-Length"},
|
| 45 |
+
AllowCredentials: true,
|
| 46 |
+
MaxAge: 12 * time.Hour,
|
| 47 |
+
}))
|
| 48 |
+
|
| 49 |
+
// --- Routes ---
|
| 50 |
+
// Health check
|
| 51 |
+
router.GET("/health", HealthCheckHandler)
|
| 52 |
+
|
| 53 |
+
// Main proxy endpoint group
|
| 54 |
+
v1 := router.Group("/v1")
|
| 55 |
+
{
|
| 56 |
+
// Apply API Key Authentication middleware if keys are configured
|
| 57 |
+
if len(cfg.ValidAPIKeys) > 0 {
|
| 58 |
+
log.Printf("API Key authentication enabled (%d keys configured).", len(cfg.ValidAPIKeys))
|
| 59 |
+
v1.Use(APIKeyAuthMiddleware(cfg.ValidAPIKeys))
|
| 60 |
+
} else {
|
| 61 |
+
log.Println("WARN: No PROXY_API_KEYS configured. Proxy is open (no authentication).")
|
| 62 |
+
}
|
| 63 |
+
v1.POST("/messages", MessagesHandler)
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
// --- Server Setup ---
|
| 67 |
+
server := &http.Server{
|
| 68 |
+
Addr: fmt.Sprintf(":%s", cfg.Port),
|
| 69 |
+
Handler: router,
|
| 70 |
+
// Add timeouts for production hardening
|
| 71 |
+
ReadTimeout: 10 * time.Second,
|
| 72 |
+
WriteTimeout: cfg.ReadTimeout + 30*time.Second, // Ensure write timeout is longer than read timeout for streaming
|
| 73 |
+
IdleTimeout: 120 * time.Second,
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// --- Graceful Shutdown ---
|
| 77 |
+
// Run server in a goroutine so it doesn't block
|
| 78 |
+
go func() {
|
| 79 |
+
log.Printf("Server listening on port %s", cfg.Port)
|
| 80 |
+
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
| 81 |
+
log.Fatalf("listen: %s\n", err)
|
| 82 |
+
}
|
| 83 |
+
}()
|
| 84 |
+
|
| 85 |
+
// Wait for interrupt signal to gracefully shut down the server
|
| 86 |
+
quit := make(chan os.Signal, 1)
|
| 87 |
+
// kill (no param) default send syscall.SIGTERM
|
| 88 |
+
// kill -2 is syscall.SIGINT
|
| 89 |
+
// kill -9 is syscall.SIGKILL but can't be caught, so don't need to add it
|
| 90 |
+
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
| 91 |
+
<-quit
|
| 92 |
+
log.Println("Shutting down server...")
|
| 93 |
+
|
| 94 |
+
// The context is used to inform the server it has 5 seconds to finish
|
| 95 |
+
// the requests it is currently handling
|
| 96 |
+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
| 97 |
+
defer cancel()
|
| 98 |
+
if err := server.Shutdown(ctx); err != nil {
|
| 99 |
+
log.Fatal("Server forced to shutdown:", err)
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
log.Println("Server exiting")
|
| 103 |
+
}
|
sse.go
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import (
|
| 4 |
+
"bufio"
|
| 5 |
+
"encoding/json"
|
| 6 |
+
"fmt"
|
| 7 |
+
"log"
|
| 8 |
+
"net/http"
|
| 9 |
+
"strings"
|
| 10 |
+
"time"
|
| 11 |
+
|
| 12 |
+
"github.com/gin-gonic/gin"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
// streamOpenAIResponseToClaudeSSE handles the SSE streaming conversion
|
| 16 |
+
// v1.10.0: Sends message_delta with accumulated usage after each content delta.
|
| 17 |
+
func streamOpenAIResponseToClaudeSSE(
|
| 18 |
+
c *gin.Context,
|
| 19 |
+
upstreamResp *http.Response,
|
| 20 |
+
claudeRequestID string,
|
| 21 |
+
requestedModel string,
|
| 22 |
+
originalClaudeRequest *ClaudeRequest, // Pass original request for token calculation
|
| 23 |
+
) {
|
| 24 |
+
// Ensure correct headers for SSE are set
|
| 25 |
+
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
| 26 |
+
c.Writer.Header().Set("Cache-Control", "no-cache")
|
| 27 |
+
c.Writer.Header().Set("Connection", "keep-alive")
|
| 28 |
+
c.Writer.Header().Set("X-Content-Type-Options", "nosniff")
|
| 29 |
+
c.Writer.Flush() // Ensure headers are sent immediately
|
| 30 |
+
|
| 31 |
+
// --- State Variables ---
|
| 32 |
+
messageID := claudeRequestID
|
| 33 |
+
accumulatedContent := ""
|
| 34 |
+
var openAIFinishReason *string // Store the pointer
|
| 35 |
+
streamErrorOccurred := false
|
| 36 |
+
var errorDetails *ClaudeError // Store potential error details for final event
|
| 37 |
+
|
| 38 |
+
// Pre-calculate input tokens
|
| 39 |
+
calculatedInputTokens := calculateInputTokensFromClaudeRequest(originalClaudeRequest)
|
| 40 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Calculated input tokens: %d", messageID, calculatedInputTokens)
|
| 41 |
+
inputTokens := calculatedInputTokens
|
| 42 |
+
outputTokens := 0 // Initialize output tokens (will be updated frequently)
|
| 43 |
+
finalUsageReceivedFromStream := false
|
| 44 |
+
lastPingTime := time.Now()
|
| 45 |
+
eventIndex := 0 // For logging clarity
|
| 46 |
+
|
| 47 |
+
log.Printf("DEBUG: [%s] Starting SSE AggressiveDelta conversion (Go v1.10.0).", messageID)
|
| 48 |
+
|
| 49 |
+
// Use a channel to signal completion or error from the reading goroutine
|
| 50 |
+
doneChan := make(chan struct{})
|
| 51 |
+
errChan := make(chan error, 1) // Buffered channel for error
|
| 52 |
+
|
| 53 |
+
// Goroutine to read from the upstream response
|
| 54 |
+
go func() {
|
| 55 |
+
defer close(doneChan) // Signal completion when done
|
| 56 |
+
defer upstreamResp.Body.Close() // Ensure body is closed
|
| 57 |
+
|
| 58 |
+
scanner := bufio.NewScanner(upstreamResp.Body)
|
| 59 |
+
for scanner.Scan() {
|
| 60 |
+
// Check for client disconnect *before* processing line
|
| 61 |
+
select {
|
| 62 |
+
case <-c.Request.Context().Done():
|
| 63 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Client disconnected detected in read loop.", messageID)
|
| 64 |
+
return // Exit goroutine if client disconnected
|
| 65 |
+
default:
|
| 66 |
+
// Continue processing
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
line := scanner.Text()
|
| 70 |
+
if line == "" {
|
| 71 |
+
continue // Skip empty lines
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
if strings.HasPrefix(line, "data:") {
|
| 75 |
+
dataStr := strings.TrimSpace(strings.TrimPrefix(line, "data:"))
|
| 76 |
+
if dataStr == "[DONE]" {
|
| 77 |
+
log.Printf("DEBUG: [%s] SSE AggressiveDelta: Received [DONE] marker.", messageID)
|
| 78 |
+
return // Normal stream completion
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
var chunk OpenAIStreamChunk
|
| 82 |
+
if err := json.Unmarshal([]byte(dataStr), &chunk); err != nil {
|
| 83 |
+
log.Printf("WARN: [%s] SSE AggressiveDelta: Could not decode JSON chunk: %v. Data: %s", messageID, err, dataStr)
|
| 84 |
+
continue // Skip malformed chunks
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
// Process choices
|
| 88 |
+
if len(chunk.Choices) > 0 {
|
| 89 |
+
choice := chunk.Choices[0]
|
| 90 |
+
if choice.FinishReason != nil {
|
| 91 |
+
openAIFinishReason = choice.FinishReason // Store the pointer
|
| 92 |
+
log.Printf("DEBUG: [%s] SSE AggressiveDelta: Received OpenAI finish_reason: %s", messageID, *openAIFinishReason)
|
| 93 |
+
}
|
| 94 |
+
if choice.Delta.Content != nil {
|
| 95 |
+
contentChunk := *choice.Delta.Content
|
| 96 |
+
accumulatedContent += contentChunk
|
| 97 |
+
currentOutputTokens := estimateTokens(accumulatedContent) // Estimate based on current content
|
| 98 |
+
|
| 99 |
+
// --- Yield content_block_delta ---
|
| 100 |
+
deltaPayload := ClaudeSSEEvent{
|
| 101 |
+
Type: "content_block_delta",
|
| 102 |
+
Index: func() *int { i := 0; return &i }(), // Pointer to 0
|
| 103 |
+
Delta: &ClaudeSSEDelta{
|
| 104 |
+
Type: "text_delta",
|
| 105 |
+
Text: &contentChunk, // Pointer to the chunk
|
| 106 |
+
},
|
| 107 |
+
}
|
| 108 |
+
if !sendSSEEvent(c, "content_block_delta", deltaPayload, messageID, eventIndex) {
|
| 109 |
+
return // Stop if client disconnected
|
| 110 |
+
}
|
| 111 |
+
eventIndex++
|
| 112 |
+
|
| 113 |
+
// --- AGGRESSIVE DELTA: Yield message_delta with current usage ---
|
| 114 |
+
// Only send if output tokens have potentially changed
|
| 115 |
+
if currentOutputTokens != outputTokens {
|
| 116 |
+
outputTokens = currentOutputTokens // Update state
|
| 117 |
+
intermediateUsage := ClaudeSSEUsage{OutputTokens: outputTokens}
|
| 118 |
+
intermediateDeltaPayload := ClaudeSSEEvent{
|
| 119 |
+
Type: "message_delta",
|
| 120 |
+
Delta: &ClaudeSSEDelta{}, // Delta part is empty here, only usage matters
|
| 121 |
+
Usage: &intermediateUsage,
|
| 122 |
+
}
|
| 123 |
+
log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (INTERMEDIATE message_delta with usage): %+v", messageID, eventIndex, intermediateDeltaPayload)
|
| 124 |
+
if !sendSSEEvent(c, "message_delta", intermediateDeltaPayload, messageID, eventIndex) {
|
| 125 |
+
return // Stop if client disconnected
|
| 126 |
+
}
|
| 127 |
+
eventIndex++
|
| 128 |
+
}
|
| 129 |
+
// -----------------------------------------------------------------
|
| 130 |
+
}
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// Check for OpenAI usage block (still useful for final confirmation)
|
| 134 |
+
if chunk.Usage != nil {
|
| 135 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Received usage block in OpenAI stream: %+v", messageID, *chunk.Usage)
|
| 136 |
+
if chunk.Usage.CompletionTokens > 0 {
|
| 137 |
+
// If OpenAI provides a final count, trust it more than estimation
|
| 138 |
+
outputTokens = chunk.Usage.CompletionTokens
|
| 139 |
+
finalUsageReceivedFromStream = true
|
| 140 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Using final completion_tokens from stream: %d", messageID, outputTokens)
|
| 141 |
+
}
|
| 142 |
+
if chunk.Usage.PromptTokens != inputTokens && chunk.Usage.PromptTokens > 0 {
|
| 143 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Updating input tokens based on stream usage block: %d -> %d", messageID, inputTokens, chunk.Usage.PromptTokens)
|
| 144 |
+
inputTokens = chunk.Usage.PromptTokens
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
} else {
|
| 148 |
+
log.Printf("TRACE: [%s] SSE AggressiveDelta: Received non-data line: %s", messageID, line)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
// Send periodic pings
|
| 152 |
+
if time.Since(lastPingTime) >= 10*time.Second {
|
| 153 |
+
pingPayload := ClaudeSSEEvent{Type: "ping"}
|
| 154 |
+
if !sendSSEEvent(c, "ping", pingPayload, messageID, eventIndex) {
|
| 155 |
+
return // Stop if client disconnected
|
| 156 |
+
}
|
| 157 |
+
eventIndex++
|
| 158 |
+
lastPingTime = time.Now()
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if err := scanner.Err(); err != nil {
|
| 163 |
+
// Check if the error is due to context cancellation (client disconnect)
|
| 164 |
+
select {
|
| 165 |
+
case <-c.Request.Context().Done():
|
| 166 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Upstream read interrupted by client disconnect: %v", messageID, c.Request.Context().Err())
|
| 167 |
+
default:
|
| 168 |
+
log.Printf("ERROR: [%s] SSE AggressiveDelta: Error reading upstream response body: %v", messageID, err)
|
| 169 |
+
errChan <- fmt.Errorf("upstream read error: %w", err)
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
}()
|
| 173 |
+
|
| 174 |
+
// --- Initial Events ---
|
| 175 |
+
// Send message_start
|
| 176 |
+
startUsage := ClaudeUsage{InputTokens: calculatedInputTokens, OutputTokens: 0}
|
| 177 |
+
startMessage := ClaudeSSEMessage{ ID: messageID, Type: "message", Role: "assistant", Content: []ClaudeContentBlock{}, Model: requestedModel, StopReason: nil, StopSequence: nil, Usage: startUsage }
|
| 178 |
+
startEvent := ClaudeSSEEvent{Type: "message_start", Message: &startMessage}
|
| 179 |
+
if !sendSSEEvent(c, "message_start", startEvent, messageID, eventIndex) { return }
|
| 180 |
+
eventIndex++
|
| 181 |
+
|
| 182 |
+
// Send content_block_start
|
| 183 |
+
contentStartBlock := ClaudeSSEContentBlock{Type: "text", Text: ""}
|
| 184 |
+
contentStartEvent := ClaudeSSEEvent{ Type: "content_block_start", Index: func() *int { i := 0; return &i }(), ContentBlock: &contentStartBlock }
|
| 185 |
+
if !sendSSEEvent(c, "content_block_start", contentStartEvent, messageID, eventIndex) { return }
|
| 186 |
+
eventIndex++
|
| 187 |
+
|
| 188 |
+
// Send initial ping
|
| 189 |
+
pingPayload := ClaudeSSEEvent{Type: "ping"}
|
| 190 |
+
if !sendSSEEvent(c, "ping", pingPayload, messageID, eventIndex) { return }
|
| 191 |
+
eventIndex++
|
| 192 |
+
lastPingTime = time.Now()
|
| 193 |
+
|
| 194 |
+
// --- Wait for completion or error or client disconnect ---
|
| 195 |
+
select {
|
| 196 |
+
case <-doneChan:
|
| 197 |
+
log.Printf("DEBUG: [%s] SSE AggressiveDelta: Upstream reading finished.", messageID)
|
| 198 |
+
case err := <-errChan:
|
| 199 |
+
log.Printf("ERROR: [%s] SSE AggressiveDelta: Received error from reading goroutine: %v", messageID, err)
|
| 200 |
+
streamErrorOccurred = true
|
| 201 |
+
errorDetails = &ClaudeError{Type: "api_error", Message: fmt.Sprintf("Error reading upstream response: %v", err)}
|
| 202 |
+
case <-c.Request.Context().Done():
|
| 203 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Client disconnected during stream processing: %v", messageID, c.Request.Context().Err())
|
| 204 |
+
streamErrorOccurred = true // Treat disconnect as a type of error for cleanup
|
| 205 |
+
errorDetails = &ClaudeError{Type: "client_disconnect", Message: "Client disconnected during stream"}
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
// --- Finally Block Logic ---
|
| 209 |
+
log.Printf("DEBUG: [%s] SSE AggressiveDelta: Entering finally block logic. Finish_reason: %v, Error: %t", messageID, openAIFinishReason, streamErrorOccurred)
|
| 210 |
+
|
| 211 |
+
// Determine final Claude stop reason
|
| 212 |
+
var claudeStopReason string
|
| 213 |
+
if streamErrorOccurred && errorDetails != nil && errorDetails.Type == "client_disconnect" {
|
| 214 |
+
claudeStopReason = "client_disconnect"
|
| 215 |
+
} else if streamErrorOccurred {
|
| 216 |
+
claudeStopReason = "error"
|
| 217 |
+
} else {
|
| 218 |
+
claudeStopReason = mapOpenAIFinishReasonToClaude(openAIFinishReason)
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// Finalize token counts (use last known value, potentially from stream or final estimation)
|
| 222 |
+
finalInputTokens := inputTokens
|
| 223 |
+
finalOutputTokens := outputTokens // Use the value updated during the stream or from OpenAI's usage block
|
| 224 |
+
|
| 225 |
+
// If no explicit usage received, do a final estimation/forcing
|
| 226 |
+
if !finalUsageReceivedFromStream {
|
| 227 |
+
log.Printf("WARN: [%s] SSE AggressiveDelta: Final usage not explicitly received. Doing final estimate.", messageID)
|
| 228 |
+
estimatedOutput := estimateTokens(accumulatedContent)
|
| 229 |
+
finalOutputTokens = max(1, estimatedOutput)
|
| 230 |
+
if accumulatedContent == "" { finalOutputTokens = 0 }
|
| 231 |
+
log.Printf("WARN: [%s] SSE AggressiveDelta: Final Estimated/Forced output tokens: %d", messageID, finalOutputTokens)
|
| 232 |
+
} else {
|
| 233 |
+
// If usage *was* received, still force minimum 1 if non-zero
|
| 234 |
+
finalOutputTokens = max(1, outputTokens)
|
| 235 |
+
if outputTokens == 0 { finalOutputTokens = 0 }
|
| 236 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Using/Forced final usage from stream: output=%d", messageID, finalOutputTokens)
|
| 237 |
+
}
|
| 238 |
+
finalInputTokens = max(0, finalInputTokens)
|
| 239 |
+
finalOutputTokens = max(0, finalOutputTokens)
|
| 240 |
+
|
| 241 |
+
// Prepare usage data structures
|
| 242 |
+
// The *last* message_delta sent needs the final output tokens for chat-api billing hack
|
| 243 |
+
finalHackUsageData := ClaudeSSEUsage{OutputTokens: finalOutputTokens}
|
| 244 |
+
finalStopUsageData := ClaudeSSEUsage{ InputTokens: &finalInputTokens, OutputTokens: finalOutputTokens }
|
| 245 |
+
|
| 246 |
+
log.Printf("INFO: [%s] SSE AggressiveDelta: Stream finished. Stop Reason: %s. Final Usage: Input=%d Output=%d", messageID, claudeStopReason, finalInputTokens, finalOutputTokens)
|
| 247 |
+
|
| 248 |
+
// --- Yield Closing Events (Priority Final Delta First) ---
|
| 249 |
+
|
| 250 |
+
// *** Send FINAL message_delta WITH FINAL usage FIRST ***
|
| 251 |
+
// This ensures the most accurate count is sent last, potentially overwriting intermediate ones in chat-api
|
| 252 |
+
finalDeltaStopReason := claudeStopReason
|
| 253 |
+
priorityFinalDeltaPayload := ClaudeSSEEvent{
|
| 254 |
+
Type: "message_delta",
|
| 255 |
+
Delta: &ClaudeSSEDelta{
|
| 256 |
+
StopReason: &finalDeltaStopReason,
|
| 257 |
+
StopSequence: nil,
|
| 258 |
+
},
|
| 259 |
+
Usage: &finalHackUsageData, // Use the final calculated/forced output tokens
|
| 260 |
+
}
|
| 261 |
+
log.Printf("WARN: [%s] SSE AggressiveDelta: Yielding Event %d (PRIORITY FINAL message_delta WITH HACKED USAGE): %+v", messageID, eventIndex, priorityFinalDeltaPayload)
|
| 262 |
+
_ = sendSSEEvent(c, "message_delta", priorityFinalDeltaPayload, messageID, eventIndex) // Try to send even if disconnected
|
| 263 |
+
eventIndex++
|
| 264 |
+
|
| 265 |
+
// Send content_block_stop
|
| 266 |
+
contentStopPayload := ClaudeSSEEvent{ Type: "content_block_stop", Index: func() *int { i := 0; return &i }()}
|
| 267 |
+
log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (content_block_stop)", messageID, eventIndex)
|
| 268 |
+
_ = sendSSEEvent(c, "content_block_stop", contentStopPayload, messageID, eventIndex)
|
| 269 |
+
eventIndex++
|
| 270 |
+
|
| 271 |
+
// Send message_stop
|
| 272 |
+
messageStopPayload := ClaudeSSEEvent{ Type: "message_stop", Usage: &finalStopUsageData }
|
| 273 |
+
log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (message_stop)", messageID, eventIndex)
|
| 274 |
+
_ = sendSSEEvent(c, "message_stop", messageStopPayload, messageID, eventIndex)
|
| 275 |
+
eventIndex++
|
| 276 |
+
|
| 277 |
+
// Send error event if needed
|
| 278 |
+
if streamErrorOccurred && errorDetails != nil && errorDetails.Type != "client_disconnect" {
|
| 279 |
+
errorPayload := ClaudeSSEEvent{ Type: "error", Error: errorDetails }
|
| 280 |
+
log.Printf("TRACE: [%s] SSE AggressiveDelta: Yielding Event %d (error)", messageID, eventIndex)
|
| 281 |
+
_ = sendSSEEvent(c, "error", errorPayload, messageID, eventIndex)
|
| 282 |
+
eventIndex++
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
log.Printf("INFO: [%s] Completed sending SSE AggressiveDelta stream.", messageID)
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// sendSSEEvent sends a single SSE event and checks for client disconnect
|
| 289 |
+
func sendSSEEvent(c *gin.Context, eventName string, data interface{}, requestID string, eventIndex int) bool {
|
| 290 |
+
select {
|
| 291 |
+
case <-c.Request.Context().Done():
|
| 292 |
+
// Client disconnected
|
| 293 |
+
log.Printf("INFO: [%s] Client disconnected before sending SSE event %d (%s).", requestID, eventIndex, eventName)
|
| 294 |
+
return false
|
| 295 |
+
default:
|
| 296 |
+
// Client still connected, try sending
|
| 297 |
+
jsonData, err := json.Marshal(data)
|
| 298 |
+
if err != nil {
|
| 299 |
+
log.Printf("ERROR: [%s] Failed to marshal SSE event %d (%s): %v", requestID, eventIndex, eventName, err)
|
| 300 |
+
return true // Continue trying other events even if one fails marshaling?
|
| 301 |
+
}
|
| 302 |
+
// Use fmt.Fprintf for potentially better handling with Gin's writer interface
|
| 303 |
+
_, err = fmt.Fprintf(c.Writer, "event: %s\ndata: %s\n\n", eventName, string(jsonData))
|
| 304 |
+
if err != nil {
|
| 305 |
+
// This error often indicates the client disconnected during the write
|
| 306 |
+
log.Printf("WARN: [%s] Failed to write SSE event %d (%s) to client: %v. Client likely disconnected.", requestID, eventIndex, eventName, err)
|
| 307 |
+
return false // Stop processing if write fails
|
| 308 |
+
}
|
| 309 |
+
c.Writer.Flush() // Ensure data is sent immediately
|
| 310 |
+
return true
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// Helper for max function
|
| 315 |
+
func max(a, b int) int {
|
| 316 |
+
if a > b {
|
| 317 |
+
return a
|
| 318 |
+
}
|
| 319 |
+
return b
|
| 320 |
+
}
|
structs.go
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
package main
|
| 2 |
+
|
| 3 |
+
import "encoding/json"
|
| 4 |
+
|
| 5 |
+
// --- Claude API Structs (Anthropic Format) ---
|
| 6 |
+
|
| 7 |
+
// ClaudeRequest represents the incoming request structure from the client
|
| 8 |
+
type ClaudeRequest struct {
|
| 9 |
+
Model string `json:"model"`
|
| 10 |
+
Messages []ClaudeMessage `json:"messages"`
|
| 11 |
+
System json.RawMessage `json:"system,omitempty"` // Can be string or list of blocks
|
| 12 |
+
MaxTokens *int `json:"max_tokens,omitempty"` // Use pointer for optional fields
|
| 13 |
+
StopSequences []string `json:"stop_sequences,omitempty"`
|
| 14 |
+
Stream bool `json:"stream,omitempty"`
|
| 15 |
+
Temperature *float64 `json:"temperature,omitempty"`
|
| 16 |
+
TopP *float64 `json:"top_p,omitempty"`
|
| 17 |
+
// TopK *int `json:"top_k,omitempty"` // OpenAI doesn't support TopK directly
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
// ClaudeMessage represents a message in the Claude request
|
| 21 |
+
type ClaudeMessage struct {
|
| 22 |
+
Role string `json:"role"` // "user" or "assistant"
|
| 23 |
+
Content json.RawMessage `json:"content"` // Can be string or list of blocks
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
// ClaudeContentBlock represents a block within the content array
|
| 27 |
+
type ClaudeContentBlock struct {
|
| 28 |
+
Type string `json:"type"`
|
| 29 |
+
Text string `json:"text,omitempty"`
|
| 30 |
+
// Add other block types if needed (e.g., image)
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// ClaudeResponse represents the non-streaming response structure sent to the client
|
| 34 |
+
type ClaudeResponse struct {
|
| 35 |
+
ID string `json:"id"`
|
| 36 |
+
Type string `json:"type"` // e.g., "message"
|
| 37 |
+
Role string `json:"role"` // e.g., "assistant"
|
| 38 |
+
Content []ClaudeContentBlock `json:"content"`
|
| 39 |
+
Model string `json:"model"`
|
| 40 |
+
StopReason string `json:"stop_reason"` // e.g., "end_turn", "max_tokens"
|
| 41 |
+
StopSequence *string `json:"stop_sequence"` // Usually null
|
| 42 |
+
Usage ClaudeUsage `json:"usage"`
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
// ClaudeUsage represents the token usage information
|
| 46 |
+
type ClaudeUsage struct {
|
| 47 |
+
InputTokens int `json:"input_tokens"`
|
| 48 |
+
OutputTokens int `json:"output_tokens"`
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
// ClaudeErrorResponse represents the error structure sent to the client
|
| 52 |
+
type ClaudeErrorResponse struct {
|
| 53 |
+
Type string `json:"type"` // Always "error"
|
| 54 |
+
Error ClaudeError `json:"error"`
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// ClaudeError represents the detailed error information
|
| 58 |
+
type ClaudeError struct {
|
| 59 |
+
Type string `json:"type"` // e.g., "invalid_request_error", "api_error"
|
| 60 |
+
Message string `json:"message"`
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
// --- OpenAI API Structs ---
|
| 65 |
+
|
| 66 |
+
// OpenAIRequest represents the request structure sent to the upstream OpenAI API
|
| 67 |
+
type OpenAIRequest struct {
|
| 68 |
+
Model string `json:"model"`
|
| 69 |
+
Messages []OpenAIMessage `json:"messages"`
|
| 70 |
+
MaxTokens *int `json:"max_tokens,omitempty"`
|
| 71 |
+
Temperature *float64 `json:"temperature,omitempty"`
|
| 72 |
+
TopP *float64 `json:"top_p,omitempty"`
|
| 73 |
+
Stop []string `json:"stop,omitempty"`
|
| 74 |
+
Stream bool `json:"stream,omitempty"`
|
| 75 |
+
// N *int `json:"n,omitempty"` // Not typically used with Claude proxy
|
| 76 |
+
// PresencePenalty *float64 `json:"presence_penalty,omitempty"` // Not mapped
|
| 77 |
+
// FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // Not mapped
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
// OpenAIMessage represents a message in the OpenAI request
|
| 81 |
+
type OpenAIMessage struct {
|
| 82 |
+
Role string `json:"role"` // "system", "user", or "assistant"
|
| 83 |
+
Content string `json:"content"`
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
// OpenAIResponse represents the non-streaming response from the upstream OpenAI API
|
| 87 |
+
type OpenAIResponse struct {
|
| 88 |
+
ID string `json:"id"`
|
| 89 |
+
Object string `json:"object"` // e.g., "chat.completion"
|
| 90 |
+
Created int64 `json:"created"`
|
| 91 |
+
Model string `json:"model"`
|
| 92 |
+
Choices []OpenAIChoice `json:"choices"`
|
| 93 |
+
Usage *OpenAIUsage `json:"usage,omitempty"` // Pointer as it might be missing in errors
|
| 94 |
+
// SystemFingerprint string `json:"system_fingerprint"` // Optional
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
// OpenAIChoice represents a choice in the OpenAI response
|
| 98 |
+
type OpenAIChoice struct {
|
| 99 |
+
Index int `json:"index"`
|
| 100 |
+
Message OpenAIMessage `json:"message"`
|
| 101 |
+
FinishReason *string `json:"finish_reason"` // Pointer as it can be null
|
| 102 |
+
// Logprobs interface{} `json:"logprobs"` // Not typically used here
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// OpenAIUsage represents the token usage information from OpenAI
|
| 106 |
+
type OpenAIUsage struct {
|
| 107 |
+
PromptTokens int `json:"prompt_tokens"`
|
| 108 |
+
CompletionTokens int `json:"completion_tokens"`
|
| 109 |
+
TotalTokens int `json:"total_tokens"`
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// OpenAIStreamChoice represents a choice within an OpenAI SSE chunk
|
| 113 |
+
type OpenAIStreamChoice struct {
|
| 114 |
+
Index int `json:"index"`
|
| 115 |
+
Delta OpenAIStreamDelta `json:"delta"`
|
| 116 |
+
FinishReason *string `json:"finish_reason"` // Pointer as it can be null
|
| 117 |
+
// Logprobs interface{} `json:"logprobs"` // Not typically used here
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// OpenAIStreamDelta represents the delta content within an OpenAI SSE chunk
|
| 121 |
+
type OpenAIStreamDelta struct {
|
| 122 |
+
Role *string `json:"role,omitempty"` // Usually only in the first delta
|
| 123 |
+
Content *string `json:"content,omitempty"` // Pointer as it can be null or empty
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
// OpenAIStreamChunk represents the structure of a data chunk in the OpenAI SSE stream
|
| 127 |
+
type OpenAIStreamChunk struct {
|
| 128 |
+
ID string `json:"id"`
|
| 129 |
+
Object string `json:"object"` // e.g., "chat.completion.chunk"
|
| 130 |
+
Created int64 `json:"created"`
|
| 131 |
+
Model string `json:"model"`
|
| 132 |
+
Choices []OpenAIStreamChoice `json:"choices"`
|
| 133 |
+
Usage *OpenAIUsage `json:"usage,omitempty"` // Usually null except maybe in Azure's final chunk?
|
| 134 |
+
// SystemFingerprint string `json:"system_fingerprint"` // Optional
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
// --- Claude SSE Structs (for sending back to client) ---
|
| 139 |
+
|
| 140 |
+
// ClaudeSSEEvent represents a generic Claude SSE event structure for easy marshaling
|
| 141 |
+
type ClaudeSSEEvent struct {
|
| 142 |
+
Type string `json:"type"`
|
| 143 |
+
Index *int `json:"index,omitempty"` // Used in content_block_*
|
| 144 |
+
Message *ClaudeSSEMessage `json:"message,omitempty"` // Used in message_start
|
| 145 |
+
ContentBlock *ClaudeSSEContentBlock `json:"content_block,omitempty"` // Used in content_block_start
|
| 146 |
+
Delta *ClaudeSSEDelta `json:"delta,omitempty"` // Used in content_block_delta, message_delta
|
| 147 |
+
Usage *ClaudeSSEUsage `json:"usage,omitempty"` // Used in message_delta (HACK), message_stop
|
| 148 |
+
Error *ClaudeError `json:"error,omitempty"` // Used in error event
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
// ClaudeSSEMessage is nested within message_start
|
| 152 |
+
type ClaudeSSEMessage struct {
|
| 153 |
+
ID string `json:"id"`
|
| 154 |
+
Type string `json:"type"` // "message"
|
| 155 |
+
Role string `json:"role"` // "assistant"
|
| 156 |
+
Content []ClaudeContentBlock `json:"content"` // Initially empty
|
| 157 |
+
Model string `json:"model"`
|
| 158 |
+
StopReason *string `json:"stop_reason"` // Initially null
|
| 159 |
+
StopSequence *string `json:"stop_sequence"` // Initially null
|
| 160 |
+
Usage ClaudeUsage `json:"usage"` // Initial usage (input tokens)
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
// ClaudeSSEContentBlock is nested within content_block_start
|
| 164 |
+
type ClaudeSSEContentBlock struct {
|
| 165 |
+
Type string `json:"type"` // "text"
|
| 166 |
+
Text string `json:"text"` // Initially empty
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// ClaudeSSEDelta is nested within content_block_delta and message_delta
|
| 170 |
+
type ClaudeSSEDelta struct {
|
| 171 |
+
Type string `json:"type,omitempty"` // "text_delta" in content_block_delta
|
| 172 |
+
Text *string `json:"text,omitempty"` // Pointer for content_block_delta
|
| 173 |
+
StopReason *string `json:"stop_reason,omitempty"` // Pointer for message_delta
|
| 174 |
+
StopSequence *string `json:"stop_sequence,omitempty"` // Pointer for message_delta (usually null)
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
// ClaudeSSEUsage is nested within message_delta (HACK) and message_stop
|
| 178 |
+
type ClaudeSSEUsage struct {
|
| 179 |
+
// Note: message_delta only needs output_tokens for the hack
|
| 180 |
+
// message_stop should have both
|
| 181 |
+
InputTokens *int `json:"input_tokens,omitempty"` // Only in message_stop
|
| 182 |
+
OutputTokens int `json:"output_tokens"` // In both (but value differs)
|
| 183 |
+
}
|