Spaces:
Paused
Paused
| // Package executor provides runtime execution capabilities for various AI service providers. | |
| // This file implements a Codex executor that uses the Responses API WebSocket transport. | |
| package executor | |
| import ( | |
| "bytes" | |
| "context" | |
| "fmt" | |
| "io" | |
| "net" | |
| "net/http" | |
| "net/url" | |
| "strconv" | |
| "strings" | |
| "sync" | |
| "time" | |
| "github.com/gin-gonic/gin" | |
| "github.com/google/uuid" | |
| "github.com/gorilla/websocket" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/config" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/util" | |
| cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" | |
| cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" | |
| "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" | |
| sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" | |
| log "github.com/sirupsen/logrus" | |
| "github.com/tidwall/gjson" | |
| "github.com/tidwall/sjson" | |
| "golang.org/x/net/proxy" | |
| ) | |
| const ( | |
| codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06" | |
| codexResponsesWebsocketIdleTimeout = 5 * time.Minute | |
| codexResponsesWebsocketHandshakeTO = 30 * time.Second | |
| ) | |
| // CodexWebsocketsExecutor executes Codex Responses requests using a WebSocket transport. | |
| // | |
| // It preserves the existing CodexExecutor HTTP implementation as a fallback for endpoints | |
| // not available over WebSocket (e.g. /responses/compact) and for websocket upgrade failures. | |
| type CodexWebsocketsExecutor struct { | |
| *CodexExecutor | |
| store *codexWebsocketSessionStore | |
| } | |
| type codexWebsocketSessionStore struct { | |
| mu sync.Mutex | |
| sessions map[string]*codexWebsocketSession | |
| } | |
| var globalCodexWebsocketSessionStore = &codexWebsocketSessionStore{ | |
| sessions: make(map[string]*codexWebsocketSession), | |
| } | |
| type codexWebsocketSession struct { | |
| sessionID string | |
| reqMu sync.Mutex | |
| connMu sync.Mutex | |
| conn *websocket.Conn | |
| wsURL string | |
| authID string | |
| writeMu sync.Mutex | |
| activeMu sync.Mutex | |
| activeCh chan codexWebsocketRead | |
| activeDone <-chan struct{} | |
| activeCancel context.CancelFunc | |
| readerConn *websocket.Conn | |
| } | |
| func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor { | |
| return &CodexWebsocketsExecutor{ | |
| CodexExecutor: NewCodexExecutor(cfg), | |
| store: globalCodexWebsocketSessionStore, | |
| } | |
| } | |
| type codexWebsocketRead struct { | |
| conn *websocket.Conn | |
| msgType int | |
| payload []byte | |
| err error | |
| } | |
| func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) { | |
| if s == nil { | |
| return | |
| } | |
| s.activeMu.Lock() | |
| if s.activeCancel != nil { | |
| s.activeCancel() | |
| s.activeCancel = nil | |
| s.activeDone = nil | |
| } | |
| s.activeCh = ch | |
| if ch != nil { | |
| activeCtx, activeCancel := context.WithCancel(context.Background()) | |
| s.activeDone = activeCtx.Done() | |
| s.activeCancel = activeCancel | |
| } | |
| s.activeMu.Unlock() | |
| } | |
| func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { | |
| if s == nil { | |
| return | |
| } | |
| s.activeMu.Lock() | |
| if s.activeCh == ch { | |
| s.activeCh = nil | |
| if s.activeCancel != nil { | |
| s.activeCancel() | |
| } | |
| s.activeCancel = nil | |
| s.activeDone = nil | |
| } | |
| s.activeMu.Unlock() | |
| } | |
| func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { | |
| if s == nil { | |
| return fmt.Errorf("codex websockets executor: session is nil") | |
| } | |
| if conn == nil { | |
| return fmt.Errorf("codex websockets executor: websocket conn is nil") | |
| } | |
| s.writeMu.Lock() | |
| defer s.writeMu.Unlock() | |
| return conn.WriteMessage(msgType, payload) | |
| } | |
| func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) { | |
| if s == nil || conn == nil { | |
| return | |
| } | |
| conn.SetPingHandler(func(appData string) error { | |
| s.writeMu.Lock() | |
| defer s.writeMu.Unlock() | |
| // Reply pongs from the same write lock to avoid concurrent writes. | |
| return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(10*time.Second)) | |
| }) | |
| } | |
| func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { | |
| if ctx == nil { | |
| ctx = context.Background() | |
| } | |
| if opts.Alt == "responses/compact" { | |
| return e.CodexExecutor.executeCompact(ctx, auth, req, opts) | |
| } | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| apiKey, baseURL := codexCreds(auth) | |
| if baseURL == "" { | |
| baseURL = "https://chatgpt.com/backend-api/codex" | |
| } | |
| reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) | |
| defer reporter.TrackFailure(ctx, &err) | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("codex") | |
| originalPayloadSource := req.Payload | |
| if len(opts.OriginalRequest) > 0 { | |
| originalPayloadSource = opts.OriginalRequest | |
| } | |
| originalPayload := originalPayloadSource | |
| originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) | |
| body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) | |
| body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return resp, err | |
| } | |
| requestedModel := helps.PayloadRequestedModel(opts, req.Model) | |
| body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) | |
| body, _ = sjson.SetBytes(body, "model", baseModel) | |
| body, _ = sjson.SetBytes(body, "stream", true) | |
| body, _ = sjson.DeleteBytes(body, "previous_response_id") | |
| body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") | |
| body, _ = sjson.DeleteBytes(body, "safety_identifier") | |
| if !gjson.GetBytes(body, "instructions").Exists() { | |
| body, _ = sjson.SetBytes(body, "instructions", "") | |
| } | |
| httpURL := strings.TrimSuffix(baseURL, "/") + "/responses" | |
| wsURL, err := buildCodexResponsesWebsocketURL(httpURL) | |
| if err != nil { | |
| return resp, err | |
| } | |
| body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) | |
| wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| executionSessionID := executionSessionIDFromOptions(opts) | |
| var sess *codexWebsocketSession | |
| if executionSessionID != "" { | |
| sess = e.getOrCreateSession(executionSessionID) | |
| sess.reqMu.Lock() | |
| defer sess.reqMu.Unlock() | |
| } | |
| wsReqBody := buildCodexWebsocketRequestBody(body) | |
| wsReqLog := helps.UpstreamRequestLog{ | |
| URL: wsURL, | |
| Method: "WEBSOCKET", | |
| Headers: wsHeaders.Clone(), | |
| Body: wsReqBody, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| } | |
| helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog) | |
| conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) | |
| if errDial != nil { | |
| bodyErr := websocketHandshakeBody(respHS) | |
| if respHS != nil { | |
| helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr) | |
| } | |
| if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { | |
| return e.CodexExecutor.Execute(ctx, auth, req, opts) | |
| } | |
| if respHS != nil && respHS.StatusCode > 0 { | |
| return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} | |
| } | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial) | |
| return resp, errDial | |
| } | |
| recordAPIWebsocketHandshake(ctx, e.cfg, respHS) | |
| if sess == nil { | |
| logCodexWebsocketConnected(executionSessionID, authID, wsURL) | |
| defer func() { | |
| reason := "completed" | |
| if err != nil { | |
| reason = "error" | |
| } | |
| logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, reason, err) | |
| if errClose := conn.Close(); errClose != nil { | |
| log.Errorf("codex websockets executor: close websocket error: %v", errClose) | |
| } | |
| }() | |
| } | |
| var readCh chan codexWebsocketRead | |
| if sess != nil { | |
| readCh = make(chan codexWebsocketRead, 4096) | |
| sess.setActive(readCh) | |
| defer sess.clearActive(readCh) | |
| } | |
| if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { | |
| if sess != nil { | |
| e.invalidateUpstreamConn(sess, conn, "send_error", errSend) | |
| // Retry once with a fresh websocket connection. This is mainly to handle | |
| // upstream closing the socket between sequential requests within the same | |
| // execution session. | |
| connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) | |
| if errDialRetry == nil && connRetry != nil { | |
| wsReqBodyRetry := buildCodexWebsocketRequestBody(body) | |
| helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: wsURL, | |
| Method: "WEBSOCKET", | |
| Headers: wsHeaders.Clone(), | |
| Body: wsReqBodyRetry, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry) | |
| if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil { | |
| conn = connRetry | |
| wsReqBody = wsReqBodyRetry | |
| } else { | |
| e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry) | |
| return resp, errSendRetry | |
| } | |
| } else { | |
| closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error") | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry) | |
| return resp, errDialRetry | |
| } | |
| } else { | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend) | |
| return resp, errSend | |
| } | |
| } | |
| for { | |
| if ctx != nil && ctx.Err() != nil { | |
| return resp, ctx.Err() | |
| } | |
| msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) | |
| if errRead != nil { | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead) | |
| return resp, errRead | |
| } | |
| if msgType != websocket.TextMessage { | |
| if msgType == websocket.BinaryMessage { | |
| err = fmt.Errorf("codex websockets executor: unexpected binary message") | |
| if sess != nil { | |
| e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) | |
| } | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err) | |
| return resp, err | |
| } | |
| continue | |
| } | |
| payload = bytes.TrimSpace(payload) | |
| if len(payload) == 0 { | |
| continue | |
| } | |
| helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload) | |
| if wsErr, ok := parseCodexWebsocketError(payload); ok { | |
| if sess != nil { | |
| e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) | |
| } | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr) | |
| return resp, wsErr | |
| } | |
| payload = normalizeCodexWebsocketCompletion(payload) | |
| eventType := gjson.GetBytes(payload, "type").String() | |
| if eventType == "response.completed" { | |
| if detail, ok := helps.ParseCodexUsage(payload); ok { | |
| reporter.Publish(ctx, detail) | |
| } | |
| var param any | |
| out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m) | |
| resp = cliproxyexecutor.Response{Payload: out} | |
| return resp, nil | |
| } | |
| } | |
| } | |
| func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) { | |
| log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) | |
| if ctx == nil { | |
| ctx = context.Background() | |
| } | |
| if opts.Alt == "responses/compact" { | |
| return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"} | |
| } | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| apiKey, baseURL := codexCreds(auth) | |
| if baseURL == "" { | |
| baseURL = "https://chatgpt.com/backend-api/codex" | |
| } | |
| reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) | |
| defer reporter.TrackFailure(ctx, &err) | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("codex") | |
| body := req.Payload | |
| body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return nil, err | |
| } | |
| requestedModel := helps.PayloadRequestedModel(opts, req.Model) | |
| body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel) | |
| httpURL := strings.TrimSuffix(baseURL, "/") + "/responses" | |
| wsURL, err := buildCodexResponsesWebsocketURL(httpURL) | |
| if err != nil { | |
| return nil, err | |
| } | |
| body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) | |
| wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg) | |
| var authID, authLabel, authType, authValue string | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| executionSessionID := executionSessionIDFromOptions(opts) | |
| var sess *codexWebsocketSession | |
| if executionSessionID != "" { | |
| sess = e.getOrCreateSession(executionSessionID) | |
| if sess != nil { | |
| sess.reqMu.Lock() | |
| } | |
| } | |
| wsReqBody := buildCodexWebsocketRequestBody(body) | |
| wsReqLog := helps.UpstreamRequestLog{ | |
| URL: wsURL, | |
| Method: "WEBSOCKET", | |
| Headers: wsHeaders.Clone(), | |
| Body: wsReqBody, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| } | |
| helps.RecordAPIWebsocketRequest(ctx, e.cfg, wsReqLog) | |
| conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) | |
| var upstreamHeaders http.Header | |
| if respHS != nil { | |
| upstreamHeaders = respHS.Header.Clone() | |
| } | |
| if errDial != nil { | |
| bodyErr := websocketHandshakeBody(respHS) | |
| if respHS != nil { | |
| helps.RecordAPIWebsocketUpgradeRejection(ctx, e.cfg, websocketUpgradeRequestLog(wsReqLog), respHS.StatusCode, respHS.Header.Clone(), bodyErr) | |
| } | |
| if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { | |
| return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts) | |
| } | |
| if respHS != nil && respHS.StatusCode > 0 { | |
| return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} | |
| } | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "dial", errDial) | |
| if sess != nil { | |
| sess.reqMu.Unlock() | |
| } | |
| return nil, errDial | |
| } | |
| recordAPIWebsocketHandshake(ctx, e.cfg, respHS) | |
| if sess == nil { | |
| logCodexWebsocketConnected(executionSessionID, authID, wsURL) | |
| } | |
| var readCh chan codexWebsocketRead | |
| if sess != nil { | |
| readCh = make(chan codexWebsocketRead, 4096) | |
| sess.setActive(readCh) | |
| } | |
| if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "send", errSend) | |
| if sess != nil { | |
| e.invalidateUpstreamConn(sess, conn, "send_error", errSend) | |
| // Retry once with a new websocket connection for the same execution session. | |
| connRetry, respHSRetry, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) | |
| if errDialRetry != nil || connRetry == nil { | |
| closeHTTPResponseBody(respHSRetry, "codex websockets executor: close handshake response body error") | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "dial_retry", errDialRetry) | |
| sess.clearActive(readCh) | |
| sess.reqMu.Unlock() | |
| return nil, errDialRetry | |
| } | |
| wsReqBodyRetry := buildCodexWebsocketRequestBody(body) | |
| helps.RecordAPIWebsocketRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: wsURL, | |
| Method: "WEBSOCKET", | |
| Headers: wsHeaders.Clone(), | |
| Body: wsReqBodyRetry, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| recordAPIWebsocketHandshake(ctx, e.cfg, respHSRetry) | |
| if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil { | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "send_retry", errSendRetry) | |
| e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) | |
| sess.clearActive(readCh) | |
| sess.reqMu.Unlock() | |
| return nil, errSendRetry | |
| } | |
| conn = connRetry | |
| wsReqBody = wsReqBodyRetry | |
| } else { | |
| logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, "send_error", errSend) | |
| if errClose := conn.Close(); errClose != nil { | |
| log.Errorf("codex websockets executor: close websocket error: %v", errClose) | |
| } | |
| return nil, errSend | |
| } | |
| } | |
| out := make(chan cliproxyexecutor.StreamChunk) | |
| go func() { | |
| terminateReason := "completed" | |
| var terminateErr error | |
| defer close(out) | |
| defer func() { | |
| if sess != nil { | |
| sess.clearActive(readCh) | |
| sess.reqMu.Unlock() | |
| return | |
| } | |
| logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, terminateReason, terminateErr) | |
| if errClose := conn.Close(); errClose != nil { | |
| log.Errorf("codex websockets executor: close websocket error: %v", errClose) | |
| } | |
| }() | |
| send := func(chunk cliproxyexecutor.StreamChunk) bool { | |
| if ctx == nil { | |
| out <- chunk | |
| return true | |
| } | |
| select { | |
| case out <- chunk: | |
| return true | |
| case <-ctx.Done(): | |
| return false | |
| } | |
| } | |
| var param any | |
| for { | |
| if ctx != nil && ctx.Err() != nil { | |
| terminateReason = "context_done" | |
| terminateErr = ctx.Err() | |
| _ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()}) | |
| return | |
| } | |
| msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) | |
| if errRead != nil { | |
| if sess != nil && ctx != nil && ctx.Err() != nil { | |
| terminateReason = "context_done" | |
| terminateErr = ctx.Err() | |
| _ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()}) | |
| return | |
| } | |
| terminateReason = "read_error" | |
| terminateErr = errRead | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "read", errRead) | |
| reporter.PublishFailure(ctx) | |
| _ = send(cliproxyexecutor.StreamChunk{Err: errRead}) | |
| return | |
| } | |
| if msgType != websocket.TextMessage { | |
| if msgType == websocket.BinaryMessage { | |
| err = fmt.Errorf("codex websockets executor: unexpected binary message") | |
| terminateReason = "unexpected_binary" | |
| terminateErr = err | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "unexpected_binary", err) | |
| reporter.PublishFailure(ctx) | |
| if sess != nil { | |
| e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) | |
| } | |
| _ = send(cliproxyexecutor.StreamChunk{Err: err}) | |
| return | |
| } | |
| continue | |
| } | |
| payload = bytes.TrimSpace(payload) | |
| if len(payload) == 0 { | |
| continue | |
| } | |
| helps.AppendAPIWebsocketResponse(ctx, e.cfg, payload) | |
| if wsErr, ok := parseCodexWebsocketError(payload); ok { | |
| terminateReason = "upstream_error" | |
| terminateErr = wsErr | |
| helps.RecordAPIWebsocketError(ctx, e.cfg, "upstream_error", wsErr) | |
| reporter.PublishFailure(ctx) | |
| if sess != nil { | |
| e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) | |
| } | |
| _ = send(cliproxyexecutor.StreamChunk{Err: wsErr}) | |
| return | |
| } | |
| payload = normalizeCodexWebsocketCompletion(payload) | |
| eventType := gjson.GetBytes(payload, "type").String() | |
| if eventType == "response.completed" || eventType == "response.done" { | |
| if detail, ok := helps.ParseCodexUsage(payload); ok { | |
| reporter.Publish(ctx, detail) | |
| } | |
| } | |
| line := encodeCodexWebsocketAsSSE(payload) | |
| chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m) | |
| for i := range chunks { | |
| if !send(cliproxyexecutor.StreamChunk{Payload: chunks[i]}) { | |
| terminateReason = "context_done" | |
| terminateErr = ctx.Err() | |
| return | |
| } | |
| } | |
| if eventType == "response.completed" || eventType == "response.done" { | |
| return | |
| } | |
| } | |
| }() | |
| return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil | |
| } | |
| func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { | |
| dialer := newProxyAwareWebsocketDialer(e.cfg, auth) | |
| dialer.HandshakeTimeout = codexResponsesWebsocketHandshakeTO | |
| dialer.EnableCompression = true | |
| if ctx == nil { | |
| ctx = context.Background() | |
| } | |
| conn, resp, err := dialer.DialContext(ctx, wsURL, headers) | |
| if conn != nil { | |
| // Avoid gorilla/websocket flate tail validation issues on some upstreams/Go versions. | |
| // Negotiating permessage-deflate is fine; we just don't compress outbound messages. | |
| conn.EnableWriteCompression(false) | |
| } | |
| return conn, resp, err | |
| } | |
| func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) error { | |
| if sess != nil { | |
| return sess.writeMessage(conn, websocket.TextMessage, payload) | |
| } | |
| if conn == nil { | |
| return fmt.Errorf("codex websockets executor: websocket conn is nil") | |
| } | |
| return conn.WriteMessage(websocket.TextMessage, payload) | |
| } | |
| func buildCodexWebsocketRequestBody(body []byte) []byte { | |
| if len(body) == 0 { | |
| return nil | |
| } | |
| // Match codex-rs websocket v2 semantics: every request is `response.create`. | |
| // Incremental follow-up turns continue on the same websocket using | |
| // `previous_response_id` + incremental `input`, not `response.append`. | |
| wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create") | |
| if errSet == nil && len(wsReqBody) > 0 { | |
| return wsReqBody | |
| } | |
| fallback := bytes.Clone(body) | |
| fallback, _ = sjson.SetBytes(fallback, "type", "response.create") | |
| return fallback | |
| } | |
| func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, conn *websocket.Conn, readCh chan codexWebsocketRead) (int, []byte, error) { | |
| if sess == nil { | |
| if conn == nil { | |
| return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil") | |
| } | |
| _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) | |
| msgType, payload, errRead := conn.ReadMessage() | |
| return msgType, payload, errRead | |
| } | |
| if conn == nil { | |
| return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil") | |
| } | |
| if readCh == nil { | |
| return 0, nil, fmt.Errorf("codex websockets executor: session read channel is nil") | |
| } | |
| for { | |
| select { | |
| case <-ctx.Done(): | |
| return 0, nil, ctx.Err() | |
| case ev, ok := <-readCh: | |
| if !ok { | |
| return 0, nil, fmt.Errorf("codex websockets executor: session read channel closed") | |
| } | |
| if ev.conn != conn { | |
| continue | |
| } | |
| if ev.err != nil { | |
| return 0, nil, ev.err | |
| } | |
| return ev.msgType, ev.payload, nil | |
| } | |
| } | |
| } | |
| func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer { | |
| dialer := &websocket.Dialer{ | |
| Proxy: http.ProxyFromEnvironment, | |
| HandshakeTimeout: codexResponsesWebsocketHandshakeTO, | |
| EnableCompression: true, | |
| NetDialContext: (&net.Dialer{ | |
| Timeout: 30 * time.Second, | |
| KeepAlive: 30 * time.Second, | |
| }).DialContext, | |
| } | |
| proxyURL := "" | |
| if auth != nil { | |
| proxyURL = strings.TrimSpace(auth.ProxyURL) | |
| } | |
| if proxyURL == "" && cfg != nil { | |
| proxyURL = strings.TrimSpace(cfg.ProxyURL) | |
| } | |
| if proxyURL == "" { | |
| return dialer | |
| } | |
| setting, errParse := proxyutil.Parse(proxyURL) | |
| if errParse != nil { | |
| log.Errorf("codex websockets executor: %v", errParse) | |
| return dialer | |
| } | |
| switch setting.Mode { | |
| case proxyutil.ModeDirect: | |
| dialer.Proxy = nil | |
| return dialer | |
| case proxyutil.ModeProxy: | |
| default: | |
| return dialer | |
| } | |
| switch setting.URL.Scheme { | |
| case "socks5", "socks5h": | |
| var proxyAuth *proxy.Auth | |
| if setting.URL.User != nil { | |
| username := setting.URL.User.Username() | |
| password, _ := setting.URL.User.Password() | |
| proxyAuth = &proxy.Auth{User: username, Password: password} | |
| } | |
| socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct) | |
| if errSOCKS5 != nil { | |
| log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5) | |
| return dialer | |
| } | |
| dialer.Proxy = nil | |
| dialer.NetDialContext = func(_ context.Context, network, addr string) (net.Conn, error) { | |
| return socksDialer.Dial(network, addr) | |
| } | |
| case "http", "https": | |
| dialer.Proxy = http.ProxyURL(setting.URL) | |
| default: | |
| log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme) | |
| } | |
| return dialer | |
| } | |
| func buildCodexResponsesWebsocketURL(httpURL string) (string, error) { | |
| parsed, err := url.Parse(strings.TrimSpace(httpURL)) | |
| if err != nil { | |
| return "", err | |
| } | |
| switch strings.ToLower(parsed.Scheme) { | |
| case "http": | |
| parsed.Scheme = "ws" | |
| case "https": | |
| parsed.Scheme = "wss" | |
| } | |
| return parsed.String(), nil | |
| } | |
| func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) { | |
| headers := http.Header{} | |
| if len(rawJSON) == 0 { | |
| return rawJSON, headers | |
| } | |
| var cache helps.CodexCache | |
| if from == "claude" { | |
| userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id") | |
| if userIDResult.Exists() { | |
| key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) | |
| if cached, ok := helps.GetCodexCache(key); ok { | |
| cache = cached | |
| } else { | |
| cache = helps.CodexCache{ | |
| ID: uuid.New().String(), | |
| Expire: time.Now().Add(1 * time.Hour), | |
| } | |
| helps.SetCodexCache(key, cache) | |
| } | |
| } | |
| } else if from == "openai-response" { | |
| if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() { | |
| cache.ID = promptCacheKey.String() | |
| } | |
| } | |
| if cache.ID != "" { | |
| rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID) | |
| headers.Set("Conversation_id", cache.ID) | |
| } | |
| return rawJSON, headers | |
| } | |
| func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header { | |
| if headers == nil { | |
| headers = http.Header{} | |
| } | |
| if strings.TrimSpace(token) != "" { | |
| headers.Set("Authorization", "Bearer "+token) | |
| } | |
| var ginHeaders http.Header | |
| if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil { | |
| ginHeaders = ginCtx.Request.Header.Clone() | |
| } | |
| _, cfgBetaFeatures := codexHeaderDefaults(cfg, auth) | |
| ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "") | |
| misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "") | |
| misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "") | |
| misc.EnsureHeader(headers, ginHeaders, "x-client-request-id", "") | |
| misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "") | |
| misc.EnsureHeader(headers, ginHeaders, "Version", "") | |
| betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta")) | |
| if betaHeader == "" && ginHeaders != nil { | |
| betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta")) | |
| } | |
| if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") { | |
| betaHeader = codexResponsesWebsocketBetaHeaderValue | |
| } | |
| headers.Set("OpenAI-Beta", betaHeader) | |
| if strings.Contains(headers.Get("User-Agent"), "Mac OS") { | |
| misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString()) | |
| } | |
| headers.Del("User-Agent") | |
| isAPIKey := false | |
| if auth != nil && auth.Attributes != nil { | |
| if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" { | |
| isAPIKey = true | |
| } | |
| } | |
| if originator := strings.TrimSpace(ginHeaders.Get("Originator")); originator != "" { | |
| headers.Set("Originator", originator) | |
| } else if !isAPIKey { | |
| headers.Set("Originator", codexOriginator) | |
| } | |
| if !isAPIKey { | |
| if auth != nil && auth.Metadata != nil { | |
| if accountID, ok := auth.Metadata["account_id"].(string); ok { | |
| if trimmed := strings.TrimSpace(accountID); trimmed != "" { | |
| headers.Set("Chatgpt-Account-Id", trimmed) | |
| } | |
| } | |
| } | |
| } | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(&http.Request{Header: headers}, attrs) | |
| return headers | |
| } | |
| func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) { | |
| if cfg == nil || auth == nil { | |
| return "", "" | |
| } | |
| if auth.Attributes != nil { | |
| if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" { | |
| return "", "" | |
| } | |
| } | |
| return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures) | |
| } | |
| func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) { | |
| if target == nil { | |
| return | |
| } | |
| if strings.TrimSpace(target.Get(key)) != "" { | |
| return | |
| } | |
| if source != nil { | |
| if val := strings.TrimSpace(source.Get(key)); val != "" { | |
| target.Set(key, val) | |
| return | |
| } | |
| } | |
| if val := strings.TrimSpace(configValue); val != "" { | |
| target.Set(key, val) | |
| return | |
| } | |
| if val := strings.TrimSpace(fallbackValue); val != "" { | |
| target.Set(key, val) | |
| } | |
| } | |
| func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) { | |
| if target == nil { | |
| return | |
| } | |
| if strings.TrimSpace(target.Get(key)) != "" { | |
| return | |
| } | |
| if val := strings.TrimSpace(configValue); val != "" { | |
| target.Set(key, val) | |
| return | |
| } | |
| if source != nil { | |
| if val := strings.TrimSpace(source.Get(key)); val != "" { | |
| target.Set(key, val) | |
| return | |
| } | |
| } | |
| if val := strings.TrimSpace(fallbackValue); val != "" { | |
| target.Set(key, val) | |
| } | |
| } | |
| type statusErrWithHeaders struct { | |
| statusErr | |
| headers http.Header | |
| } | |
| func (e statusErrWithHeaders) Headers() http.Header { | |
| if e.headers == nil { | |
| return nil | |
| } | |
| return e.headers.Clone() | |
| } | |
| func parseCodexWebsocketError(payload []byte) (error, bool) { | |
| if len(payload) == 0 { | |
| return nil, false | |
| } | |
| if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "error" { | |
| return nil, false | |
| } | |
| status := int(gjson.GetBytes(payload, "status").Int()) | |
| if status == 0 { | |
| status = int(gjson.GetBytes(payload, "status_code").Int()) | |
| } | |
| if status <= 0 { | |
| return nil, false | |
| } | |
| out := []byte(`{}`) | |
| if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() { | |
| raw := errNode.Raw | |
| if errNode.Type == gjson.String { | |
| raw = errNode.Raw | |
| } | |
| out, _ = sjson.SetRawBytes(out, "error", []byte(raw)) | |
| } else { | |
| out, _ = sjson.SetBytes(out, "error.type", "server_error") | |
| out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status)) | |
| } | |
| headers := parseCodexWebsocketErrorHeaders(payload) | |
| return statusErrWithHeaders{ | |
| statusErr: statusErr{code: status, msg: string(out)}, | |
| headers: headers, | |
| }, true | |
| } | |
| func parseCodexWebsocketErrorHeaders(payload []byte) http.Header { | |
| headersNode := gjson.GetBytes(payload, "headers") | |
| if !headersNode.Exists() || !headersNode.IsObject() { | |
| return nil | |
| } | |
| mapped := make(http.Header) | |
| headersNode.ForEach(func(key, value gjson.Result) bool { | |
| name := strings.TrimSpace(key.String()) | |
| if name == "" { | |
| return true | |
| } | |
| switch value.Type { | |
| case gjson.String: | |
| if v := strings.TrimSpace(value.String()); v != "" { | |
| mapped.Set(name, v) | |
| } | |
| case gjson.Number, gjson.True, gjson.False: | |
| if v := strings.TrimSpace(value.Raw); v != "" { | |
| mapped.Set(name, v) | |
| } | |
| default: | |
| } | |
| return true | |
| }) | |
| if len(mapped) == 0 { | |
| return nil | |
| } | |
| return mapped | |
| } | |
| func normalizeCodexWebsocketCompletion(payload []byte) []byte { | |
| if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.done" { | |
| updated, err := sjson.SetBytes(payload, "type", "response.completed") | |
| if err == nil && len(updated) > 0 { | |
| return updated | |
| } | |
| } | |
| return payload | |
| } | |
| func encodeCodexWebsocketAsSSE(payload []byte) []byte { | |
| if len(payload) == 0 { | |
| return nil | |
| } | |
| line := make([]byte, 0, len("data: ")+len(payload)) | |
| line = append(line, []byte("data: ")...) | |
| line = append(line, payload...) | |
| return line | |
| } | |
| func websocketUpgradeRequestLog(info helps.UpstreamRequestLog) helps.UpstreamRequestLog { | |
| upgradeInfo := info | |
| upgradeInfo.URL = helps.WebsocketUpgradeRequestURL(info.URL) | |
| upgradeInfo.Method = http.MethodGet | |
| upgradeInfo.Body = nil | |
| upgradeInfo.Headers = info.Headers.Clone() | |
| if upgradeInfo.Headers == nil { | |
| upgradeInfo.Headers = make(http.Header) | |
| } | |
| if strings.TrimSpace(upgradeInfo.Headers.Get("Connection")) == "" { | |
| upgradeInfo.Headers.Set("Connection", "Upgrade") | |
| } | |
| if strings.TrimSpace(upgradeInfo.Headers.Get("Upgrade")) == "" { | |
| upgradeInfo.Headers.Set("Upgrade", "websocket") | |
| } | |
| return upgradeInfo | |
| } | |
| func recordAPIWebsocketHandshake(ctx context.Context, cfg *config.Config, resp *http.Response) { | |
| if resp == nil { | |
| return | |
| } | |
| helps.RecordAPIWebsocketHandshake(ctx, cfg, resp.StatusCode, resp.Header.Clone()) | |
| closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error") | |
| } | |
| func websocketHandshakeBody(resp *http.Response) []byte { | |
| if resp == nil || resp.Body == nil { | |
| return nil | |
| } | |
| body, _ := io.ReadAll(resp.Body) | |
| closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error") | |
| if len(body) == 0 { | |
| return nil | |
| } | |
| return body | |
| } | |
| func closeHTTPResponseBody(resp *http.Response, logPrefix string) { | |
| if resp == nil || resp.Body == nil { | |
| return | |
| } | |
| if errClose := resp.Body.Close(); errClose != nil { | |
| log.Errorf("%s: %v", logPrefix, errClose) | |
| } | |
| } | |
| func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string { | |
| if len(opts.Metadata) == 0 { | |
| return "" | |
| } | |
| raw, ok := opts.Metadata[cliproxyexecutor.ExecutionSessionMetadataKey] | |
| if !ok || raw == nil { | |
| return "" | |
| } | |
| switch v := raw.(type) { | |
| case string: | |
| return strings.TrimSpace(v) | |
| case []byte: | |
| return strings.TrimSpace(string(v)) | |
| default: | |
| return "" | |
| } | |
| } | |
| func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWebsocketSession { | |
| sessionID = strings.TrimSpace(sessionID) | |
| if sessionID == "" { | |
| return nil | |
| } | |
| if e == nil { | |
| return nil | |
| } | |
| store := e.store | |
| if store == nil { | |
| store = globalCodexWebsocketSessionStore | |
| } | |
| store.mu.Lock() | |
| defer store.mu.Unlock() | |
| if store.sessions == nil { | |
| store.sessions = make(map[string]*codexWebsocketSession) | |
| } | |
| if sess, ok := store.sessions[sessionID]; ok && sess != nil { | |
| return sess | |
| } | |
| sess := &codexWebsocketSession{sessionID: sessionID} | |
| store.sessions[sessionID] = sess | |
| return sess | |
| } | |
| func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { | |
| if sess == nil { | |
| return e.dialCodexWebsocket(ctx, auth, wsURL, headers) | |
| } | |
| sess.connMu.Lock() | |
| conn := sess.conn | |
| readerConn := sess.readerConn | |
| sess.connMu.Unlock() | |
| if conn != nil { | |
| if readerConn != conn { | |
| sess.connMu.Lock() | |
| sess.readerConn = conn | |
| sess.connMu.Unlock() | |
| sess.configureConn(conn) | |
| go e.readUpstreamLoop(sess, conn) | |
| } | |
| return conn, nil, nil | |
| } | |
| conn, resp, errDial := e.dialCodexWebsocket(ctx, auth, wsURL, headers) | |
| if errDial != nil { | |
| return nil, resp, errDial | |
| } | |
| sess.connMu.Lock() | |
| if sess.conn != nil { | |
| previous := sess.conn | |
| sess.connMu.Unlock() | |
| if errClose := conn.Close(); errClose != nil { | |
| log.Errorf("codex websockets executor: close websocket error: %v", errClose) | |
| } | |
| return previous, nil, nil | |
| } | |
| sess.conn = conn | |
| sess.wsURL = wsURL | |
| sess.authID = authID | |
| sess.readerConn = conn | |
| sess.connMu.Unlock() | |
| sess.configureConn(conn) | |
| go e.readUpstreamLoop(sess, conn) | |
| logCodexWebsocketConnected(sess.sessionID, authID, wsURL) | |
| return conn, resp, nil | |
| } | |
| func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, conn *websocket.Conn) { | |
| if e == nil || sess == nil || conn == nil { | |
| return | |
| } | |
| for { | |
| _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) | |
| msgType, payload, errRead := conn.ReadMessage() | |
| if errRead != nil { | |
| sess.activeMu.Lock() | |
| ch := sess.activeCh | |
| done := sess.activeDone | |
| sess.activeMu.Unlock() | |
| if ch != nil { | |
| select { | |
| case ch <- codexWebsocketRead{conn: conn, err: errRead}: | |
| case <-done: | |
| default: | |
| } | |
| sess.clearActive(ch) | |
| close(ch) | |
| } | |
| e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) | |
| return | |
| } | |
| if msgType != websocket.TextMessage { | |
| if msgType == websocket.BinaryMessage { | |
| errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") | |
| sess.activeMu.Lock() | |
| ch := sess.activeCh | |
| done := sess.activeDone | |
| sess.activeMu.Unlock() | |
| if ch != nil { | |
| select { | |
| case ch <- codexWebsocketRead{conn: conn, err: errBinary}: | |
| case <-done: | |
| default: | |
| } | |
| sess.clearActive(ch) | |
| close(ch) | |
| } | |
| e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) | |
| return | |
| } | |
| continue | |
| } | |
| sess.activeMu.Lock() | |
| ch := sess.activeCh | |
| done := sess.activeDone | |
| sess.activeMu.Unlock() | |
| if ch == nil { | |
| continue | |
| } | |
| select { | |
| case ch <- codexWebsocketRead{conn: conn, msgType: msgType, payload: payload}: | |
| case <-done: | |
| } | |
| } | |
| } | |
| func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error) { | |
| if sess == nil || conn == nil { | |
| return | |
| } | |
| sess.connMu.Lock() | |
| current := sess.conn | |
| authID := sess.authID | |
| wsURL := sess.wsURL | |
| sessionID := sess.sessionID | |
| if current == nil || current != conn { | |
| sess.connMu.Unlock() | |
| return | |
| } | |
| sess.conn = nil | |
| if sess.readerConn == conn { | |
| sess.readerConn = nil | |
| } | |
| sess.connMu.Unlock() | |
| logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err) | |
| if errClose := conn.Close(); errClose != nil { | |
| log.Errorf("codex websockets executor: close websocket error: %v", errClose) | |
| } | |
| } | |
| func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) { | |
| sessionID = strings.TrimSpace(sessionID) | |
| if e == nil { | |
| return | |
| } | |
| if sessionID == "" { | |
| return | |
| } | |
| if sessionID == cliproxyauth.CloseAllExecutionSessionsID { | |
| // Executor replacement can happen during hot reload (config/credential changes). | |
| // Do not force-close upstream websocket sessions here, otherwise in-flight | |
| // downstream websocket requests get interrupted. | |
| return | |
| } | |
| store := e.store | |
| if store == nil { | |
| store = globalCodexWebsocketSessionStore | |
| } | |
| store.mu.Lock() | |
| sess := store.sessions[sessionID] | |
| delete(store.sessions, sessionID) | |
| store.mu.Unlock() | |
| e.closeExecutionSession(sess, "session_closed") | |
| } | |
| func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) { | |
| if e == nil { | |
| return | |
| } | |
| store := e.store | |
| if store == nil { | |
| store = globalCodexWebsocketSessionStore | |
| } | |
| store.mu.Lock() | |
| sessions := make([]*codexWebsocketSession, 0, len(store.sessions)) | |
| for sessionID, sess := range store.sessions { | |
| delete(store.sessions, sessionID) | |
| if sess != nil { | |
| sessions = append(sessions, sess) | |
| } | |
| } | |
| store.mu.Unlock() | |
| for i := range sessions { | |
| e.closeExecutionSession(sessions[i], reason) | |
| } | |
| } | |
| func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) { | |
| closeCodexWebsocketSession(sess, reason) | |
| } | |
| func closeCodexWebsocketSession(sess *codexWebsocketSession, reason string) { | |
| if sess == nil { | |
| return | |
| } | |
| reason = strings.TrimSpace(reason) | |
| if reason == "" { | |
| reason = "session_closed" | |
| } | |
| sess.connMu.Lock() | |
| conn := sess.conn | |
| authID := sess.authID | |
| wsURL := sess.wsURL | |
| sess.conn = nil | |
| if sess.readerConn == conn { | |
| sess.readerConn = nil | |
| } | |
| sessionID := sess.sessionID | |
| sess.connMu.Unlock() | |
| if conn == nil { | |
| return | |
| } | |
| logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, nil) | |
| if errClose := conn.Close(); errClose != nil { | |
| log.Errorf("codex websockets executor: close websocket error: %v", errClose) | |
| } | |
| } | |
| func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) { | |
| log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL)) | |
| } | |
| func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) { | |
| if err != nil { | |
| log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err) | |
| return | |
| } | |
| log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) | |
| } | |
| // CloseCodexWebsocketSessionsForAuthID closes all active Codex upstream websocket sessions | |
| // associated with the supplied auth ID. | |
| func CloseCodexWebsocketSessionsForAuthID(authID string, reason string) { | |
| authID = strings.TrimSpace(authID) | |
| if authID == "" { | |
| return | |
| } | |
| reason = strings.TrimSpace(reason) | |
| if reason == "" { | |
| reason = "auth_removed" | |
| } | |
| store := globalCodexWebsocketSessionStore | |
| if store == nil { | |
| return | |
| } | |
| type sessionItem struct { | |
| sessionID string | |
| sess *codexWebsocketSession | |
| } | |
| store.mu.Lock() | |
| items := make([]sessionItem, 0, len(store.sessions)) | |
| for sessionID, sess := range store.sessions { | |
| items = append(items, sessionItem{sessionID: sessionID, sess: sess}) | |
| } | |
| store.mu.Unlock() | |
| matches := make([]sessionItem, 0) | |
| for i := range items { | |
| sess := items[i].sess | |
| if sess == nil { | |
| continue | |
| } | |
| sess.connMu.Lock() | |
| sessAuthID := strings.TrimSpace(sess.authID) | |
| sess.connMu.Unlock() | |
| if sessAuthID == authID { | |
| matches = append(matches, items[i]) | |
| } | |
| } | |
| if len(matches) == 0 { | |
| return | |
| } | |
| toClose := make([]*codexWebsocketSession, 0, len(matches)) | |
| store.mu.Lock() | |
| for i := range matches { | |
| current, ok := store.sessions[matches[i].sessionID] | |
| if !ok || current == nil || current != matches[i].sess { | |
| continue | |
| } | |
| delete(store.sessions, matches[i].sessionID) | |
| toClose = append(toClose, current) | |
| } | |
| store.mu.Unlock() | |
| for i := range toClose { | |
| closeCodexWebsocketSession(toClose[i], reason) | |
| } | |
| } | |
| // CodexAutoExecutor routes Codex requests to the websocket transport only when: | |
| // 1. The downstream transport is websocket, and | |
| // 2. The selected auth enables websockets. | |
| // | |
| // For non-websocket downstream requests, it always uses the legacy HTTP implementation. | |
| type CodexAutoExecutor struct { | |
| httpExec *CodexExecutor | |
| wsExec *CodexWebsocketsExecutor | |
| } | |
| func NewCodexAutoExecutor(cfg *config.Config) *CodexAutoExecutor { | |
| return &CodexAutoExecutor{ | |
| httpExec: NewCodexExecutor(cfg), | |
| wsExec: NewCodexWebsocketsExecutor(cfg), | |
| } | |
| } | |
| func (e *CodexAutoExecutor) Identifier() string { return "codex" } | |
| func (e *CodexAutoExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { | |
| if e == nil || e.httpExec == nil { | |
| return nil | |
| } | |
| return e.httpExec.PrepareRequest(req, auth) | |
| } | |
| func (e *CodexAutoExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { | |
| if e == nil || e.httpExec == nil { | |
| return nil, fmt.Errorf("codex auto executor: http executor is nil") | |
| } | |
| return e.httpExec.HttpRequest(ctx, auth, req) | |
| } | |
| func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { | |
| if e == nil || e.httpExec == nil || e.wsExec == nil { | |
| return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: executor is nil") | |
| } | |
| if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) { | |
| return e.wsExec.Execute(ctx, auth, req, opts) | |
| } | |
| return e.httpExec.Execute(ctx, auth, req, opts) | |
| } | |
| func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { | |
| if e == nil || e.httpExec == nil || e.wsExec == nil { | |
| return nil, fmt.Errorf("codex auto executor: executor is nil") | |
| } | |
| if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) { | |
| return e.wsExec.ExecuteStream(ctx, auth, req, opts) | |
| } | |
| return e.httpExec.ExecuteStream(ctx, auth, req, opts) | |
| } | |
| func (e *CodexAutoExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { | |
| if e == nil || e.httpExec == nil { | |
| return nil, fmt.Errorf("codex auto executor: http executor is nil") | |
| } | |
| return e.httpExec.Refresh(ctx, auth) | |
| } | |
| func (e *CodexAutoExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { | |
| if e == nil || e.httpExec == nil { | |
| return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: http executor is nil") | |
| } | |
| return e.httpExec.CountTokens(ctx, auth, req, opts) | |
| } | |
| func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) { | |
| if e == nil || e.wsExec == nil { | |
| return | |
| } | |
| e.wsExec.CloseExecutionSession(sessionID) | |
| } | |
| func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool { | |
| if auth == nil { | |
| return false | |
| } | |
| if len(auth.Attributes) > 0 { | |
| if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" { | |
| parsed, errParse := strconv.ParseBool(raw) | |
| if errParse == nil { | |
| return parsed | |
| } | |
| } | |
| } | |
| if len(auth.Metadata) == 0 { | |
| return false | |
| } | |
| raw, ok := auth.Metadata["websockets"] | |
| if !ok || raw == nil { | |
| return false | |
| } | |
| switch v := raw.(type) { | |
| case bool: | |
| return v | |
| case string: | |
| parsed, errParse := strconv.ParseBool(strings.TrimSpace(v)) | |
| if errParse == nil { | |
| return parsed | |
| } | |
| default: | |
| } | |
| return false | |
| } | |