/** * @license * Copyright 2024 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ import { aggregateResponses, getResponseStream, processStream, } from "./stream-reader"; import { expect, use } from "chai"; import { restore } from "sinon"; import * as sinonChai from "sinon-chai"; import { getChunkedStream, getMockResponseStreaming, } from "../../test-utils/mock-response"; import { BlockReason, FinishReason, GenerateContentResponse, HarmCategory, HarmProbability, } from "../../types"; use(sinonChai); describe("getResponseStream", () => { afterEach(() => { restore(); }); it("two lines", async () => { const src = [{ text: "A" }, { text: "B" }]; const inputStream = getChunkedStream( src .map((v) => JSON.stringify(v)) .map((v) => "data: " + v + "\r\n\r\n") .join(""), ).pipeThrough(new TextDecoderStream("utf8", { fatal: true })); const responseStream = getResponseStream<{ text: string }>(inputStream); const reader = responseStream.getReader(); const responses: Array<{ text: string }> = []; while (true) { const { done, value } = await reader.read(); if (done) { break; } responses.push(value); } expect(responses).to.deep.equal(src); }); }); describe("processStream", () => { afterEach(() => { restore(); }); it("streaming response - short", async () => { const fakeResponse = getMockResponseStreaming( "streaming-success-basic-reply-short.txt", ); const result = processStream(fakeResponse as Response); for await (const response of result.stream) { expect(response.text()).to.not.be.empty; } const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include("Cheyenne"); }); it("streaming response - long", async () => { const fakeResponse = getMockResponseStreaming( "streaming-success-basic-reply-long.txt", ); const result = processStream(fakeResponse as Response); for await (const response of result.stream) { expect(response.text()).to.not.be.empty; } const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include("**Cats:**"); expect(aggregatedResponse.text()).to.include("to their owners."); }); it("streaming response - long - big chunk", async () => { const fakeResponse = getMockResponseStreaming( "streaming-success-basic-reply-long.txt", 1e6, ); const result = processStream(fakeResponse as Response); for await (const response of result.stream) { expect(response.text()).to.not.be.empty; } const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include("**Cats:**"); expect(aggregatedResponse.text()).to.include("to their owners."); }); it("streaming response - utf8", async () => { const fakeResponse = getMockResponseStreaming("streaming-success-utf8.txt"); const result = processStream(fakeResponse as Response); for await (const response of result.stream) { expect(response.text()).to.not.be.empty; } const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include("秋风瑟瑟,叶落纷纷"); expect(aggregatedResponse.text()).to.include("家人围坐在一起"); }); it("streaming response - functioncall", async () => { const fakeResponse = getMockResponseStreaming( "streaming-success-function-call-short.txt", ); const result = processStream(fakeResponse as Response); for await (const response of result.stream) { expect(response.text()).to.be.empty; expect(response.functionCall()).to.be.deep.equal({ name: "getTemperature", args: { city: "San Jose" }, }); } const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.be.empty; expect(aggregatedResponse.functionCall()).to.be.deep.equal({ name: "getTemperature", args: { city: "San Jose" }, }); }); it("candidate had finishReason", async () => { const fakeResponse = getMockResponseStreaming( "streaming-failure-finish-reason-safety.txt", ); const result = processStream(fakeResponse as Response); const aggregatedResponse = await result.response; expect(aggregatedResponse.candidates?.[0].finishReason).to.equal("SAFETY"); expect(aggregatedResponse.text).to.throw("SAFETY"); for await (const response of result.stream) { expect(response.text).to.throw("SAFETY"); } }); it("prompt was blocked", async () => { const fakeResponse = getMockResponseStreaming( "streaming-failure-prompt-blocked-safety.txt", ); const result = processStream(fakeResponse as Response); const aggregatedResponse = await result.response; expect(aggregatedResponse.text).to.throw("SAFETY"); expect(aggregatedResponse.promptFeedback?.blockReason).to.equal("SAFETY"); for await (const response of result.stream) { expect(response.text).to.throw("SAFETY"); } }); it("empty content", async () => { const fakeResponse = getMockResponseStreaming( "streaming-failure-empty-content.txt", ); const result = processStream(fakeResponse as Response); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.equal(""); for await (const response of result.stream) { expect(response.text()).to.equal(""); } }); it("unknown enum - should ignore", async () => { const fakeResponse = getMockResponseStreaming("streaming-unknown-enum.txt"); const result = processStream(fakeResponse as Response); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include("Cats"); for await (const response of result.stream) { expect(response.text()).to.not.be.empty; } }); it("recitation ending with a missing content field", async () => { const fakeResponse = getMockResponseStreaming( "streaming-failure-recitation-no-content.txt", ); const result = processStream(fakeResponse as Response); const aggregatedResponse = await result.response; expect(aggregatedResponse.text).to.throw("RECITATION"); expect(aggregatedResponse.candidates[0].content.parts[0].text).to.include( "Copyrighted text goes here", ); for await (const response of result.stream) { if (response.candidates[0].finishReason !== FinishReason.RECITATION) { expect(response.text()).to.not.be.empty; } else { expect(response.text).to.throw("RECITATION"); } } }); it("handles citations", async () => { const fakeResponse = getMockResponseStreaming( "streaming-success-citations.txt", ); const result = processStream(fakeResponse as Response); const aggregatedResponse = await result.response; expect(aggregatedResponse.text()).to.include("Quantum mechanics is"); expect( aggregatedResponse.candidates[0].citationMetadata.citationSources.length, ).to.equal(2); let foundCitationMetadata = false; for await (const response of result.stream) { expect(response.text()).to.not.be.empty; if (response.candidates[0].citationMetadata) { foundCitationMetadata = true; } } expect(foundCitationMetadata).to.be.true; }); }); describe("aggregateResponses", () => { it("handles no candidates, and promptFeedback", () => { const responsesToAggregate: GenerateContentResponse[] = [ { promptFeedback: { blockReason: BlockReason.SAFETY, safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, probability: HarmProbability.LOW, }, ], }, }, ]; const response = aggregateResponses(responsesToAggregate); expect(response.candidates).to.not.exist; expect(response.promptFeedback.blockReason).to.equal(BlockReason.SAFETY); }); describe("multiple responses, has candidates", () => { let response: GenerateContentResponse; before(() => { const responsesToAggregate: GenerateContentResponse[] = [ { candidates: [ { index: 0, content: { role: "user", parts: [{ text: "hello." }], }, finishReason: FinishReason.STOP, finishMessage: "something", safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_HARASSMENT, probability: HarmProbability.NEGLIGIBLE, }, ], }, ], promptFeedback: { blockReason: BlockReason.SAFETY, safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, probability: HarmProbability.LOW, }, ], }, }, { candidates: [ { index: 0, content: { role: "user", parts: [{ text: "angry stuff" }], }, finishReason: FinishReason.STOP, finishMessage: "something", safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_HARASSMENT, probability: HarmProbability.NEGLIGIBLE, }, ], citationMetadata: { citationSources: [ { startIndex: 0, endIndex: 20, uri: "sourceurl", license: "", }, ], }, }, ], promptFeedback: { blockReason: BlockReason.OTHER, safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, probability: HarmProbability.HIGH, }, ], }, }, { candidates: [ { index: 0, content: { role: "user", parts: [{ text: "...more stuff" }], }, finishReason: FinishReason.MAX_TOKENS, finishMessage: "too many tokens", safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, probability: HarmProbability.MEDIUM, }, ], citationMetadata: { citationSources: [ { startIndex: 0, endIndex: 20, uri: "sourceurl", license: "", }, { startIndex: 150, endIndex: 155, uri: "sourceurl", license: "", }, ], }, }, ], promptFeedback: { blockReason: BlockReason.OTHER, safetyRatings: [ { category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, probability: HarmProbability.HIGH, }, ], }, }, ]; response = aggregateResponses(responsesToAggregate); }); it("aggregates text across responses", () => { expect(response.candidates.length).to.equal(1); expect( response.candidates[0].content.parts.map(({ text }) => text), ).to.deep.equal(["hello.", "angry stuff", "...more stuff"]); }); it("takes the last response's promptFeedback", () => { expect(response.promptFeedback.blockReason).to.equal(BlockReason.OTHER); }); it("takes the last response's finishReason", () => { expect(response.candidates[0].finishReason).to.equal( FinishReason.MAX_TOKENS, ); }); it("takes the last response's finishMessage", () => { expect(response.candidates[0].finishMessage).to.equal("too many tokens"); }); it("takes the last response's candidate safetyRatings", () => { expect(response.candidates[0].safetyRatings[0].category).to.equal( HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, ); expect(response.candidates[0].safetyRatings[0].probability).to.equal( HarmProbability.MEDIUM, ); }); it("collects all citationSources into one array", () => { expect( response.candidates[0].citationMetadata.citationSources.length, ).to.equal(2); expect( response.candidates[0].citationMetadata.citationSources[0].startIndex, ).to.equal(0); expect( response.candidates[0].citationMetadata.citationSources[1].startIndex, ).to.equal(150); }); }); });