Spaces:
Paused
Paused
| // Package executor provides runtime execution capabilities for various AI service providers. | |
| // This file implements the Vertex AI Gemini executor that talks to Google Vertex AI | |
| // endpoints using service account credentials or API keys. | |
| package executor | |
| import ( | |
| "bufio" | |
| "bytes" | |
| "context" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "net/http" | |
| "strings" | |
| "time" | |
| vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/config" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor/helps" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" | |
| "github.com/router-for-me/CLIProxyAPI/v6/internal/util" | |
| cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" | |
| cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" | |
| sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" | |
| log "github.com/sirupsen/logrus" | |
| "github.com/tidwall/gjson" | |
| "github.com/tidwall/sjson" | |
| "golang.org/x/oauth2" | |
| "golang.org/x/oauth2/google" | |
| ) | |
| const ( | |
| // vertexAPIVersion aligns with current public Vertex Generative AI API. | |
| vertexAPIVersion = "v1" | |
| ) | |
| // isImagenModel checks if the model name is an Imagen image generation model. | |
| // Imagen models use the :predict action instead of :generateContent. | |
| func isImagenModel(model string) bool { | |
| lowerModel := strings.ToLower(model) | |
| return strings.Contains(lowerModel, "imagen") | |
| } | |
| // getVertexAction returns the appropriate action for the given model. | |
| // Imagen models use "predict", while Gemini models use "generateContent". | |
| func getVertexAction(model string, isStream bool) string { | |
| if isImagenModel(model) { | |
| return "predict" | |
| } | |
| if isStream { | |
| return "streamGenerateContent" | |
| } | |
| return "generateContent" | |
| } | |
| // convertImagenToGeminiResponse converts Imagen API response to Gemini format | |
| // so it can be processed by the standard translation pipeline. | |
| // This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview. | |
| func convertImagenToGeminiResponse(data []byte, model string) []byte { | |
| predictions := gjson.GetBytes(data, "predictions") | |
| if !predictions.Exists() || !predictions.IsArray() { | |
| return data | |
| } | |
| // Build Gemini-compatible response with inlineData | |
| parts := make([]map[string]any, 0) | |
| for _, pred := range predictions.Array() { | |
| imageData := pred.Get("bytesBase64Encoded").String() | |
| mimeType := pred.Get("mimeType").String() | |
| if mimeType == "" { | |
| mimeType = "image/png" | |
| } | |
| if imageData != "" { | |
| parts = append(parts, map[string]any{ | |
| "inlineData": map[string]any{ | |
| "mimeType": mimeType, | |
| "data": imageData, | |
| }, | |
| }) | |
| } | |
| } | |
| // Generate unique response ID using timestamp | |
| responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano()) | |
| response := map[string]any{ | |
| "candidates": []map[string]any{{ | |
| "content": map[string]any{ | |
| "parts": parts, | |
| "role": "model", | |
| }, | |
| "finishReason": "STOP", | |
| }}, | |
| "responseId": responseId, | |
| "modelVersion": model, | |
| // Imagen API doesn't return token counts, set to 0 for tracking purposes | |
| "usageMetadata": map[string]any{ | |
| "promptTokenCount": 0, | |
| "candidatesTokenCount": 0, | |
| "totalTokenCount": 0, | |
| }, | |
| } | |
| result, err := json.Marshal(response) | |
| if err != nil { | |
| return data | |
| } | |
| return result | |
| } | |
| // convertToImagenRequest converts a Gemini-style request to Imagen API format. | |
| // Imagen API uses a different structure: instances[].prompt instead of contents[]. | |
| func convertToImagenRequest(payload []byte) ([]byte, error) { | |
| // Extract prompt from Gemini-style contents | |
| prompt := "" | |
| // Try to get prompt from contents[0].parts[0].text | |
| contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text") | |
| if contentsText.Exists() { | |
| prompt = contentsText.String() | |
| } | |
| // If no contents, try messages format (OpenAI-compatible) | |
| if prompt == "" { | |
| messagesText := gjson.GetBytes(payload, "messages.#.content") | |
| if messagesText.Exists() && messagesText.IsArray() { | |
| for _, msg := range messagesText.Array() { | |
| if msg.String() != "" { | |
| prompt = msg.String() | |
| break | |
| } | |
| } | |
| } | |
| } | |
| // If still no prompt, try direct prompt field | |
| if prompt == "" { | |
| directPrompt := gjson.GetBytes(payload, "prompt") | |
| if directPrompt.Exists() { | |
| prompt = directPrompt.String() | |
| } | |
| } | |
| if prompt == "" { | |
| return nil, fmt.Errorf("imagen: no prompt found in request") | |
| } | |
| // Build Imagen API request | |
| imagenReq := map[string]any{ | |
| "instances": []map[string]any{ | |
| { | |
| "prompt": prompt, | |
| }, | |
| }, | |
| "parameters": map[string]any{ | |
| "sampleCount": 1, | |
| }, | |
| } | |
| // Extract optional parameters | |
| if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() { | |
| imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String() | |
| } | |
| if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() { | |
| imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int()) | |
| } | |
| if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() { | |
| imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String() | |
| } | |
| return json.Marshal(imagenReq) | |
| } | |
| // GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials. | |
| type GeminiVertexExecutor struct { | |
| cfg *config.Config | |
| } | |
| // NewGeminiVertexExecutor creates a new Vertex AI Gemini executor instance. | |
| // | |
| // Parameters: | |
| // - cfg: The application configuration | |
| // | |
| // Returns: | |
| // - *GeminiVertexExecutor: A new Vertex AI Gemini executor instance | |
| func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { | |
| return &GeminiVertexExecutor{cfg: cfg} | |
| } | |
| // Identifier returns the executor identifier. | |
| func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } | |
| // PrepareRequest injects Vertex credentials into the outgoing HTTP request. | |
| func (e *GeminiVertexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { | |
| if req == nil { | |
| return nil | |
| } | |
| apiKey, _ := vertexAPICreds(auth) | |
| if strings.TrimSpace(apiKey) != "" { | |
| req.Header.Set("x-goog-api-key", apiKey) | |
| req.Header.Del("Authorization") | |
| return nil | |
| } | |
| _, _, saJSON, errCreds := vertexCreds(auth) | |
| if errCreds != nil { | |
| return errCreds | |
| } | |
| token, errToken := vertexAccessToken(req.Context(), e.cfg, auth, saJSON) | |
| if errToken != nil { | |
| return errToken | |
| } | |
| if strings.TrimSpace(token) == "" { | |
| return statusErr{code: http.StatusUnauthorized, msg: "missing access token"} | |
| } | |
| req.Header.Set("Authorization", "Bearer "+token) | |
| req.Header.Del("x-goog-api-key") | |
| return nil | |
| } | |
| // HttpRequest injects Vertex credentials into the request and executes it. | |
| func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { | |
| if req == nil { | |
| return nil, fmt.Errorf("vertex executor: request is nil") | |
| } | |
| if ctx == nil { | |
| ctx = req.Context() | |
| } | |
| httpReq := req.WithContext(ctx) | |
| if err := e.PrepareRequest(httpReq, auth); err != nil { | |
| return nil, err | |
| } | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| return httpClient.Do(httpReq) | |
| } | |
| // Execute performs a non-streaming request to the Vertex AI API. | |
| func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { | |
| if opts.Alt == "responses/compact" { | |
| return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} | |
| } | |
| // Try API key authentication first | |
| apiKey, baseURL := vertexAPICreds(auth) | |
| // If no API key found, fall back to service account authentication | |
| if apiKey == "" { | |
| projectID, location, saJSON, errCreds := vertexCreds(auth) | |
| if errCreds != nil { | |
| return resp, errCreds | |
| } | |
| return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) | |
| } | |
| // Use API key authentication | |
| return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) | |
| } | |
| // ExecuteStream performs a streaming request to the Vertex AI API. | |
| func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { | |
| if opts.Alt == "responses/compact" { | |
| return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"} | |
| } | |
| // Try API key authentication first | |
| apiKey, baseURL := vertexAPICreds(auth) | |
| // If no API key found, fall back to service account authentication | |
| if apiKey == "" { | |
| projectID, location, saJSON, errCreds := vertexCreds(auth) | |
| if errCreds != nil { | |
| return nil, errCreds | |
| } | |
| return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) | |
| } | |
| // Use API key authentication | |
| return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) | |
| } | |
| // CountTokens counts tokens for the given request using the Vertex AI API. | |
| func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { | |
| // Try API key authentication first | |
| apiKey, baseURL := vertexAPICreds(auth) | |
| // If no API key found, fall back to service account authentication | |
| if apiKey == "" { | |
| projectID, location, saJSON, errCreds := vertexCreds(auth) | |
| if errCreds != nil { | |
| return cliproxyexecutor.Response{}, errCreds | |
| } | |
| return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) | |
| } | |
| // Use API key authentication | |
| return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) | |
| } | |
| // Refresh refreshes the authentication credentials (no-op for Vertex). | |
| func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { | |
| return auth, nil | |
| } | |
| // executeWithServiceAccount handles authentication using service account credentials. | |
| // This method contains the original service account authentication logic. | |
| func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) { | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) | |
| defer reporter.TrackFailure(ctx, &err) | |
| var body []byte | |
| // Handle Imagen models with special request format | |
| if isImagenModel(baseModel) { | |
| imagenBody, errImagen := convertToImagenRequest(req.Payload) | |
| if errImagen != nil { | |
| return resp, errImagen | |
| } | |
| body = imagenBody | |
| } else { | |
| // Standard Gemini translation flow | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| originalPayloadSource := req.Payload | |
| if len(opts.OriginalRequest) > 0 { | |
| originalPayloadSource = opts.OriginalRequest | |
| } | |
| originalPayload := originalPayloadSource | |
| originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) | |
| body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) | |
| body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return resp, err | |
| } | |
| body = fixGeminiImageAspectRatio(baseModel, body) | |
| requestedModel := helps.PayloadRequestedModel(opts, req.Model) | |
| body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) | |
| body, _ = sjson.SetBytes(body, "model", baseModel) | |
| } | |
| action := getVertexAction(baseModel, false) | |
| if req.Metadata != nil { | |
| if a, _ := req.Metadata["action"].(string); a == "countTokens" { | |
| action = "countTokens" | |
| } | |
| } | |
| baseURL := vertexBaseURL(location) | |
| url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action) | |
| if opts.Alt != "" && action != "countTokens" { | |
| url = url + fmt.Sprintf("?$alt=%s", opts.Alt) | |
| } | |
| body, _ = sjson.DeleteBytes(body, "session_id") | |
| httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) | |
| if errNewReq != nil { | |
| return resp, errNewReq | |
| } | |
| httpReq.Header.Set("Content-Type", "application/json") | |
| if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { | |
| httpReq.Header.Set("Authorization", "Bearer "+token) | |
| } else if errTok != nil { | |
| log.Errorf("vertex executor: access token error: %v", errTok) | |
| return resp, statusErr{code: 500, msg: "internal server error"} | |
| } | |
| applyGeminiHeaders(httpReq, auth) | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(httpReq, attrs) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: url, | |
| Method: http.MethodPost, | |
| Headers: httpReq.Header.Clone(), | |
| Body: body, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| httpResp, errDo := httpClient.Do(httpReq) | |
| if errDo != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errDo) | |
| return resp, errDo | |
| } | |
| defer func() { | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| }() | |
| helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) | |
| if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { | |
| b, _ := io.ReadAll(httpResp.Body) | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, b) | |
| helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) | |
| err = statusErr{code: httpResp.StatusCode, msg: string(b)} | |
| return resp, err | |
| } | |
| data, errRead := io.ReadAll(httpResp.Body) | |
| if errRead != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errRead) | |
| return resp, errRead | |
| } | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, data) | |
| reporter.Publish(ctx, helps.ParseGeminiUsage(data)) | |
| // For Imagen models, convert response to Gemini format before translation | |
| // This ensures Imagen responses use the same format as gemini-3-pro-image-preview | |
| if isImagenModel(baseModel) { | |
| data = convertImagenToGeminiResponse(data, baseModel) | |
| } | |
| // Standard Gemini translation (works for both Gemini and converted Imagen responses) | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| var param any | |
| out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) | |
| resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} | |
| return resp, nil | |
| } | |
| // executeWithAPIKey handles authentication using API key credentials. | |
| func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) | |
| defer reporter.TrackFailure(ctx, &err) | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| originalPayloadSource := req.Payload | |
| if len(opts.OriginalRequest) > 0 { | |
| originalPayloadSource = opts.OriginalRequest | |
| } | |
| originalPayload := originalPayloadSource | |
| originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) | |
| body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) | |
| body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return resp, err | |
| } | |
| body = fixGeminiImageAspectRatio(baseModel, body) | |
| requestedModel := helps.PayloadRequestedModel(opts, req.Model) | |
| body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) | |
| body, _ = sjson.SetBytes(body, "model", baseModel) | |
| action := getVertexAction(baseModel, false) | |
| if req.Metadata != nil { | |
| if a, _ := req.Metadata["action"].(string); a == "countTokens" { | |
| action = "countTokens" | |
| } | |
| } | |
| // For API key auth, use simpler URL format without project/location | |
| if baseURL == "" { | |
| baseURL = "https://aiplatform.googleapis.com" | |
| } | |
| url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) | |
| if opts.Alt != "" && action != "countTokens" { | |
| url = url + fmt.Sprintf("?$alt=%s", opts.Alt) | |
| } | |
| body, _ = sjson.DeleteBytes(body, "session_id") | |
| httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) | |
| if errNewReq != nil { | |
| return resp, errNewReq | |
| } | |
| httpReq.Header.Set("Content-Type", "application/json") | |
| if apiKey != "" { | |
| httpReq.Header.Set("x-goog-api-key", apiKey) | |
| } | |
| applyGeminiHeaders(httpReq, auth) | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(httpReq, attrs) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: url, | |
| Method: http.MethodPost, | |
| Headers: httpReq.Header.Clone(), | |
| Body: body, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| httpResp, errDo := httpClient.Do(httpReq) | |
| if errDo != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errDo) | |
| return resp, errDo | |
| } | |
| defer func() { | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| }() | |
| helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) | |
| if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { | |
| b, _ := io.ReadAll(httpResp.Body) | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, b) | |
| helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) | |
| err = statusErr{code: httpResp.StatusCode, msg: string(b)} | |
| return resp, err | |
| } | |
| data, errRead := io.ReadAll(httpResp.Body) | |
| if errRead != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errRead) | |
| return resp, errRead | |
| } | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, data) | |
| reporter.Publish(ctx, helps.ParseGeminiUsage(data)) | |
| var param any | |
| out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m) | |
| resp = cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()} | |
| return resp, nil | |
| } | |
| // executeStreamWithServiceAccount handles streaming authentication using service account credentials. | |
| func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) { | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) | |
| defer reporter.TrackFailure(ctx, &err) | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| originalPayloadSource := req.Payload | |
| if len(opts.OriginalRequest) > 0 { | |
| originalPayloadSource = opts.OriginalRequest | |
| } | |
| originalPayload := originalPayloadSource | |
| originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) | |
| body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) | |
| body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return nil, err | |
| } | |
| body = fixGeminiImageAspectRatio(baseModel, body) | |
| requestedModel := helps.PayloadRequestedModel(opts, req.Model) | |
| body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) | |
| body, _ = sjson.SetBytes(body, "model", baseModel) | |
| action := getVertexAction(baseModel, true) | |
| baseURL := vertexBaseURL(location) | |
| url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action) | |
| // Imagen models don't support streaming, skip SSE params | |
| if !isImagenModel(baseModel) { | |
| if opts.Alt == "" { | |
| url = url + "?alt=sse" | |
| } else { | |
| url = url + fmt.Sprintf("?$alt=%s", opts.Alt) | |
| } | |
| } | |
| body, _ = sjson.DeleteBytes(body, "session_id") | |
| httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) | |
| if errNewReq != nil { | |
| return nil, errNewReq | |
| } | |
| httpReq.Header.Set("Content-Type", "application/json") | |
| if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { | |
| httpReq.Header.Set("Authorization", "Bearer "+token) | |
| } else if errTok != nil { | |
| log.Errorf("vertex executor: access token error: %v", errTok) | |
| return nil, statusErr{code: 500, msg: "internal server error"} | |
| } | |
| applyGeminiHeaders(httpReq, auth) | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(httpReq, attrs) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: url, | |
| Method: http.MethodPost, | |
| Headers: httpReq.Header.Clone(), | |
| Body: body, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| httpResp, errDo := httpClient.Do(httpReq) | |
| if errDo != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errDo) | |
| return nil, errDo | |
| } | |
| helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) | |
| if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { | |
| b, _ := io.ReadAll(httpResp.Body) | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, b) | |
| helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} | |
| } | |
| out := make(chan cliproxyexecutor.StreamChunk) | |
| go func() { | |
| defer close(out) | |
| defer func() { | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| }() | |
| scanner := bufio.NewScanner(httpResp.Body) | |
| scanner.Buffer(nil, streamScannerBuffer) | |
| var param any | |
| for scanner.Scan() { | |
| line := scanner.Bytes() | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, line) | |
| if detail, ok := helps.ParseGeminiStreamUsage(line); ok { | |
| reporter.Publish(ctx, detail) | |
| } | |
| lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) | |
| for i := range lines { | |
| out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} | |
| } | |
| } | |
| lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) | |
| for i := range lines { | |
| out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} | |
| } | |
| if errScan := scanner.Err(); errScan != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errScan) | |
| reporter.PublishFailure(ctx) | |
| out <- cliproxyexecutor.StreamChunk{Err: errScan} | |
| } | |
| }() | |
| return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil | |
| } | |
| // executeStreamWithAPIKey handles streaming authentication using API key credentials. | |
| func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) { | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| reporter := helps.NewUsageReporter(ctx, e.Identifier(), baseModel, auth) | |
| defer reporter.TrackFailure(ctx, &err) | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| originalPayloadSource := req.Payload | |
| if len(opts.OriginalRequest) > 0 { | |
| originalPayloadSource = opts.OriginalRequest | |
| } | |
| originalPayload := originalPayloadSource | |
| originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true) | |
| body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true) | |
| body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return nil, err | |
| } | |
| body = fixGeminiImageAspectRatio(baseModel, body) | |
| requestedModel := helps.PayloadRequestedModel(opts, req.Model) | |
| body = helps.ApplyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) | |
| body, _ = sjson.SetBytes(body, "model", baseModel) | |
| action := getVertexAction(baseModel, true) | |
| // For API key auth, use simpler URL format without project/location | |
| if baseURL == "" { | |
| baseURL = "https://aiplatform.googleapis.com" | |
| } | |
| url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) | |
| // Imagen models don't support streaming, skip SSE params | |
| if !isImagenModel(baseModel) { | |
| if opts.Alt == "" { | |
| url = url + "?alt=sse" | |
| } else { | |
| url = url + fmt.Sprintf("?$alt=%s", opts.Alt) | |
| } | |
| } | |
| body, _ = sjson.DeleteBytes(body, "session_id") | |
| httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) | |
| if errNewReq != nil { | |
| return nil, errNewReq | |
| } | |
| httpReq.Header.Set("Content-Type", "application/json") | |
| if apiKey != "" { | |
| httpReq.Header.Set("x-goog-api-key", apiKey) | |
| } | |
| applyGeminiHeaders(httpReq, auth) | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(httpReq, attrs) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: url, | |
| Method: http.MethodPost, | |
| Headers: httpReq.Header.Clone(), | |
| Body: body, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| httpResp, errDo := httpClient.Do(httpReq) | |
| if errDo != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errDo) | |
| return nil, errDo | |
| } | |
| helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) | |
| if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { | |
| b, _ := io.ReadAll(httpResp.Body) | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, b) | |
| helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} | |
| } | |
| out := make(chan cliproxyexecutor.StreamChunk) | |
| go func() { | |
| defer close(out) | |
| defer func() { | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| }() | |
| scanner := bufio.NewScanner(httpResp.Body) | |
| scanner.Buffer(nil, streamScannerBuffer) | |
| var param any | |
| for scanner.Scan() { | |
| line := scanner.Bytes() | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, line) | |
| if detail, ok := helps.ParseGeminiStreamUsage(line); ok { | |
| reporter.Publish(ctx, detail) | |
| } | |
| lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m) | |
| for i := range lines { | |
| out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} | |
| } | |
| } | |
| lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m) | |
| for i := range lines { | |
| out <- cliproxyexecutor.StreamChunk{Payload: lines[i]} | |
| } | |
| if errScan := scanner.Err(); errScan != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errScan) | |
| reporter.PublishFailure(ctx) | |
| out <- cliproxyexecutor.StreamChunk{Err: errScan} | |
| } | |
| }() | |
| return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil | |
| } | |
| // countTokensWithServiceAccount counts tokens using service account credentials. | |
| func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) | |
| translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return cliproxyexecutor.Response{}, err | |
| } | |
| translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq) | |
| translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel) | |
| respCtx := context.WithValue(ctx, "alt", opts.Alt) | |
| translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") | |
| translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") | |
| translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") | |
| baseURL := vertexBaseURL(location) | |
| url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "countTokens") | |
| httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) | |
| if errNewReq != nil { | |
| return cliproxyexecutor.Response{}, errNewReq | |
| } | |
| httpReq.Header.Set("Content-Type", "application/json") | |
| if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { | |
| httpReq.Header.Set("Authorization", "Bearer "+token) | |
| } else if errTok != nil { | |
| log.Errorf("vertex executor: access token error: %v", errTok) | |
| return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} | |
| } | |
| applyGeminiHeaders(httpReq, auth) | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(httpReq, attrs) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: url, | |
| Method: http.MethodPost, | |
| Headers: httpReq.Header.Clone(), | |
| Body: translatedReq, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| httpResp, errDo := httpClient.Do(httpReq) | |
| if errDo != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errDo) | |
| return cliproxyexecutor.Response{}, errDo | |
| } | |
| defer func() { | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| }() | |
| helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) | |
| if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { | |
| b, _ := io.ReadAll(httpResp.Body) | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, b) | |
| helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) | |
| return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} | |
| } | |
| data, errRead := io.ReadAll(httpResp.Body) | |
| if errRead != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errRead) | |
| return cliproxyexecutor.Response{}, errRead | |
| } | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, data) | |
| count := gjson.GetBytes(data, "totalTokens").Int() | |
| out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) | |
| return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil | |
| } | |
| // countTokensWithAPIKey handles token counting using API key credentials. | |
| func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { | |
| baseModel := thinking.ParseSuffix(req.Model).ModelName | |
| from := opts.SourceFormat | |
| to := sdktranslator.FromString("gemini") | |
| translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) | |
| translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier()) | |
| if err != nil { | |
| return cliproxyexecutor.Response{}, err | |
| } | |
| translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq) | |
| translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel) | |
| respCtx := context.WithValue(ctx, "alt", opts.Alt) | |
| translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") | |
| translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") | |
| translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") | |
| // For API key auth, use simpler URL format without project/location | |
| if baseURL == "" { | |
| baseURL = "https://aiplatform.googleapis.com" | |
| } | |
| url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens") | |
| httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) | |
| if errNewReq != nil { | |
| return cliproxyexecutor.Response{}, errNewReq | |
| } | |
| httpReq.Header.Set("Content-Type", "application/json") | |
| if apiKey != "" { | |
| httpReq.Header.Set("x-goog-api-key", apiKey) | |
| } | |
| applyGeminiHeaders(httpReq, auth) | |
| var attrs map[string]string | |
| if auth != nil { | |
| attrs = auth.Attributes | |
| } | |
| util.ApplyCustomHeadersFromAttrs(httpReq, attrs) | |
| var authID, authLabel, authType, authValue string | |
| if auth != nil { | |
| authID = auth.ID | |
| authLabel = auth.Label | |
| authType, authValue = auth.AccountInfo() | |
| } | |
| helps.RecordAPIRequest(ctx, e.cfg, helps.UpstreamRequestLog{ | |
| URL: url, | |
| Method: http.MethodPost, | |
| Headers: httpReq.Header.Clone(), | |
| Body: translatedReq, | |
| Provider: e.Identifier(), | |
| AuthID: authID, | |
| AuthLabel: authLabel, | |
| AuthType: authType, | |
| AuthValue: authValue, | |
| }) | |
| httpClient := helps.NewProxyAwareHTTPClient(ctx, e.cfg, auth, 0) | |
| httpResp, errDo := httpClient.Do(httpReq) | |
| if errDo != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errDo) | |
| return cliproxyexecutor.Response{}, errDo | |
| } | |
| defer func() { | |
| if errClose := httpResp.Body.Close(); errClose != nil { | |
| log.Errorf("vertex executor: close response body error: %v", errClose) | |
| } | |
| }() | |
| helps.RecordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) | |
| if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { | |
| b, _ := io.ReadAll(httpResp.Body) | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, b) | |
| helps.LogWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, helps.SummarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) | |
| return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} | |
| } | |
| data, errRead := io.ReadAll(httpResp.Body) | |
| if errRead != nil { | |
| helps.RecordAPIResponseError(ctx, e.cfg, errRead) | |
| return cliproxyexecutor.Response{}, errRead | |
| } | |
| helps.AppendAPIResponseChunk(ctx, e.cfg, data) | |
| count := gjson.GetBytes(data, "totalTokens").Int() | |
| out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) | |
| return cliproxyexecutor.Response{Payload: out, Headers: httpResp.Header.Clone()}, nil | |
| } | |
| // vertexCreds extracts project, location and raw service account JSON from auth metadata. | |
| func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) { | |
| if a == nil || a.Metadata == nil { | |
| return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata") | |
| } | |
| if v, ok := a.Metadata["project_id"].(string); ok { | |
| projectID = strings.TrimSpace(v) | |
| } | |
| if projectID == "" { | |
| // Some service accounts may use "project"; still prefer standard field | |
| if v, ok := a.Metadata["project"].(string); ok { | |
| projectID = strings.TrimSpace(v) | |
| } | |
| } | |
| if projectID == "" { | |
| return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials") | |
| } | |
| if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" { | |
| location = strings.TrimSpace(v) | |
| } else { | |
| location = "us-central1" | |
| } | |
| var sa map[string]any | |
| if raw, ok := a.Metadata["service_account"].(map[string]any); ok { | |
| sa = raw | |
| } | |
| if sa == nil { | |
| return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials") | |
| } | |
| normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa) | |
| if errNorm != nil { | |
| return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm) | |
| } | |
| saJSON, errMarshal := json.Marshal(normalized) | |
| if errMarshal != nil { | |
| return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal) | |
| } | |
| return projectID, location, saJSON, nil | |
| } | |
| // vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern. | |
| func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { | |
| if a == nil { | |
| return "", "" | |
| } | |
| if a.Attributes != nil { | |
| apiKey = a.Attributes["api_key"] | |
| baseURL = a.Attributes["base_url"] | |
| } | |
| if apiKey == "" && a.Metadata != nil { | |
| if v, ok := a.Metadata["access_token"].(string); ok { | |
| apiKey = v | |
| } | |
| } | |
| return | |
| } | |
| func vertexBaseURL(location string) string { | |
| loc := strings.TrimSpace(location) | |
| if loc == "" { | |
| loc = "us-central1" | |
| } else if loc == "global" { | |
| return "https://aiplatform.googleapis.com" | |
| } | |
| return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc) | |
| } | |
| func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) { | |
| if httpClient := helps.NewProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil { | |
| ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | |
| } | |
| // Use cloud-platform scope for Vertex AI. | |
| creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform") | |
| if errCreds != nil { | |
| return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds) | |
| } | |
| tok, errTok := creds.TokenSource.Token() | |
| if errTok != nil { | |
| return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok) | |
| } | |
| return tok.AccessToken, nil | |
| } | |
| // resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth. | |
| func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey { | |
| if auth == nil || e.cfg == nil { | |
| return nil | |
| } | |
| var attrKey, attrBase string | |
| if auth.Attributes != nil { | |
| attrKey = strings.TrimSpace(auth.Attributes["api_key"]) | |
| attrBase = strings.TrimSpace(auth.Attributes["base_url"]) | |
| } | |
| for i := range e.cfg.VertexCompatAPIKey { | |
| entry := &e.cfg.VertexCompatAPIKey[i] | |
| cfgKey := strings.TrimSpace(entry.APIKey) | |
| cfgBase := strings.TrimSpace(entry.BaseURL) | |
| if attrKey != "" && attrBase != "" { | |
| if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) { | |
| return entry | |
| } | |
| continue | |
| } | |
| if attrKey != "" && strings.EqualFold(cfgKey, attrKey) { | |
| if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) { | |
| return entry | |
| } | |
| } | |
| if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) { | |
| return entry | |
| } | |
| } | |
| if attrKey != "" { | |
| for i := range e.cfg.VertexCompatAPIKey { | |
| entry := &e.cfg.VertexCompatAPIKey[i] | |
| if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) { | |
| return entry | |
| } | |
| } | |
| } | |
| return nil | |
| } | |