/** * @license * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { generateContent, generateContentStream, } from "../methods/generate-content"; import { BatchEmbedContentsRequest, BatchEmbedContentsResponse, CachedContent, Content, CountTokensRequest, CountTokensResponse, EmbedContentRequest, EmbedContentResponse, GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, GenerationConfig, ModelParams, Part, RequestOptions, SafetySetting, StartChatParams, Tool, ToolConfig, } from "../../types"; import { ChatSession } from "../methods/chat-session"; import { countTokens } from "../methods/count-tokens"; import { batchEmbedContents, embedContent } from "../methods/embed-content"; import { formatCountTokensInput, formatEmbedContentInput, formatGenerateContentInput, formatSystemInstruction, } from "../requests/request-helpers"; /** * Class for generative model APIs. * @public */ export class GenerativeModel { model: string; generationConfig: GenerationConfig; safetySettings: SafetySetting[]; requestOptions: RequestOptions; tools?: Tool[]; toolConfig?: ToolConfig; systemInstruction?: Content; cachedContent: CachedContent; constructor( public apiKey: string, modelParams: ModelParams, requestOptions?: RequestOptions, ) { if (modelParams.model.includes("/")) { // Models may be named "models/model-name" or "tunedModels/model-name" this.model = modelParams.model; } else { // If path is not included, assume it's a non-tuned model. this.model = `models/${modelParams.model}`; } this.generationConfig = modelParams.generationConfig || {}; this.safetySettings = modelParams.safetySettings || []; this.tools = modelParams.tools; this.toolConfig = modelParams.toolConfig; this.systemInstruction = formatSystemInstruction( modelParams.systemInstruction, ); this.cachedContent = modelParams.cachedContent; this.requestOptions = requestOptions || {}; } /** * Makes a single non-streaming call to the model * and returns an object containing a single {@link GenerateContentResponse}. */ async generateContent( request: GenerateContentRequest | string | Array, ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContent( this.apiKey, this.model, { generationConfig: this.generationConfig, safetySettings: this.safetySettings, tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, cachedContent: this.cachedContent?.name, ...formattedParams, }, this.requestOptions, ); } /** * Makes a single streaming call to the model * and returns an object containing an iterable stream that iterates * over all chunks in the streaming response as well as * a promise that returns the final aggregated response. */ async generateContentStream( request: GenerateContentRequest | string | Array, ): Promise { const formattedParams = formatGenerateContentInput(request); return generateContentStream( this.apiKey, this.model, { generationConfig: this.generationConfig, safetySettings: this.safetySettings, tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, cachedContent: this.cachedContent?.name, ...formattedParams, }, this.requestOptions, ); } /** * Gets a new {@link ChatSession} instance which can be used for * multi-turn chats. */ startChat(startChatParams?: StartChatParams): ChatSession { return new ChatSession( this.apiKey, this.model, { generationConfig: this.generationConfig, safetySettings: this.safetySettings, tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, cachedContent: this.cachedContent?.name, ...startChatParams, }, this.requestOptions, ); } /** * Counts the tokens in the provided request. */ async countTokens( request: CountTokensRequest | string | Array, ): Promise { const formattedParams = formatCountTokensInput(request, { model: this.model, generationConfig: this.generationConfig, safetySettings: this.safetySettings, tools: this.tools, toolConfig: this.toolConfig, systemInstruction: this.systemInstruction, cachedContent: this.cachedContent, }); return countTokens( this.apiKey, this.model, formattedParams, this.requestOptions, ); } /** * Embeds the provided content. */ async embedContent( request: EmbedContentRequest | string | Array, ): Promise { const formattedParams = formatEmbedContentInput(request); return embedContent( this.apiKey, this.model, formattedParams, this.requestOptions, ); } /** * Embeds an array of {@link EmbedContentRequest}s. */ async batchEmbedContents( batchEmbedContentRequest: BatchEmbedContentsRequest, ): Promise { return batchEmbedContents( this.apiKey, this.model, batchEmbedContentRequest, this.requestOptions, ); } }