AUXteam's picture
Set Gemini API version to v1
8c741f6 verified
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { RequestOptions } from "../../types";
import {
GoogleGenerativeAIError,
GoogleGenerativeAIFetchError,
GoogleGenerativeAIRequestInputError,
} from "../errors";
export const DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com";
export const DEFAULT_API_VERSION = "v1beta";
/**
* We can't `require` package.json if this runs on web. We will use rollup to
* swap in the version number here at build time.
*/
const PACKAGE_VERSION = "__PACKAGE_VERSION__";
const PACKAGE_LOG_HEADER = "genai-js";
export enum Task {
GENERATE_CONTENT = "generateContent",
STREAM_GENERATE_CONTENT = "streamGenerateContent",
COUNT_TOKENS = "countTokens",
EMBED_CONTENT = "embedContent",
BATCH_EMBED_CONTENTS = "batchEmbedContents",
}
export class RequestUrl {
constructor(
public model: string,
public task: Task,
public apiKey: string,
public stream: boolean,
public requestOptions: RequestOptions,
) {}
toString(): string {
const apiVersion = this.requestOptions?.apiVersion || DEFAULT_API_VERSION;
const baseUrl = this.requestOptions?.baseUrl || DEFAULT_BASE_URL;
let url = `${baseUrl}/${apiVersion}/${this.model}:${this.task}`;
if (this.stream) {
url += "?alt=sse";
}
return url;
}
}
/**
* Simple, but may become more complex if we add more versions to log.
*/
export function getClientHeaders(requestOptions: RequestOptions): string {
const clientHeaders = [];
if (requestOptions?.apiClient) {
clientHeaders.push(requestOptions.apiClient);
}
clientHeaders.push(`${PACKAGE_LOG_HEADER}/${PACKAGE_VERSION}`);
return clientHeaders.join(" ");
}
export async function getHeaders(url: RequestUrl): Promise<Headers> {
const headers = new Headers();
headers.append("Content-Type", "application/json");
headers.append("x-goog-api-client", getClientHeaders(url.requestOptions));
headers.append("x-goog-api-key", url.apiKey);
let customHeaders = url.requestOptions?.customHeaders;
if (customHeaders) {
if (!(customHeaders instanceof Headers)) {
try {
customHeaders = new Headers(customHeaders);
} catch (e) {
throw new GoogleGenerativeAIRequestInputError(
`unable to convert customHeaders value ${JSON.stringify(
customHeaders,
)} to Headers: ${e.message}`,
);
}
}
for (const [headerName, headerValue] of customHeaders.entries()) {
if (headerName === "x-goog-api-key") {
throw new GoogleGenerativeAIRequestInputError(
`Cannot set reserved header name ${headerName}`,
);
} else if (headerName === "x-goog-api-client") {
throw new GoogleGenerativeAIRequestInputError(
`Header name ${headerName} can only be set using the apiClient field`,
);
}
headers.append(headerName, headerValue);
}
}
return headers;
}
export async function constructModelRequest(
model: string,
task: Task,
apiKey: string,
stream: boolean,
body: string,
requestOptions?: RequestOptions,
): Promise<{ url: string; fetchOptions: RequestInit }> {
const url = new RequestUrl(model, task, apiKey, stream, requestOptions);
return {
url: url.toString(),
fetchOptions: {
...buildFetchOptions(requestOptions),
method: "POST",
headers: await getHeaders(url),
body,
},
};
}
export async function makeModelRequest(
model: string,
task: Task,
apiKey: string,
stream: boolean,
body: string,
requestOptions?: RequestOptions,
// Allows this to be stubbed for tests
fetchFn = fetch,
): Promise<Response> {
const { url, fetchOptions } = await constructModelRequest(
model,
task,
apiKey,
stream,
body,
requestOptions,
);
return makeRequest(url, fetchOptions, fetchFn);
}
export async function makeRequest(
url: string,
fetchOptions: RequestInit,
fetchFn = fetch,
): Promise<Response> {
let response;
try {
response = await fetchFn(url, fetchOptions);
} catch (e) {
handleResponseError(e, url);
}
if (!response.ok) {
await handleResponseNotOk(response, url);
}
return response;
}
function handleResponseError(e: Error, url: string): void {
let err = e;
if (
!(
e instanceof GoogleGenerativeAIFetchError ||
e instanceof GoogleGenerativeAIRequestInputError
)
) {
err = new GoogleGenerativeAIError(
`Error fetching from ${url.toString()}: ${e.message}`,
);
err.stack = e.stack;
}
throw err;
}
async function handleResponseNotOk(
response: Response,
url: string,
): Promise<void> {
let message = "";
let errorDetails;
try {
const json = await response.json();
message = json.error.message;
if (json.error.details) {
message += ` ${JSON.stringify(json.error.details)}`;
errorDetails = json.error.details;
}
} catch (e) {
// ignored
}
throw new GoogleGenerativeAIFetchError(
`Error fetching from ${url.toString()}: [${response.status} ${
response.statusText
}] ${message}`,
response.status,
response.statusText,
errorDetails,
);
}
/**
* Generates the request options to be passed to the fetch API.
* @param requestOptions - The user-defined request options.
* @returns The generated request options.
*/
function buildFetchOptions(requestOptions?: RequestOptions): RequestInit {
const fetchOptions = {} as RequestInit;
if (requestOptions?.timeout >= 0) {
const abortController = new AbortController();
const signal = abortController.signal;
setTimeout(() => abortController.abort(), requestOptions.timeout);
fetchOptions.signal = signal;
}
return fetchOptions;
}