AUXteam's picture
Set Gemini API version to v1
8c741f6 verified
/**
* @license
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { expect, use } from "chai";
import { GoogleAICacheManager } from "./cache-manager";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import { restore, stub } from "sinon";
import * as request from "./request";
import { RpcTask } from "./constants";
import { DEFAULT_API_VERSION } from "../requests/request";
use(sinonChai);
use(chaiAsPromised);
const FAKE_CONTENTS = [{ role: "user", parts: [{ text: "some text" }] }];
const FAKE_CACHE_NAME = "cachedContents/hash1234";
const fakeResponseJson: () => Promise<{}> = () =>
Promise.resolve({ name: FAKE_CACHE_NAME });
const model = "models/gemini-1.5-pro-001";
describe("GoogleAICacheManager", () => {
afterEach(() => {
restore();
});
it("stores api key", () => {
const cacheManager = new GoogleAICacheManager("apiKey");
expect(cacheManager.apiKey).to.equal("apiKey");
});
it("passes create request info", async () => {
const displayName = "a display name.";
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
const result = await cacheManager.create({
model,
contents: FAKE_CONTENTS,
ttlSeconds: 30,
systemInstruction: "talk like a cat",
tools: [{ functionDeclarations: [{ name: "myFn" }] }],
toolConfig: { functionCallingConfig: {} },
displayName,
});
expect(result.name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.CREATE);
expect(makeRequestStub.args[0][1]).to.be.instanceOf(Headers);
const requestBody = JSON.parse(makeRequestStub.args[0][2] as string);
expect(requestBody.model).to.equal(model);
expect(requestBody.contents).to.deep.equal(FAKE_CONTENTS);
expect(requestBody.ttl).to.deep.equal("30s");
expect(requestBody.displayName).to.equal(displayName);
expect(requestBody.systemInstruction.parts[0].text).to.equal(
"talk like a cat",
);
expect(requestBody.tools[0].functionDeclarations[0].name).to.equal("myFn");
expect(requestBody.toolConfig.functionCallingConfig).to.exist;
});
it("create() formats unprefixed model name", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
await cacheManager.create({
model,
contents: FAKE_CONTENTS,
});
const requestBody = JSON.parse(makeRequestStub.args[0][2] as string);
expect(requestBody.model).to.equal(model);
});
it("create() errors without a model name", async () => {
const cacheManager = new GoogleAICacheManager("apiKey");
await expect(
cacheManager.create({
contents: FAKE_CONTENTS,
}),
).to.be.rejectedWith("Cached content must contain a `model` field.");
});
it("create() errors if ttlSeconds and expireTime are both provided", async () => {
const cacheManager = new GoogleAICacheManager("apiKey");
await expect(
cacheManager.create({
model,
contents: FAKE_CONTENTS,
ttlSeconds: 40,
expireTime: new Date().toISOString(),
}),
).to.be.rejectedWith("You cannot specify");
});
it("passes create request info (with options)", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey", {
apiVersion: "v3000",
baseUrl: "http://mysite.com",
});
await cacheManager.create({
model,
contents: FAKE_CONTENTS,
});
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.CREATE);
expect(makeRequestStub.args[0][1]).to.be.instanceOf(Headers);
expect(makeRequestStub.args[0][0].toString()).to.include(
"v3000/cachedContents",
);
expect(makeRequestStub.args[0][0].toString()).to.match(
/^http:\/\/mysite\.com/,
);
});
it("passes update request info", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
const result = await cacheManager.update(FAKE_CACHE_NAME, {
cachedContent: {
ttlSeconds: 30,
},
});
expect(result.name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.UPDATE);
expect(makeRequestStub.args[0][1]).to.be.instanceOf(Headers);
const requestBody = JSON.parse(makeRequestStub.args[0][2] as string);
expect(requestBody.ttl).to.deep.equal("30s");
});
it("passes list request info", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: () =>
Promise.resolve({ cachedContents: [{ name: FAKE_CACHE_NAME }] }),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
const result = await cacheManager.list();
expect(result.cachedContents[0].name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.LIST);
expect(makeRequestStub.args[0][0].toString()).to.match(/\/cachedContents$/);
});
it("passes list request info with params", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: () =>
Promise.resolve({ cachedContents: [{ name: FAKE_CACHE_NAME }] }),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
const result = await cacheManager.list({
pageSize: 3,
pageToken: "abc",
});
expect(result.cachedContents[0].name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.LIST);
expect(makeRequestStub.args[0][0].toString()).to.include("pageSize=3");
expect(makeRequestStub.args[0][0].toString()).to.include("pageToken=abc");
});
it("passes list request info with options", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: () =>
Promise.resolve({ cachedContents: [{ name: FAKE_CACHE_NAME }] }),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey", {
apiVersion: "v3000",
baseUrl: "http://mysite.com",
});
const result = await cacheManager.list();
expect(result.cachedContents[0].name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.LIST);
expect(makeRequestStub.args[0][0].toString()).to.match(/\/cachedContents$/);
expect(makeRequestStub.args[0][0].toString()).to.include(
"v3000/cachedContents",
);
expect(makeRequestStub.args[0][0].toString()).to.match(
/^http:\/\/mysite\.com/,
);
});
it("passes get request info with name prefix provided", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
const result = await cacheManager.get("cachedContents/hash1234");
expect(result.name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.GET);
expect(makeRequestStub.args[0][0].toString()).to.include(
`${DEFAULT_API_VERSION}/cachedContents/hash1234`,
);
});
it("passes get request info with no name prefix", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
const result = await cacheManager.get("hash1234");
expect(result.name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.GET);
expect(makeRequestStub.args[0][0].toString()).to.include(
`${DEFAULT_API_VERSION}/cachedContents/hash1234`,
);
});
it("passes getFile request info (with options)", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey", {
apiVersion: "v3000",
baseUrl: "http://mysite.com",
});
const result = await cacheManager.get("hash1234");
expect(result.name).to.equal(FAKE_CACHE_NAME);
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.GET);
expect(makeRequestStub.args[0][0].toString()).to.include(
"v3000/cachedContents/hash1234",
);
expect(makeRequestStub.args[0][0].toString()).to.match(
/^http:\/\/mysite\.com/,
);
});
it("get throws on bad name", async () => {
stub(request, "makeServerRequest").resolves({
ok: true,
json: fakeResponseJson,
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
await expect(cacheManager.get("")).to.be.rejectedWith("Invalid name");
});
it("passes delete request info (no prefix)", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: () => Promise.resolve({}),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
await cacheManager.delete("hash1234");
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.DELETE);
expect(makeRequestStub.args[0][0].toString()).to.include(
`${DEFAULT_API_VERSION}/cachedContents/hash1234`,
);
});
it("passes delete request info (prefix)", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: () => Promise.resolve({}),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
await cacheManager.delete("cachedContents/hash1234");
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.DELETE);
expect(makeRequestStub.args[0][0].toString()).to.include(
`${DEFAULT_API_VERSION}/cachedContents/hash1234`,
);
});
it("passes delete request info (with options)", async () => {
const makeRequestStub = stub(request, "makeServerRequest").resolves({
ok: true,
json: () => Promise.resolve({}),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey", {
apiVersion: "v3000",
baseUrl: "http://mysite.com",
});
await cacheManager.delete("hash1234");
expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.DELETE);
expect(makeRequestStub.args[0][0].toString()).to.include(
"v3000/cachedContents/hash1234",
);
expect(makeRequestStub.args[0][0].toString()).to.match(
/^http:\/\/mysite\.com/,
);
});
it("delete throws on bad name", async () => {
stub(request, "makeServerRequest").resolves({
ok: true,
json: () => Promise.resolve({}),
} as Response);
const cacheManager = new GoogleAICacheManager("apiKey");
await expect(cacheManager.delete("")).to.be.rejectedWith("Invalid name");
});
});