| | import { |
| | VertexAI, |
| | HarmCategory, |
| | HarmBlockThreshold, |
| | type Content, |
| | type TextPart, |
| | } from "@google-cloud/vertexai"; |
| | import type { Endpoint } from "../endpoints"; |
| | import { z } from "zod"; |
| | import type { Message } from "$lib/types/Message"; |
| | import type { TextGenerationStreamOutput } from "@huggingface/inference"; |
| |
|
| | export const endpointVertexParametersSchema = z.object({ |
| | weight: z.number().int().positive().default(1), |
| | model: z.any(), |
| | type: z.literal("vertex"), |
| | location: z.string().default("europe-west1"), |
| | project: z.string(), |
| | apiEndpoint: z.string().optional(), |
| | safetyThreshold: z |
| | .enum([ |
| | HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, |
| | HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, |
| | HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, |
| | HarmBlockThreshold.BLOCK_NONE, |
| | HarmBlockThreshold.BLOCK_ONLY_HIGH, |
| | ]) |
| | .optional(), |
| | }); |
| |
|
| | export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint { |
| | const { project, location, model, apiEndpoint, safetyThreshold } = |
| | endpointVertexParametersSchema.parse(input); |
| |
|
| | const vertex_ai = new VertexAI({ |
| | project, |
| | location, |
| | apiEndpoint, |
| | }); |
| |
|
| | return async ({ messages, preprompt, generateSettings }) => { |
| | const generativeModel = vertex_ai.getGenerativeModel({ |
| | model: model.id ?? model.name, |
| | safetySettings: safetyThreshold |
| | ? [ |
| | { |
| | category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, |
| | threshold: safetyThreshold, |
| | }, |
| | { |
| | category: HarmCategory.HARM_CATEGORY_HARASSMENT, |
| | threshold: safetyThreshold, |
| | }, |
| | { |
| | category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, |
| | threshold: safetyThreshold, |
| | }, |
| | { |
| | category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, |
| | threshold: safetyThreshold, |
| | }, |
| | { |
| | category: HarmCategory.HARM_CATEGORY_UNSPECIFIED, |
| | threshold: safetyThreshold, |
| | }, |
| | ] |
| | : undefined, |
| | generationConfig: { |
| | maxOutputTokens: generateSettings?.max_new_tokens ?? 4096, |
| | stopSequences: generateSettings?.stop, |
| | temperature: generateSettings?.temperature ?? 1, |
| | }, |
| | }); |
| |
|
| | |
| | let systemMessage = preprompt; |
| | if (messages[0].from === "system") { |
| | systemMessage = messages[0].content; |
| | messages.shift(); |
| | } |
| |
|
| | const vertexMessages = messages.map(({ from, content }: Omit<Message, "id">): Content => { |
| | return { |
| | role: from === "user" ? "user" : "model", |
| | parts: [ |
| | { |
| | text: content, |
| | }, |
| | ], |
| | }; |
| | }); |
| |
|
| | const result = await generativeModel.generateContentStream({ |
| | contents: vertexMessages, |
| | systemInstruction: systemMessage |
| | ? { |
| | role: "system", |
| | parts: [ |
| | { |
| | text: systemMessage, |
| | }, |
| | ], |
| | } |
| | : undefined, |
| | }); |
| |
|
| | let tokenId = 0; |
| | return (async function* () { |
| | let generatedText = ""; |
| |
|
| | for await (const data of result.stream) { |
| | if (!data?.candidates?.length) break; |
| |
|
| | const candidate = data.candidates[0]; |
| | if (!candidate.content?.parts?.length) continue; |
| |
|
| | const firstPart = candidate.content.parts.find((part) => "text" in part) as |
| | | TextPart |
| | | undefined; |
| | if (!firstPart) continue; |
| |
|
| | const isLastChunk = !!candidate.finishReason; |
| |
|
| | const content = firstPart.text; |
| | generatedText += content; |
| | const output: TextGenerationStreamOutput = { |
| | token: { |
| | id: tokenId++, |
| | text: content, |
| | logprob: 0, |
| | special: isLastChunk, |
| | }, |
| | generated_text: isLastChunk ? generatedText : null, |
| | details: null, |
| | }; |
| | yield output; |
| |
|
| | if (isLastChunk) break; |
| | } |
| | })(); |
| | }; |
| | } |
| | export default endpointVertex; |
| |
|