// Package executor provides runtime execution capabilities for various AI service providers. // This file implements the AI Studio executor that routes requests through a websocket-backed // transport for the AI Studio provider. package executor import ( "bytes" "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) // AIStudioExecutor routes AI Studio requests through a websocket-backed transport. type AIStudioExecutor struct { provider string relay *wsrelay.Manager cfg *config.Config } // NewAIStudioExecutor creates a new AI Studio executor instance. // // Parameters: // - cfg: The application configuration // - provider: The provider name // - relay: The websocket relay manager // // Returns: // - *AIStudioExecutor: A new AI Studio executor instance func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor { return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg} } // Identifier returns the executor identifier. func (e *AIStudioExecutor) Identifier() string { return "aistudio" } // PrepareRequest prepares the HTTP request for execution (no-op for AI Studio). func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil } // Execute performs a non-streaming request to the AI Studio API. func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) translatedReq, body, err := e.translateRequest(req, opts, false) if err != nil { return resp, err } endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, URL: endpoint, Headers: http.Header{"Content-Type": []string{"application/json"}}, Body: body.payload, } var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() } recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: endpoint, Method: http.MethodPost, Headers: wsReq.Headers.Clone(), Body: bytes.Clone(body.payload), Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, }) wsResp, err := e.relay.NonStream(ctx, authID, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return resp, err } recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone()) if len(wsResp.Body) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body)) } if wsResp.Status < 200 || wsResp.Status >= 300 { return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)} } reporter.publish(ctx, parseGeminiUsage(wsResp.Body)) var param any out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m) resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))} return resp, nil } // ExecuteStream performs a streaming request to the AI Studio API. func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) translatedReq, body, err := e.translateRequest(req, opts, true) if err != nil { return nil, err } endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt) wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, URL: endpoint, Headers: http.Header{"Content-Type": []string{"application/json"}}, Body: body.payload, } var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() } recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: endpoint, Method: http.MethodPost, Headers: wsReq.Headers.Clone(), Body: bytes.Clone(body.payload), Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, }) wsStream, err := e.relay.Stream(ctx, authID, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return nil, err } firstEvent, ok := <-wsStream if !ok { err = fmt.Errorf("wsrelay: stream closed before start") recordAPIResponseError(ctx, e.cfg, err) return nil, err } if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK { metadataLogged := false if firstEvent.Status > 0 { recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone()) metadataLogged = true } var body bytes.Buffer if len(firstEvent.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload)) body.Write(firstEvent.Payload) } if firstEvent.Type == wsrelay.MessageTypeStreamEnd { return nil, statusErr{code: firstEvent.Status, msg: body.String()} } for event := range wsStream { if event.Err != nil { recordAPIResponseError(ctx, e.cfg, event.Err) if body.Len() == 0 { body.WriteString(event.Err.Error()) } break } if !metadataLogged && event.Status > 0 { recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) metadataLogged = true } if len(event.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) body.Write(event.Payload) } if event.Type == wsrelay.MessageTypeStreamEnd { break } } return nil, statusErr{code: firstEvent.Status, msg: body.String()} } out := make(chan cliproxyexecutor.StreamChunk) stream = out go func(first wsrelay.StreamEvent) { defer close(out) var param any metadataLogged := false processEvent := func(event wsrelay.StreamEvent) bool { if event.Err != nil { recordAPIResponseError(ctx, e.cfg, event.Err) reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} return false } switch event.Type { case wsrelay.MessageTypeStreamStart: if !metadataLogged && event.Status > 0 { recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) metadataLogged = true } case wsrelay.MessageTypeStreamChunk: if len(event.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) filtered := FilterSSEUsageMetadata(event.Payload) if detail, ok := parseGeminiStreamUsage(filtered); ok { reporter.publish(ctx, detail) } lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m) for i := range lines { out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))} } break } case wsrelay.MessageTypeStreamEnd: return false case wsrelay.MessageTypeHTTPResp: if !metadataLogged && event.Status > 0 { recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone()) metadataLogged = true } if len(event.Payload) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload)) } lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m) for i := range lines { out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))} } reporter.publish(ctx, parseGeminiUsage(event.Payload)) return false case wsrelay.MessageTypeError: recordAPIResponseError(ctx, e.cfg, event.Err) reporter.publishFailure(ctx) out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)} return false } return true } if !processEvent(first) { return } for event := range wsStream { if !processEvent(event) { return } } }(firstEvent) return stream, nil } // CountTokens counts tokens for the given request using the AI Studio API. func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { _, body, err := e.translateRequest(req, opts, false) if err != nil { return cliproxyexecutor.Response{}, err } body.payload, _ = sjson.DeleteBytes(body.payload, "generationConfig") body.payload, _ = sjson.DeleteBytes(body.payload, "tools") body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings") endpoint := e.buildEndpoint(req.Model, "countTokens", "") wsReq := &wsrelay.HTTPRequest{ Method: http.MethodPost, URL: endpoint, Headers: http.Header{"Content-Type": []string{"application/json"}}, Body: body.payload, } var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID authLabel = auth.Label authType, authValue = auth.AccountInfo() } recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ URL: endpoint, Method: http.MethodPost, Headers: wsReq.Headers.Clone(), Body: bytes.Clone(body.payload), Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, AuthType: authType, AuthValue: authValue, }) resp, err := e.relay.NonStream(ctx, authID, wsReq) if err != nil { recordAPIResponseError(ctx, e.cfg, err) return cliproxyexecutor.Response{}, err } recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone()) if len(resp.Body) > 0 { appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body)) } if resp.Status < 200 || resp.Status >= 300 { return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)} } totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int() if totalTokens <= 0 { return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response") } translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body)) return cliproxyexecutor.Response{Payload: []byte(translated)}, nil } // Refresh refreshes the authentication credentials (no-op for AI Studio). func (e *AIStudioExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { return auth, nil } type translatedPayload struct { payload []byte action string toFormat sdktranslator.Format } func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) { from := opts.SourceFormat to := sdktranslator.FromString("gemini") originalPayload := bytes.Clone(req.Payload) if len(opts.OriginalRequest) > 0 { originalPayload = bytes.Clone(opts.OriginalRequest) } originalTranslated := sdktranslator.TranslateRequest(from, to, req.Model, originalPayload, stream) payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream) payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model) payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload) payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload) payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true) payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = fixGeminiImageAspectRatio(req.Model, payload) payload = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", payload, originalTranslated) payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens") payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType") payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema") metadataAction := "generateContent" if req.Metadata != nil { if action, _ := req.Metadata["action"].(string); action == "countTokens" { metadataAction = action } } action := metadataAction if stream && action != "countTokens" { action = "streamGenerateContent" } payload, _ = sjson.DeleteBytes(payload, "session_id") return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil } func (e *AIStudioExecutor) buildEndpoint(model, action, alt string) string { base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action) if action == "streamGenerateContent" { if alt == "" { return base + "?alt=sse" } return base + "?$alt=" + url.QueryEscape(alt) } if alt != "" && action != "countTokens" { return base + "?$alt=" + url.QueryEscape(alt) } return base } // ensureColonSpacedJSON normalizes JSON objects so that colons are followed by a single space while // keeping the payload otherwise compact. Non-JSON inputs are returned unchanged. func ensureColonSpacedJSON(payload []byte) []byte { trimmed := bytes.TrimSpace(payload) if len(trimmed) == 0 { return payload } var decoded any if err := json.Unmarshal(trimmed, &decoded); err != nil { return payload } indented, err := json.MarshalIndent(decoded, "", " ") if err != nil { return payload } compacted := make([]byte, 0, len(indented)) inString := false skipSpace := false for i := 0; i < len(indented); i++ { ch := indented[i] if ch == '"' { // A quote is escaped only when preceded by an odd number of consecutive backslashes. // For example: "\\\"" keeps the quote inside the string, but "\\\\" closes the string. backslashes := 0 for j := i - 1; j >= 0 && indented[j] == '\\'; j-- { backslashes++ } if backslashes%2 == 0 { inString = !inString } } if !inString { if ch == '\n' || ch == '\r' { skipSpace = true continue } if skipSpace { if ch == ' ' || ch == '\t' { continue } skipSpace = false } } compacted = append(compacted, ch) } return compacted }