/** * @license * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { RequestOptions } from "../../types"; import { GoogleGenerativeAIError, GoogleGenerativeAIFetchError, GoogleGenerativeAIRequestInputError, } from "../errors"; export const DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com"; export const DEFAULT_API_VERSION = "v1beta"; /** * We can't `require` package.json if this runs on web. We will use rollup to * swap in the version number here at build time. */ const PACKAGE_VERSION = "__PACKAGE_VERSION__"; const PACKAGE_LOG_HEADER = "genai-js"; export enum Task { GENERATE_CONTENT = "generateContent", STREAM_GENERATE_CONTENT = "streamGenerateContent", COUNT_TOKENS = "countTokens", EMBED_CONTENT = "embedContent", BATCH_EMBED_CONTENTS = "batchEmbedContents", } export class RequestUrl { constructor( public model: string, public task: Task, public apiKey: string, public stream: boolean, public requestOptions: RequestOptions, ) {} toString(): string { const apiVersion = this.requestOptions?.apiVersion || DEFAULT_API_VERSION; const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL; let url = `${baseUrl}/${apiVersion}/${this.model}:${this.task}`; if (this.stream) { url += "?alt=sse"; } return url; } } /** * Simple, but may become more complex if we add more versions to log. */ export function getClientHeaders(requestOptions: RequestOptions): string { const clientHeaders = []; if (requestOptions?.apiClient) { clientHeaders.push(requestOptions.apiClient); } clientHeaders.push(`${PACKAGE_LOG_HEADER}/${PACKAGE_VERSION}`); return clientHeaders.join(" "); } export async function getHeaders(url: RequestUrl): Promise { const headers = new Headers(); headers.append("Content-Type", "application/json"); headers.append("x-goog-api-client", getClientHeaders(url.requestOptions)); headers.append("x-goog-api-key", url.apiKey); let customHeaders = url.requestOptions?.customHeaders; if (customHeaders) { if (!(customHeaders instanceof Headers)) { try { customHeaders = new Headers(customHeaders); } catch (e) { throw new GoogleGenerativeAIRequestInputError( `unable to convert customHeaders value ${JSON.stringify( customHeaders, )} to Headers: ${e.message}`, ); } } for (const [headerName, headerValue] of customHeaders.entries()) { if (headerName === "x-goog-api-key") { throw new GoogleGenerativeAIRequestInputError( `Cannot set reserved header name ${headerName}`, ); } else if (headerName === "x-goog-api-client") { throw new GoogleGenerativeAIRequestInputError( `Header name ${headerName} can only be set using the apiClient field`, ); } headers.append(headerName, headerValue); } } return headers; } export async function constructModelRequest( model: string, task: Task, apiKey: string, stream: boolean, body: string, requestOptions?: RequestOptions, ): Promise<{ url: string; fetchOptions: RequestInit }> { const url = new RequestUrl(model, task, apiKey, stream, requestOptions); return { url: url.toString(), fetchOptions: { ...buildFetchOptions(requestOptions), method: "POST", headers: await getHeaders(url), body, }, }; } export async function makeModelRequest( model: string, task: Task, apiKey: string, stream: boolean, body: string, requestOptions?: RequestOptions, // Allows this to be stubbed for tests fetchFn = fetch, ): Promise { const { url, fetchOptions } = await constructModelRequest( model, task, apiKey, stream, body, requestOptions, ); return makeRequest(url, fetchOptions, fetchFn); } export async function makeRequest( url: string, fetchOptions: RequestInit, fetchFn = fetch, ): Promise { let response; try { response = await fetchFn(url, fetchOptions); } catch (e) { handleResponseError(e, url); } if (!response.ok) { await handleResponseNotOk(response, url); } return response; } function handleResponseError(e: Error, url: string): void { let err = e; if ( !( e instanceof GoogleGenerativeAIFetchError || e instanceof GoogleGenerativeAIRequestInputError ) ) { err = new GoogleGenerativeAIError( `Error fetching from ${url.toString()}: ${e.message}`, ); err.stack = e.stack; } throw err; } async function handleResponseNotOk( response: Response, url: string, ): Promise { let message = ""; let errorDetails; try { const json = await response.json(); message = json.error.message; if (json.error.details) { message += ` ${JSON.stringify(json.error.details)}`; errorDetails = json.error.details; } } catch (e) { // ignored } throw new GoogleGenerativeAIFetchError( `Error fetching from ${url.toString()}: [${response.status} ${ response.statusText }] ${message}`, response.status, response.statusText, errorDetails, ); } /** * Generates the request options to be passed to the fetch API. * @param requestOptions - The user-defined request options. * @returns The generated request options. */ function buildFetchOptions(requestOptions?: RequestOptions): RequestInit { const fetchOptions = {} as RequestInit; if (requestOptions?.timeout >= 0) { const abortController = new AbortController(); const signal = abortController.signal; setTimeout(() => abortController.abort(), requestOptions.timeout); fetchOptions.signal = signal; } return fetchOptions; }