| import { z } from "zod"; |
| import type { Endpoint } from "../endpoints"; |
| import type { TextGenerationStreamOutput } from "@huggingface/inference"; |
| import { env } from "$env/dynamic/private"; |
| import { logger } from "$lib/server/logger"; |
|
|
| export const endpointCloudflareParametersSchema = z.object({ |
| weight: z.number().int().positive().default(1), |
| model: z.any(), |
| type: z.literal("cloudflare"), |
| accountId: z.string().default(env.CLOUDFLARE_ACCOUNT_ID), |
| apiToken: z.string().default(env.CLOUDFLARE_API_TOKEN), |
| }); |
|
|
| export async function endpointCloudflare( |
| input: z.input<typeof endpointCloudflareParametersSchema> |
| ): Promise<Endpoint> { |
| const { accountId, apiToken, model } = endpointCloudflareParametersSchema.parse(input); |
| const apiURL = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/@hf/${model.id}`; |
|
|
| return async ({ messages, preprompt }) => { |
| let messagesFormatted = messages.map((message) => ({ |
| role: message.from, |
| content: message.content, |
| })); |
|
|
| if (messagesFormatted?.[0]?.role !== "system") { |
| messagesFormatted = [{ role: "system", content: preprompt ?? "" }, ...messagesFormatted]; |
| } |
|
|
| const payload = JSON.stringify({ |
| messages: messagesFormatted, |
| stream: true, |
| }); |
|
|
| const res = await fetch(apiURL, { |
| method: "POST", |
| headers: { |
| Authorization: `Bearer ${apiToken}`, |
| "Content-Type": "application/json", |
| }, |
| body: payload, |
| }); |
|
|
| if (!res.ok) { |
| throw new Error(`Failed to generate text: ${await res.text()}`); |
| } |
|
|
| const encoder = new TextDecoderStream(); |
| const reader = res.body?.pipeThrough(encoder).getReader(); |
|
|
| return (async function* () { |
| let stop = false; |
| let generatedText = ""; |
| let tokenId = 0; |
| let accumulatedData = ""; |
|
|
| while (!stop) { |
| const out = await reader?.read(); |
|
|
| |
| if (out?.done) { |
| reader?.cancel(); |
| return; |
| } |
|
|
| if (!out?.value) { |
| return; |
| } |
|
|
| |
| accumulatedData += out.value; |
|
|
| |
| while (accumulatedData.includes("\n")) { |
| |
| const endIndex = accumulatedData.indexOf("\n"); |
| let jsonString = accumulatedData.substring(0, endIndex).trim(); |
|
|
| |
| accumulatedData = accumulatedData.substring(endIndex + 1); |
|
|
| if (jsonString.startsWith("data: ")) { |
| jsonString = jsonString.slice(6); |
| let data = null; |
|
|
| if (jsonString === "[DONE]") { |
| stop = true; |
|
|
| yield { |
| token: { |
| id: tokenId++, |
| text: "", |
| logprob: 0, |
| special: true, |
| }, |
| generated_text: generatedText, |
| details: null, |
| } satisfies TextGenerationStreamOutput; |
| reader?.cancel(); |
|
|
| continue; |
| } |
|
|
| try { |
| data = JSON.parse(jsonString); |
| } catch (e) { |
| logger.error("Failed to parse JSON", e); |
| logger.error("Problematic JSON string:", jsonString); |
| continue; |
| } |
|
|
| |
| if (data.response) { |
| generatedText += data.response ?? ""; |
| const output: TextGenerationStreamOutput = { |
| token: { |
| id: tokenId++, |
| text: data.response ?? "", |
| logprob: 0, |
| special: false, |
| }, |
| generated_text: null, |
| details: null, |
| }; |
| yield output; |
| } |
| } |
| } |
| } |
| })(); |
| }; |
| } |
|
|
| export default endpointCloudflare; |
|
|