Spaces:
Paused
Paused
| /** | |
| * @license | |
| * Copyright 2024 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| import { expect, use } from "chai"; | |
| import { GenerativeModel } from "./generative-model"; | |
| import * as sinonChai from "sinon-chai"; | |
| import { | |
| CountTokensRequest, | |
| FunctionCallingMode, | |
| FunctionDeclarationSchemaType, | |
| HarmBlockThreshold, | |
| HarmCategory, | |
| } from "../../types"; | |
| import { getMockResponse } from "../../test-utils/mock-response"; | |
| import { match, restore, stub } from "sinon"; | |
| import * as request from "../requests/request"; | |
| use(sinonChai); | |
| describe("GenerativeModel", () => { | |
| it("handles plain model name", () => { | |
| const genModel = new GenerativeModel("apiKey", { model: "my-model" }); | |
| expect(genModel.model).to.equal("models/my-model"); | |
| }); | |
| it("handles prefixed model name", () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "models/my-model", | |
| }); | |
| expect(genModel.model).to.equal("models/my-model"); | |
| }); | |
| it("handles prefixed tuned model name", () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "tunedModels/my-model", | |
| }); | |
| expect(genModel.model).to.equal("tunedModels/my-model"); | |
| }); | |
| it("passes params through to generateContent", async () => { | |
| const genModel = new GenerativeModel( | |
| "apiKey", | |
| { | |
| model: "my-model", | |
| generationConfig: { | |
| temperature: 0, | |
| responseMimeType: "application/json", | |
| responseSchema: { | |
| type: FunctionDeclarationSchemaType.OBJECT, | |
| properties: { | |
| testField: { | |
| type: FunctionDeclarationSchemaType.STRING, | |
| properties: {}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| safetySettings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, | |
| }, | |
| ], | |
| tools: [{ functionDeclarations: [{ name: "myfunc" }] }], | |
| toolConfig: { | |
| functionCallingConfig: { mode: FunctionCallingMode.NONE }, | |
| }, | |
| systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, | |
| }, | |
| { | |
| apiVersion: "v6", | |
| }, | |
| ); | |
| expect(genModel.generationConfig?.temperature).to.equal(0); | |
| expect(genModel.generationConfig?.responseMimeType).to.equal( | |
| "application/json", | |
| ); | |
| expect(genModel.generationConfig?.responseSchema.type).to.equal( | |
| FunctionDeclarationSchemaType.OBJECT, | |
| ); | |
| expect( | |
| genModel.generationConfig?.responseSchema.properties.testField.type, | |
| ).to.equal(FunctionDeclarationSchemaType.STRING); | |
| expect(genModel.safetySettings?.length).to.equal(1); | |
| expect(genModel.tools?.length).to.equal(1); | |
| expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( | |
| FunctionCallingMode.NONE, | |
| ); | |
| expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel.generateContent("hello"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.GENERATE_CONTENT, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return ( | |
| value.includes("myfunc") && | |
| value.includes(FunctionCallingMode.NONE) && | |
| value.includes("be friendly") && | |
| value.includes("temperature") && | |
| value.includes("testField") && | |
| value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE) | |
| ); | |
| }), | |
| match((value) => { | |
| return value.apiVersion === "v6"; | |
| }), | |
| ); | |
| restore(); | |
| }); | |
| it("passes text-only systemInstruction through to generateContent", async () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "my-model", | |
| systemInstruction: "be friendly", | |
| }); | |
| expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel.generateContent("hello"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.GENERATE_CONTENT, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return value.includes("be friendly"); | |
| }), | |
| match.any, | |
| ); | |
| restore(); | |
| }); | |
| it("generateContent overrides model values", async () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "my-model", | |
| generationConfig: { | |
| temperature: 0, | |
| responseMimeType: "application/json", | |
| responseSchema: { | |
| type: FunctionDeclarationSchemaType.OBJECT, | |
| properties: { | |
| testField: { | |
| type: FunctionDeclarationSchemaType.STRING, | |
| properties: {}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| safetySettings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, | |
| }, | |
| ], | |
| tools: [{ functionDeclarations: [{ name: "myfunc" }] }], | |
| toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, | |
| systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, | |
| }); | |
| expect(genModel.tools?.length).to.equal(1); | |
| expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( | |
| FunctionCallingMode.NONE, | |
| ); | |
| expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel.generateContent({ | |
| generationConfig: { | |
| topK: 1, | |
| responseSchema: { | |
| type: FunctionDeclarationSchemaType.OBJECT, | |
| properties: { | |
| newTestField: { | |
| type: FunctionDeclarationSchemaType.STRING, | |
| properties: {}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| safetySettings: [ | |
| { | |
| category: HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, | |
| }, | |
| { | |
| category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| threshold: HarmBlockThreshold.BLOCK_NONE, | |
| }, | |
| ], | |
| contents: [{ role: "user", parts: [{ text: "hello" }] }], | |
| tools: [{ functionDeclarations: [{ name: "otherfunc" }] }], | |
| toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } }, | |
| systemInstruction: { role: "system", parts: [{ text: "be formal" }] }, | |
| }); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.GENERATE_CONTENT, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return ( | |
| value.includes("otherfunc") && | |
| value.includes(FunctionCallingMode.AUTO) && | |
| value.includes("be formal") && | |
| value.includes("topK") && | |
| value.includes("newTestField") && | |
| !value.includes("testField") && | |
| value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT) | |
| ); | |
| }), | |
| {}, | |
| ); | |
| restore(); | |
| }); | |
| it("passes requestOptions through to countTokens", async () => { | |
| const genModel = new GenerativeModel( | |
| "apiKey", | |
| { | |
| model: "my-model", | |
| systemInstruction: "you are a cat", | |
| }, | |
| { | |
| apiVersion: "v2000", | |
| }, | |
| ); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel.countTokens("hello"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.COUNT_TOKENS, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return value.includes("hello") && value.includes("you are a cat"); | |
| }), | |
| match((value) => { | |
| return value.apiVersion === "v2000"; | |
| }), | |
| ); | |
| restore(); | |
| }); | |
| it("passes params through to chat.sendMessage", async () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "my-model", | |
| tools: [{ functionDeclarations: [{ name: "myfunc" }] }], | |
| toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, | |
| systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, | |
| }); | |
| expect(genModel.tools?.length).to.equal(1); | |
| expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( | |
| FunctionCallingMode.NONE, | |
| ); | |
| expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel.startChat().sendMessage("hello"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.GENERATE_CONTENT, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return ( | |
| value.includes("myfunc") && | |
| value.includes(FunctionCallingMode.NONE) && | |
| value.includes("be friendly") | |
| ); | |
| }), | |
| {}, | |
| ); | |
| restore(); | |
| }); | |
| it("passes params through to chat.sendMessage", async () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "my-model", | |
| generationConfig: { | |
| temperature: 0, | |
| responseMimeType: "application/json", | |
| responseSchema: { | |
| type: FunctionDeclarationSchemaType.OBJECT, | |
| properties: { | |
| testField: { | |
| type: FunctionDeclarationSchemaType.STRING, | |
| properties: {}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, | |
| }); | |
| expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); | |
| expect(genModel.generationConfig.responseSchema.properties.testField).to | |
| .exist; | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel.startChat().sendMessage("hello"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.GENERATE_CONTENT, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return value.includes("be friendly") && value.includes("testField"); | |
| }), | |
| {}, | |
| ); | |
| restore(); | |
| }); | |
| it("startChat overrides model values", async () => { | |
| const genModel = new GenerativeModel("apiKey", { | |
| model: "my-model", | |
| generationConfig: { | |
| temperature: 0, | |
| responseMimeType: "application/json", | |
| responseSchema: { | |
| type: FunctionDeclarationSchemaType.OBJECT, | |
| properties: { | |
| testField: { | |
| type: FunctionDeclarationSchemaType.STRING, | |
| properties: {}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| tools: [{ functionDeclarations: [{ name: "myfunc" }] }], | |
| toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, | |
| systemInstruction: { role: "system", parts: [{ text: "be friendly" }] }, | |
| }); | |
| expect(genModel.generationConfig.responseSchema.properties.testField).to | |
| .exist; | |
| expect(genModel.tools?.length).to.equal(1); | |
| expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal( | |
| FunctionCallingMode.NONE, | |
| ); | |
| expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly"); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| await genModel | |
| .startChat({ | |
| tools: [{ functionDeclarations: [{ name: "otherfunc" }] }], | |
| generationConfig: { | |
| responseSchema: { | |
| type: FunctionDeclarationSchemaType.OBJECT, | |
| properties: { | |
| newTestField: { | |
| type: FunctionDeclarationSchemaType.STRING, | |
| properties: {}, | |
| }, | |
| }, | |
| }, | |
| }, | |
| toolConfig: { | |
| functionCallingConfig: { mode: FunctionCallingMode.AUTO }, | |
| }, | |
| systemInstruction: { role: "system", parts: [{ text: "be formal" }] }, | |
| }) | |
| .sendMessage("hello"); | |
| expect(makeRequestStub).to.be.calledWith( | |
| "models/my-model", | |
| request.Task.GENERATE_CONTENT, | |
| match.any, | |
| false, | |
| match((value: string) => { | |
| return ( | |
| value.includes("otherfunc") && | |
| value.includes(FunctionCallingMode.AUTO) && | |
| value.includes("be formal") && | |
| value.includes("newTestField") && | |
| !value.includes("testField") | |
| ); | |
| }), | |
| {}, | |
| ); | |
| restore(); | |
| }); | |
| it("countTokens errors if contents and generateContentRequest are both defined", async () => { | |
| const genModel = new GenerativeModel( | |
| "apiKey", | |
| { | |
| model: "my-model", | |
| }, | |
| { | |
| apiVersion: "v2000", | |
| }, | |
| ); | |
| const mockResponse = getMockResponse( | |
| "unary-success-basic-reply-short.json", | |
| ); | |
| const makeRequestStub = stub(request, "makeModelRequest").resolves( | |
| mockResponse as Response, | |
| ); | |
| const countTokensRequest: CountTokensRequest = { | |
| contents: [{ role: "user", parts: [{ text: "hello" }] }], | |
| generateContentRequest: { | |
| contents: [{ role: "user", parts: [{ text: "hello" }] }], | |
| }, | |
| }; | |
| await expect( | |
| genModel.countTokens(countTokensRequest), | |
| ).to.eventually.be.rejectedWith( | |
| "CountTokensRequest must have one of contents or generateContentRequest, not both.", | |
| ); | |
| expect(makeRequestStub).to.not.be.called; | |
| restore(); | |
| }); | |
| }); | |