Spaces:
Paused
Paused
| /** | |
| * @license | |
| * Copyright 2024 Google LLC | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| import { expect, use } from "chai"; | |
| import { GoogleAICacheManager } from "./cache-manager"; | |
| import * as sinonChai from "sinon-chai"; | |
| import * as chaiAsPromised from "chai-as-promised"; | |
| import { restore, stub } from "sinon"; | |
| import * as request from "./request"; | |
| import { RpcTask } from "./constants"; | |
| import { DEFAULT_API_VERSION } from "../requests/request"; | |
| use(sinonChai); | |
| use(chaiAsPromised); | |
| const FAKE_CONTENTS = [{ role: "user", parts: [{ text: "some text" }] }]; | |
| const FAKE_CACHE_NAME = "cachedContents/hash1234"; | |
| const fakeResponseJson: () => Promise<{}> = () => | |
| Promise.resolve({ name: FAKE_CACHE_NAME }); | |
| const model = "models/gemini-1.5-pro-001"; | |
| describe("GoogleAICacheManager", () => { | |
| afterEach(() => { | |
| restore(); | |
| }); | |
| it("stores api key", () => { | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| expect(cacheManager.apiKey).to.equal("apiKey"); | |
| }); | |
| it("passes create request info", async () => { | |
| const displayName = "a display name."; | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| const result = await cacheManager.create({ | |
| model, | |
| contents: FAKE_CONTENTS, | |
| ttlSeconds: 30, | |
| systemInstruction: "talk like a cat", | |
| tools: [{ functionDeclarations: [{ name: "myFn" }] }], | |
| toolConfig: { functionCallingConfig: {} }, | |
| displayName, | |
| }); | |
| expect(result.name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.CREATE); | |
| expect(makeRequestStub.args[0][1]).to.be.instanceOf(Headers); | |
| const requestBody = JSON.parse(makeRequestStub.args[0][2] as string); | |
| expect(requestBody.model).to.equal(model); | |
| expect(requestBody.contents).to.deep.equal(FAKE_CONTENTS); | |
| expect(requestBody.ttl).to.deep.equal("30s"); | |
| expect(requestBody.displayName).to.equal(displayName); | |
| expect(requestBody.systemInstruction.parts[0].text).to.equal( | |
| "talk like a cat", | |
| ); | |
| expect(requestBody.tools[0].functionDeclarations[0].name).to.equal("myFn"); | |
| expect(requestBody.toolConfig.functionCallingConfig).to.exist; | |
| }); | |
| it("create() formats unprefixed model name", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await cacheManager.create({ | |
| model, | |
| contents: FAKE_CONTENTS, | |
| }); | |
| const requestBody = JSON.parse(makeRequestStub.args[0][2] as string); | |
| expect(requestBody.model).to.equal(model); | |
| }); | |
| it("create() errors without a model name", async () => { | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await expect( | |
| cacheManager.create({ | |
| contents: FAKE_CONTENTS, | |
| }), | |
| ).to.be.rejectedWith("Cached content must contain a `model` field."); | |
| }); | |
| it("create() errors if ttlSeconds and expireTime are both provided", async () => { | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await expect( | |
| cacheManager.create({ | |
| model, | |
| contents: FAKE_CONTENTS, | |
| ttlSeconds: 40, | |
| expireTime: new Date().toISOString(), | |
| }), | |
| ).to.be.rejectedWith("You cannot specify"); | |
| }); | |
| it("passes create request info (with options)", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey", { | |
| apiVersion: "v3000", | |
| baseUrl: "http://mysite.com", | |
| }); | |
| await cacheManager.create({ | |
| model, | |
| contents: FAKE_CONTENTS, | |
| }); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.CREATE); | |
| expect(makeRequestStub.args[0][1]).to.be.instanceOf(Headers); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| "v3000/cachedContents", | |
| ); | |
| expect(makeRequestStub.args[0][0].toString()).to.match( | |
| /^http:\/\/mysite\.com/, | |
| ); | |
| }); | |
| it("passes update request info", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| const result = await cacheManager.update(FAKE_CACHE_NAME, { | |
| cachedContent: { | |
| ttlSeconds: 30, | |
| }, | |
| }); | |
| expect(result.name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.UPDATE); | |
| expect(makeRequestStub.args[0][1]).to.be.instanceOf(Headers); | |
| const requestBody = JSON.parse(makeRequestStub.args[0][2] as string); | |
| expect(requestBody.ttl).to.deep.equal("30s"); | |
| }); | |
| it("passes list request info", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => | |
| Promise.resolve({ cachedContents: [{ name: FAKE_CACHE_NAME }] }), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| const result = await cacheManager.list(); | |
| expect(result.cachedContents[0].name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.LIST); | |
| expect(makeRequestStub.args[0][0].toString()).to.match(/\/cachedContents$/); | |
| }); | |
| it("passes list request info with params", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => | |
| Promise.resolve({ cachedContents: [{ name: FAKE_CACHE_NAME }] }), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| const result = await cacheManager.list({ | |
| pageSize: 3, | |
| pageToken: "abc", | |
| }); | |
| expect(result.cachedContents[0].name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.LIST); | |
| expect(makeRequestStub.args[0][0].toString()).to.include("pageSize=3"); | |
| expect(makeRequestStub.args[0][0].toString()).to.include("pageToken=abc"); | |
| }); | |
| it("passes list request info with options", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => | |
| Promise.resolve({ cachedContents: [{ name: FAKE_CACHE_NAME }] }), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey", { | |
| apiVersion: "v3000", | |
| baseUrl: "http://mysite.com", | |
| }); | |
| const result = await cacheManager.list(); | |
| expect(result.cachedContents[0].name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.LIST); | |
| expect(makeRequestStub.args[0][0].toString()).to.match(/\/cachedContents$/); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| "v3000/cachedContents", | |
| ); | |
| expect(makeRequestStub.args[0][0].toString()).to.match( | |
| /^http:\/\/mysite\.com/, | |
| ); | |
| }); | |
| it("passes get request info with name prefix provided", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| const result = await cacheManager.get("cachedContents/hash1234"); | |
| expect(result.name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.GET); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| `${DEFAULT_API_VERSION}/cachedContents/hash1234`, | |
| ); | |
| }); | |
| it("passes get request info with no name prefix", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| const result = await cacheManager.get("hash1234"); | |
| expect(result.name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.GET); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| `${DEFAULT_API_VERSION}/cachedContents/hash1234`, | |
| ); | |
| }); | |
| it("passes getFile request info (with options)", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey", { | |
| apiVersion: "v3000", | |
| baseUrl: "http://mysite.com", | |
| }); | |
| const result = await cacheManager.get("hash1234"); | |
| expect(result.name).to.equal(FAKE_CACHE_NAME); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.GET); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| "v3000/cachedContents/hash1234", | |
| ); | |
| expect(makeRequestStub.args[0][0].toString()).to.match( | |
| /^http:\/\/mysite\.com/, | |
| ); | |
| }); | |
| it("get throws on bad name", async () => { | |
| stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: fakeResponseJson, | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await expect(cacheManager.get("")).to.be.rejectedWith("Invalid name"); | |
| }); | |
| it("passes delete request info (no prefix)", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => Promise.resolve({}), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await cacheManager.delete("hash1234"); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.DELETE); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| `${DEFAULT_API_VERSION}/cachedContents/hash1234`, | |
| ); | |
| }); | |
| it("passes delete request info (prefix)", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => Promise.resolve({}), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await cacheManager.delete("cachedContents/hash1234"); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.DELETE); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| `${DEFAULT_API_VERSION}/cachedContents/hash1234`, | |
| ); | |
| }); | |
| it("passes delete request info (with options)", async () => { | |
| const makeRequestStub = stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => Promise.resolve({}), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey", { | |
| apiVersion: "v3000", | |
| baseUrl: "http://mysite.com", | |
| }); | |
| await cacheManager.delete("hash1234"); | |
| expect(makeRequestStub.args[0][0].task).to.equal(RpcTask.DELETE); | |
| expect(makeRequestStub.args[0][0].toString()).to.include( | |
| "v3000/cachedContents/hash1234", | |
| ); | |
| expect(makeRequestStub.args[0][0].toString()).to.match( | |
| /^http:\/\/mysite\.com/, | |
| ); | |
| }); | |
| it("delete throws on bad name", async () => { | |
| stub(request, "makeServerRequest").resolves({ | |
| ok: true, | |
| json: () => Promise.resolve({}), | |
| } as Response); | |
| const cacheManager = new GoogleAICacheManager("apiKey"); | |
| await expect(cacheManager.delete("")).to.be.rejectedWith("Invalid name"); | |
| }); | |
| }); | |