/** * @license * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { EnhancedGenerateContentResponse, GenerateContentCandidate, GenerateContentResponse, GenerateContentStreamResult, Part, } from "../../types"; import { GoogleGenerativeAIError } from "../errors"; import { addHelpers } from "./response-helpers"; const responseLineRE = /^data\: (.*)(?:\n\n|\r\r|\r\n\r\n)/; /** * Process a response.body stream from the backend and return an * iterator that provides one complete GenerateContentResponse at a time * and a promise that resolves with a single aggregated * GenerateContentResponse. * * @param response - Response from a fetch call */ export function processStream(response: Response): GenerateContentStreamResult { const inputStream = response.body!.pipeThrough( new TextDecoderStream("utf8", { fatal: true }), ); const responseStream = getResponseStream(inputStream); const [stream1, stream2] = responseStream.tee(); return { stream: generateResponseSequence(stream1), response: getResponsePromise(stream2), }; } async function getResponsePromise( stream: ReadableStream, ): Promise { const allResponses: GenerateContentResponse[] = []; const reader = stream.getReader(); while (true) { const { done, value } = await reader.read(); if (done) { return addHelpers(aggregateResponses(allResponses)); } allResponses.push(value); } } async function* generateResponseSequence( stream: ReadableStream, ): AsyncGenerator { const reader = stream.getReader(); while (true) { const { value, done } = await reader.read(); if (done) { break; } yield addHelpers(value); } } /** * Reads a raw stream from the fetch response and join incomplete * chunks, returning a new stream that provides a single complete * GenerateContentResponse in each iteration. */ export function getResponseStream( inputStream: ReadableStream, ): ReadableStream { const reader = inputStream.getReader(); const stream = new ReadableStream({ start(controller) { let currentText = ""; return pump(); function pump(): Promise<(() => Promise) | undefined> { return reader.read().then(({ value, done }) => { if (done) { if (currentText.trim()) { controller.error( new GoogleGenerativeAIError("Failed to parse stream"), ); return; } controller.close(); return; } currentText += value; let match = currentText.match(responseLineRE); let parsedResponse: T; while (match) { try { parsedResponse = JSON.parse(match[1]); } catch (e) { controller.error( new GoogleGenerativeAIError( `Error parsing JSON response: "${match[1]}"`, ), ); return; } controller.enqueue(parsedResponse); currentText = currentText.substring(match[0].length); match = currentText.match(responseLineRE); } return pump(); }); } }, }); return stream; } /** * Aggregates an array of `GenerateContentResponse`s into a single * GenerateContentResponse. */ export function aggregateResponses( responses: GenerateContentResponse[], ): GenerateContentResponse { const lastResponse = responses[responses.length - 1]; const aggregatedResponse: GenerateContentResponse = { promptFeedback: lastResponse?.promptFeedback, }; for (const response of responses) { if (response.candidates) { for (const candidate of response.candidates) { const i = candidate.index; if (!aggregatedResponse.candidates) { aggregatedResponse.candidates = []; } if (!aggregatedResponse.candidates[i]) { aggregatedResponse.candidates[i] = { index: candidate.index, } as GenerateContentCandidate; } // Keep overwriting, the last one will be final aggregatedResponse.candidates[i].citationMetadata = candidate.citationMetadata; aggregatedResponse.candidates[i].finishReason = candidate.finishReason; aggregatedResponse.candidates[i].finishMessage = candidate.finishMessage; aggregatedResponse.candidates[i].safetyRatings = candidate.safetyRatings; /** * Candidates should always have content and parts, but this handles * possible malformed responses. */ if (candidate.content && candidate.content.parts) { if (!aggregatedResponse.candidates[i].content) { aggregatedResponse.candidates[i].content = { role: candidate.content.role || "user", parts: [], }; } const newPart: Partial = {}; for (const part of candidate.content.parts) { if (part.text) { newPart.text = part.text; } if (part.functionCall) { newPart.functionCall = part.functionCall; } if (part.executableCode) { newPart.executableCode = part.executableCode; } if (part.codeExecutionResult) { newPart.codeExecutionResult = part.codeExecutionResult; } if (Object.keys(newPart).length === 0) { newPart.text = ""; } aggregatedResponse.candidates[i].content.parts.push( newPart as Part, ); } } } } if (response.usageMetadata) { aggregatedResponse.usageMetadata = response.usageMetadata; } } return aggregatedResponse; }