UX-agent / backend /node_modules /@google /generative-ai /src /models /generative-model.test.ts
AUXteam's picture
Set Gemini API version to v1
8c741f6 verified
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { expect, use } from "chai";
import { GenerativeModel } from "./generative-model";
import * as sinonChai from "sinon-chai";
import {
CountTokensRequest,
FunctionCallingMode,
FunctionDeclarationSchemaType,
HarmBlockThreshold,
HarmCategory,
} from "../../types";
import { getMockResponse } from "../../test-utils/mock-response";
import { match, restore, stub } from "sinon";
import * as request from "../requests/request";
use(sinonChai);
describe("GenerativeModel", () => {
it("handles plain model name", () => {
const genModel = new GenerativeModel("apiKey", { model: "my-model" });
expect(genModel.model).to.equal("models/my-model");
});
it("handles prefixed model name", () => {
const genModel = new GenerativeModel("apiKey", {
model: "models/my-model",
});
expect(genModel.model).to.equal("models/my-model");
});
it("handles prefixed tuned model name", () => {
const genModel = new GenerativeModel("apiKey", {
model: "tunedModels/my-model",
});
expect(genModel.model).to.equal("tunedModels/my-model");
});
it("passes params through to generateContent", async () => {
const genModel = new GenerativeModel(
"apiKey",
{
model: "my-model",
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.NONE },
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
},
{
apiVersion: "v6",
},
);
expect(genModel.generationConfig?.temperature).to.equal(0);
expect(genModel.generationConfig?.responseMimeType).to.equal(
"application/json",
);
expect(genModel.generationConfig?.responseSchema.type).to.equal(
FunctionDeclarationSchemaType.OBJECT,
);
expect(
genModel.generationConfig?.responseSchema.properties.testField.type,
).to.equal(FunctionDeclarationSchemaType.STRING);
expect(genModel.safetySettings?.length).to.equal(1);
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
);
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel.generateContent("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes("myfunc") &&
value.includes(FunctionCallingMode.NONE) &&
value.includes("be friendly") &&
value.includes("temperature") &&
value.includes("testField") &&
value.includes(HarmBlockThreshold.BLOCK_LOW_AND_ABOVE)
);
}),
match((value) => {
return value.apiVersion === "v6";
}),
);
restore();
});
it("passes text-only systemInstruction through to generateContent", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
systemInstruction: "be friendly",
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel.generateContent("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return value.includes("be friendly");
}),
match.any,
);
restore();
});
it("generateContent overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
],
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
);
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel.generateContent({
generationConfig: {
topK: 1,
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
newTestField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
},
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: HarmBlockThreshold.BLOCK_NONE,
},
],
contents: [{ role: "user", parts: [{ text: "hello" }] }],
tools: [{ functionDeclarations: [{ name: "otherfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.AUTO } },
systemInstruction: { role: "system", parts: [{ text: "be formal" }] },
});
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes("otherfunc") &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes("be formal") &&
value.includes("topK") &&
value.includes("newTestField") &&
!value.includes("testField") &&
value.includes(HarmCategory.HARM_CATEGORY_HARASSMENT)
);
}),
{},
);
restore();
});
it("passes requestOptions through to countTokens", async () => {
const genModel = new GenerativeModel(
"apiKey",
{
model: "my-model",
systemInstruction: "you are a cat",
},
{
apiVersion: "v2000",
},
);
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel.countTokens("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.COUNT_TOKENS,
match.any,
false,
match((value: string) => {
return value.includes("hello") && value.includes("you are a cat");
}),
match((value) => {
return value.apiVersion === "v2000";
}),
);
restore();
});
it("passes params through to chat.sendMessage", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
);
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel.startChat().sendMessage("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes("myfunc") &&
value.includes(FunctionCallingMode.NONE) &&
value.includes("be friendly")
);
}),
{},
);
restore();
});
it("passes params through to chat.sendMessage", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel.startChat().sendMessage("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return value.includes("be friendly") && value.includes("testField");
}),
{},
);
restore();
});
it("startChat overrides model values", async () => {
const genModel = new GenerativeModel("apiKey", {
model: "my-model",
generationConfig: {
temperature: 0,
responseMimeType: "application/json",
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
testField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
tools: [{ functionDeclarations: [{ name: "myfunc" }] }],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: "system", parts: [{ text: "be friendly" }] },
});
expect(genModel.generationConfig.responseSchema.properties.testField).to
.exist;
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig.mode).to.equal(
FunctionCallingMode.NONE,
);
expect(genModel.systemInstruction?.parts[0].text).to.equal("be friendly");
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
await genModel
.startChat({
tools: [{ functionDeclarations: [{ name: "otherfunc" }] }],
generationConfig: {
responseSchema: {
type: FunctionDeclarationSchemaType.OBJECT,
properties: {
newTestField: {
type: FunctionDeclarationSchemaType.STRING,
properties: {},
},
},
},
},
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.AUTO },
},
systemInstruction: { role: "system", parts: [{ text: "be formal" }] },
})
.sendMessage("hello");
expect(makeRequestStub).to.be.calledWith(
"models/my-model",
request.Task.GENERATE_CONTENT,
match.any,
false,
match((value: string) => {
return (
value.includes("otherfunc") &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes("be formal") &&
value.includes("newTestField") &&
!value.includes("testField")
);
}),
{},
);
restore();
});
it("countTokens errors if contents and generateContentRequest are both defined", async () => {
const genModel = new GenerativeModel(
"apiKey",
{
model: "my-model",
},
{
apiVersion: "v2000",
},
);
const mockResponse = getMockResponse(
"unary-success-basic-reply-short.json",
);
const makeRequestStub = stub(request, "makeModelRequest").resolves(
mockResponse as Response,
);
const countTokensRequest: CountTokensRequest = {
contents: [{ role: "user", parts: [{ text: "hello" }] }],
generateContentRequest: {
contents: [{ role: "user", parts: [{ text: "hello" }] }],
},
};
await expect(
genModel.countTokens(countTokensRequest),
).to.eventually.be.rejectedWith(
"CountTokensRequest must have one of contents or generateContentRequest, not both.",
);
expect(makeRequestStub).to.not.be.called;
restore();
});
});