Spaces:
Paused
Paused
| /** | |
| * @license | |
| * Copyright 2024 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| import { expect, use } from "chai"; | |
| import { match, restore, stub } from "sinon"; | |
| import * as sinonChai from "sinon-chai"; | |
| import * as chaiAsPromised from "chai-as-promised"; | |
| import { getMockResponse } from "../../test-utils/mock-response"; | |
| import * as request from "../requests/request"; | |
| import { generateContent } from "./generate-content"; | |
| import { | |
| GenerateContentRequest, | |
| HarmBlockThreshold, | |
| HarmCategory, | |
| } from "../../types"; | |
| use(sinonChai); | |
| use(chaiAsPromised); | |
| const fakeRequestParams: GenerateContentRequest = { | |
| contents: [{ parts: [{ text: "hello" }], role: "user" }], | |
| generationConfig: { | |
| topK: 16, | |
| }, | |
| safetySettings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, | |
| }, | |
| ], | |
| }; | |
| describe("generateContent()", () => { | |
| afterEach(() => { | |
| restore(); | |
| }); | |
| it("short response", async () => { | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text()).to.include("Helena"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match((value: string) => { | |
| return value.includes("contents"); | |
| }), | |
| ); | |
| }); | |
| it("long response", async () => { | |
| const mockResponse = getMockResponse("unary-success-basic-reply-long.json"); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text()).to.include("Use Freshly Ground Coffee"); | |
| expect(result.response.text()).to.include("30 minutes of brewing"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match.any, | |
| ); | |
| }); | |
| it("citations", async () => { | |
| const mockResponse = getMockResponse("unary-success-citations.json"); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text()).to.include("Quantum mechanics is"); | |
| expect( | |
| result.response.candidates[0].citationMetadata.citationSources.length, | |
| ).to.equal(1); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match.any, | |
| ); | |
| }); | |
| it("blocked prompt", async () => { | |
| const mockResponse = getMockResponse( | |
| "unary-failure-prompt-blocked-safety.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text).to.throw("SAFETY"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match.any, | |
| ); | |
| }); | |
| it("finishReason safety", async () => { | |
| const mockResponse = getMockResponse( | |
| "unary-failure-finish-reason-safety.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text).to.throw("SAFETY"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match.any, | |
| ); | |
| }); | |
| it("empty content", async () => { | |
| const mockResponse = getMockResponse("unary-failure-empty-content.json"); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text()).to.equal(""); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match.any, | |
| ); | |
| }); | |
| it("unknown enum - should ignore", async () => { | |
| const mockResponse = getMockResponse("unary-unknown-enum.json"); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const result = await generateContent("key", "model", fakeRequestParams); | |
| expect(result.response.text()).to.include("30 minutes of brewing"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "model", | |
| request.Task.GENERATE_CONTENT, | |
| "key", | |
| false, | |
| match.any, | |
| ); | |
| }); | |
| it("image rejected (400)", async () => { | |
| const mockResponse = getMockResponse("unary-failure-image-rejected.json"); | |
| const errorJson = await mockResponse.json(); | |
| const makeRequestStub = stub(request, "makeModelRequest").rejects( | |
| new Error(`[400 ] ${errorJson.error.message}`), | |
| ); | |
| await expect( | |
| generateContent("key", "model", fakeRequestParams), | |
| ).to.be.rejectedWith(/400.*invalid argument/); | |
| expect(makeRequestStub).to.be.called; | |
| }); | |
| }); | |