Spaces:
Paused
Paused
| /** | |
| * @license | |
| * Copyright 2024 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| import { | |
| aggregateResponses, | |
| getResponseStream, | |
| processStream, | |
| } from "./stream-reader"; | |
| import { expect, use } from "chai"; | |
| import { restore } from "sinon"; | |
| import * as sinonChai from "sinon-chai"; | |
| import { | |
| getChunkedStream, | |
| getMockResponseStreaming, | |
| } from "../../test-utils/mock-response"; | |
| import { | |
| BlockReason, | |
| FinishReason, | |
| GenerateContentResponse, | |
| HarmCategory, | |
| HarmProbability, | |
| } from "../../types"; | |
| use(sinonChai); | |
| describe("getResponseStream", () => { | |
| afterEach(() => { | |
| restore(); | |
| }); | |
| it("two lines", async () => { | |
| const src = [{ text: "A" }, { text: "B" }]; | |
| const inputStream = getChunkedStream( | |
| src | |
| .map((v) => JSON.stringify(v)) | |
| .map((v) => "data: " + v + "\r\n\r\n") | |
| .join(""), | |
| ).pipeThrough(new TextDecoderStream("utf8", { fatal: true })); | |
| const responseStream = getResponseStream<{ text: string }>(inputStream); | |
| const reader = responseStream.getReader(); | |
| const responses: Array<{ text: string }> = []; | |
| while (true) { | |
| const { done, value } = await reader.read(); | |
| if (done) { | |
| break; | |
| } | |
| responses.push(value); | |
| } | |
| expect(responses).to.deep.equal(src); | |
| }); | |
| }); | |
| describe("processStream", () => { | |
| afterEach(() => { | |
| restore(); | |
| }); | |
| it("streaming response - short", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-success-basic-reply-short.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.not.be.empty; | |
| } | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.include("Cheyenne"); | |
| }); | |
| it("streaming response - long", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-success-basic-reply-long.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.not.be.empty; | |
| } | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.include("**Cats:**"); | |
| expect(aggregatedResponse.text()).to.include("to their owners."); | |
| }); | |
| it("streaming response - long - big chunk", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-success-basic-reply-long.txt", | |
| 1e6, | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.not.be.empty; | |
| } | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.include("**Cats:**"); | |
| expect(aggregatedResponse.text()).to.include("to their owners."); | |
| }); | |
| it("streaming response - utf8", async () => { | |
| const fakeResponse = getMockResponseStreaming("streaming-success-utf8.txt"); | |
| const result = processStream(fakeResponse as Response); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.not.be.empty; | |
| } | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.include("秋风瑟瑟,叶落纷纷"); | |
| expect(aggregatedResponse.text()).to.include("家人围坐在一起"); | |
| }); | |
| it("streaming response - functioncall", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-success-function-call-short.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.be.empty; | |
| expect(response.functionCall()).to.be.deep.equal({ | |
| name: "getTemperature", | |
| args: { city: "San Jose" }, | |
| }); | |
| } | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.be.empty; | |
| expect(aggregatedResponse.functionCall()).to.be.deep.equal({ | |
| name: "getTemperature", | |
| args: { city: "San Jose" }, | |
| }); | |
| }); | |
| it("candidate had finishReason", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-failure-finish-reason-safety.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.candidates?.[0].finishReason).to.equal("SAFETY"); | |
| expect(aggregatedResponse.text).to.throw("SAFETY"); | |
| for await (const response of result.stream) { | |
| expect(response.text).to.throw("SAFETY"); | |
| } | |
| }); | |
| it("prompt was blocked", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-failure-prompt-blocked-safety.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text).to.throw("SAFETY"); | |
| expect(aggregatedResponse.promptFeedback?.blockReason).to.equal("SAFETY"); | |
| for await (const response of result.stream) { | |
| expect(response.text).to.throw("SAFETY"); | |
| } | |
| }); | |
| it("empty content", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-failure-empty-content.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.equal(""); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.equal(""); | |
| } | |
| }); | |
| it("unknown enum - should ignore", async () => { | |
| const fakeResponse = getMockResponseStreaming("streaming-unknown-enum.txt"); | |
| const result = processStream(fakeResponse as Response); | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.include("Cats"); | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.not.be.empty; | |
| } | |
| }); | |
| it("recitation ending with a missing content field", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-failure-recitation-no-content.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text).to.throw("RECITATION"); | |
| expect(aggregatedResponse.candidates[0].content.parts[0].text).to.include( | |
| "Copyrighted text goes here", | |
| ); | |
| for await (const response of result.stream) { | |
| if (response.candidates[0].finishReason !== FinishReason.RECITATION) { | |
| expect(response.text()).to.not.be.empty; | |
| } else { | |
| expect(response.text).to.throw("RECITATION"); | |
| } | |
| } | |
| }); | |
| it("handles citations", async () => { | |
| const fakeResponse = getMockResponseStreaming( | |
| "streaming-success-citations.txt", | |
| ); | |
| const result = processStream(fakeResponse as Response); | |
| const aggregatedResponse = await result.response; | |
| expect(aggregatedResponse.text()).to.include("Quantum mechanics is"); | |
| expect( | |
| aggregatedResponse.candidates[0].citationMetadata.citationSources.length, | |
| ).to.equal(2); | |
| let foundCitationMetadata = false; | |
| for await (const response of result.stream) { | |
| expect(response.text()).to.not.be.empty; | |
| if (response.candidates[0].citationMetadata) { | |
| foundCitationMetadata = true; | |
| } | |
| } | |
| expect(foundCitationMetadata).to.be.true; | |
| }); | |
| }); | |
| describe("aggregateResponses", () => { | |
| it("handles no candidates, and promptFeedback", () => { | |
| const responsesToAggregate: GenerateContentResponse[] = [ | |
| { | |
| promptFeedback: { | |
| blockReason: BlockReason.SAFETY, | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| probability: HarmProbability.LOW, | |
| }, | |
| ], | |
| }, | |
| }, | |
| ]; | |
| const response = aggregateResponses(responsesToAggregate); | |
| expect(response.candidates).to.not.exist; | |
| expect(response.promptFeedback.blockReason).to.equal(BlockReason.SAFETY); | |
| }); | |
| describe("multiple responses, has candidates", () => { | |
| let response: GenerateContentResponse; | |
| before(() => { | |
| const responsesToAggregate: GenerateContentResponse[] = [ | |
| { | |
| candidates: [ | |
| { | |
| index: 0, | |
| content: { | |
| role: "user", | |
| parts: [{ text: "hello." }], | |
| }, | |
| finishReason: FinishReason.STOP, | |
| finishMessage: "something", | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| probability: HarmProbability.NEGLIGIBLE, | |
| }, | |
| ], | |
| }, | |
| ], | |
| promptFeedback: { | |
| blockReason: BlockReason.SAFETY, | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| probability: HarmProbability.LOW, | |
| }, | |
| ], | |
| }, | |
| }, | |
| { | |
| candidates: [ | |
| { | |
| index: 0, | |
| content: { | |
| role: "user", | |
| parts: [{ text: "angry stuff" }], | |
| }, | |
| finishReason: FinishReason.STOP, | |
| finishMessage: "something", | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| probability: HarmProbability.NEGLIGIBLE, | |
| }, | |
| ], | |
| citationMetadata: { | |
| citationSources: [ | |
| { | |
| startIndex: 0, | |
| endIndex: 20, | |
| uri: "sourceurl", | |
| license: "", | |
| }, | |
| ], | |
| }, | |
| }, | |
| ], | |
| promptFeedback: { | |
| blockReason: BlockReason.OTHER, | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| probability: HarmProbability.HIGH, | |
| }, | |
| ], | |
| }, | |
| }, | |
| { | |
| candidates: [ | |
| { | |
| index: 0, | |
| content: { | |
| role: "user", | |
| parts: [{ text: "...more stuff" }], | |
| }, | |
| finishReason: FinishReason.MAX_TOKENS, | |
| finishMessage: "too many tokens", | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| probability: HarmProbability.MEDIUM, | |
| }, | |
| ], | |
| citationMetadata: { | |
| citationSources: [ | |
| { | |
| startIndex: 0, | |
| endIndex: 20, | |
| uri: "sourceurl", | |
| license: "", | |
| }, | |
| { | |
| startIndex: 150, | |
| endIndex: 155, | |
| uri: "sourceurl", | |
| license: "", | |
| }, | |
| ], | |
| }, | |
| }, | |
| ], | |
| promptFeedback: { | |
| blockReason: BlockReason.OTHER, | |
| safetyRatings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| probability: HarmProbability.HIGH, | |
| }, | |
| ], | |
| }, | |
| }, | |
| ]; | |
| response = aggregateResponses(responsesToAggregate); | |
| }); | |
| it("aggregates text across responses", () => { | |
| expect(response.candidates.length).to.equal(1); | |
| expect( | |
| response.candidates[0].content.parts.map(({ text }) => text), | |
| ).to.deep.equal(["hello.", "angry stuff", "...more stuff"]); | |
| }); | |
| it("takes the last response's promptFeedback", () => { | |
| expect(response.promptFeedback.blockReason).to.equal(BlockReason.OTHER); | |
| }); | |
| it("takes the last response's finishReason", () => { | |
| expect(response.candidates[0].finishReason).to.equal( | |
| FinishReason.MAX_TOKENS, | |
| ); | |
| }); | |
| it("takes the last response's finishMessage", () => { | |
| expect(response.candidates[0].finishMessage).to.equal("too many tokens"); | |
| }); | |
| it("takes the last response's candidate safetyRatings", () => { | |
| expect(response.candidates[0].safetyRatings[0].category).to.equal( | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| ); | |
| expect(response.candidates[0].safetyRatings[0].probability).to.equal( | |
| HarmProbability.MEDIUM, | |
| ); | |
| }); | |
| it("collects all citationSources into one array", () => { | |
| expect( | |
| response.candidates[0].citationMetadata.citationSources.length, | |
| ).to.equal(2); | |
| expect( | |
| response.candidates[0].citationMetadata.citationSources[0].startIndex, | |
| ).to.equal(0); | |
| expect( | |
| response.candidates[0].citationMetadata.citationSources[1].startIndex, | |
| ).to.equal(150); | |
| }); | |
| }); | |
| }); | |