/** * @license * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { CachedContent, RequestOptions } from "../../types"; import { CachedContentUrl, getHeaders, makeServerRequest } from "./request"; import { CachedContentCreateParams, CachedContentUpdateParams, ListCacheResponse, ListParams, _CachedContentUpdateRequestFields, } from "../../types/server"; import { RpcTask } from "./constants"; import { GoogleGenerativeAIError, GoogleGenerativeAIRequestInputError, } from "../errors"; import { formatSystemInstruction } from "../requests/request-helpers"; /** * Class for managing GoogleAI content caches. * @public */ export class GoogleAICacheManager { constructor( public apiKey: string, private _requestOptions?: RequestOptions, ) {} /** * Upload a new content cache */ async create( createOptions: CachedContentCreateParams, ): Promise { const newCachedContent: CachedContent = { ...createOptions }; if (createOptions.ttlSeconds) { if (createOptions.expireTime) { throw new GoogleGenerativeAIRequestInputError( "You cannot specify both `ttlSeconds` and `expireTime` when creating" + " a content cache. You must choose one.", ); } if (createOptions.systemInstruction) { newCachedContent.systemInstruction = formatSystemInstruction( createOptions.systemInstruction, ); } newCachedContent.ttl = createOptions.ttlSeconds.toString() + "s"; delete (newCachedContent as CachedContentCreateParams).ttlSeconds; } if (!newCachedContent.model) { throw new GoogleGenerativeAIRequestInputError( "Cached content must contain a `model` field.", ); } if (!newCachedContent.model.includes("/")) { // If path is not included, assume it's a non-tuned model. newCachedContent.model = `models/${newCachedContent.model}`; } const url = new CachedContentUrl( RpcTask.CREATE, this.apiKey, this._requestOptions, ); const headers = getHeaders(url); const response = await makeServerRequest( url, headers, JSON.stringify(newCachedContent), ); return response.json(); } /** * List all uploaded content caches */ async list(listParams?: ListParams): Promise { const url = new CachedContentUrl( RpcTask.LIST, this.apiKey, this._requestOptions, ); if (listParams?.pageSize) { url.appendParam("pageSize", listParams.pageSize.toString()); } if (listParams?.pageToken) { url.appendParam("pageToken", listParams.pageToken); } const headers = getHeaders(url); const response = await makeServerRequest(url, headers); return response.json(); } /** * Get a content cache */ async get(name: string): Promise { const url = new CachedContentUrl( RpcTask.GET, this.apiKey, this._requestOptions, ); url.appendPath(parseCacheName(name)); const headers = getHeaders(url); const response = await makeServerRequest(url, headers); return response.json(); } /** * Update an existing content cache */ async update( name: string, updateParams: CachedContentUpdateParams, ): Promise { const url = new CachedContentUrl( RpcTask.UPDATE, this.apiKey, this._requestOptions, ); url.appendPath(parseCacheName(name)); const headers = getHeaders(url); const formattedCachedContent: _CachedContentUpdateRequestFields = { ...updateParams.cachedContent, }; if (updateParams.cachedContent.ttlSeconds) { formattedCachedContent.ttl = updateParams.cachedContent.ttlSeconds.toString() + "s"; delete (formattedCachedContent as CachedContentCreateParams).ttlSeconds; } if (updateParams.updateMask) { url.appendParam( "update_mask", updateParams.updateMask.map((prop) => camelToSnake(prop)).join(","), ); } const response = await makeServerRequest( url, headers, JSON.stringify(formattedCachedContent), ); return response.json(); } /** * Delete content cache with given name */ async delete(name: string): Promise { const url = new CachedContentUrl( RpcTask.DELETE, this.apiKey, this._requestOptions, ); url.appendPath(parseCacheName(name)); const headers = getHeaders(url); await makeServerRequest(url, headers); } } /** * If cache name is prepended with "cachedContents/", remove prefix */ function parseCacheName(name: string): string { if (name.startsWith("cachedContents/")) { return name.split("cachedContents/")[1]; } if (!name) { throw new GoogleGenerativeAIError( `Invalid name ${name}. ` + `Must be in the format "cachedContents/name" or "name"`, ); } return name; } function camelToSnake(str: string): string { return str.replace(/[A-Z]/g, (letter) => `_${letter.toLowerCase()}`); }