| import { z } from "zod"; |
| import { env } from "$env/dynamic/private"; |
| import type { Endpoint } from "../endpoints"; |
| import type { TextGenerationStreamOutput } from "@huggingface/inference"; |
| import type { Cohere, CohereClient } from "cohere-ai"; |
| import { buildPrompt } from "$lib/buildPrompt"; |
|
|
| export const endpointCohereParametersSchema = z.object({ |
| weight: z.number().int().positive().default(1), |
| model: z.any(), |
| type: z.literal("cohere"), |
| apiKey: z.string().default(env.COHERE_API_TOKEN), |
| raw: z.boolean().default(false), |
| }); |
|
|
| export async function endpointCohere( |
| input: z.input<typeof endpointCohereParametersSchema> |
| ): Promise<Endpoint> { |
| const { apiKey, model, raw } = endpointCohereParametersSchema.parse(input); |
|
|
| let cohere: CohereClient; |
|
|
| try { |
| cohere = new (await import("cohere-ai")).CohereClient({ |
| token: apiKey, |
| }); |
| } catch (e) { |
| throw new Error("Failed to import cohere-ai", { cause: e }); |
| } |
|
|
| return async ({ messages, preprompt, generateSettings, continueMessage }) => { |
| let system = preprompt; |
| if (messages?.[0]?.from === "system") { |
| system = messages[0].content; |
| } |
|
|
| const parameters = { ...model.parameters, ...generateSettings }; |
|
|
| return (async function* () { |
| let stream; |
| let tokenId = 0; |
|
|
| if (raw) { |
| const prompt = await buildPrompt({ |
| messages: messages.filter((message) => message.from !== "system"), |
| model, |
| preprompt: system, |
| continueMessage, |
| }); |
|
|
| stream = await cohere.chatStream({ |
| message: prompt, |
| rawPrompting: true, |
| model: model.id ?? model.name, |
| p: parameters?.top_p, |
| k: parameters?.top_k, |
| maxTokens: parameters?.max_new_tokens, |
| temperature: parameters?.temperature, |
| stopSequences: parameters?.stop, |
| frequencyPenalty: parameters?.frequency_penalty, |
| }); |
| } else { |
| const formattedMessages = messages |
| .filter((message) => message.from !== "system") |
| .map((message) => ({ |
| role: message.from === "user" ? "USER" : "CHATBOT", |
| message: message.content, |
| })) satisfies Cohere.ChatMessage[]; |
|
|
| stream = await cohere.chatStream({ |
| model: model.id ?? model.name, |
| chatHistory: formattedMessages.slice(0, -1), |
| message: formattedMessages[formattedMessages.length - 1].message, |
| preamble: system, |
| p: parameters?.top_p, |
| k: parameters?.top_k, |
| maxTokens: parameters?.max_new_tokens, |
| temperature: parameters?.temperature, |
| stopSequences: parameters?.stop, |
| frequencyPenalty: parameters?.frequency_penalty, |
| }); |
| } |
|
|
| for await (const output of stream) { |
| if (output.eventType === "text-generation") { |
| yield { |
| token: { |
| id: tokenId++, |
| text: output.text, |
| logprob: 0, |
| special: false, |
| }, |
| generated_text: null, |
| details: null, |
| } satisfies TextGenerationStreamOutput; |
| } else if (output.eventType === "stream-end") { |
| if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) { |
| throw new Error(output.finishReason); |
| } |
| yield { |
| token: { |
| id: tokenId++, |
| text: "", |
| logprob: 0, |
| special: true, |
| }, |
| generated_text: output.response.text, |
| details: null, |
| }; |
| } |
| } |
| })(); |
| }; |
| } |
|
|