| import { buildPrompt } from "$lib/buildPrompt"; |
| import { textGenerationStream } from "@huggingface/inference"; |
| import { z } from "zod"; |
| import type { Endpoint } from "../endpoints"; |
|
|
| export const endpointAwsParametersSchema = z.object({ |
| weight: z.number().int().positive().default(1), |
| model: z.any(), |
| type: z.literal("aws"), |
| url: z.string().url(), |
| accessKey: z.string().min(1), |
| secretKey: z.string().min(1), |
| sessionToken: z.string().optional(), |
| service: z.union([z.literal("sagemaker"), z.literal("lambda")]).default("sagemaker"), |
| region: z.string().optional(), |
| }); |
|
|
| export async function endpointAws( |
| input: z.input<typeof endpointAwsParametersSchema> |
| ): Promise<Endpoint> { |
| let AwsClient; |
| try { |
| AwsClient = (await import("aws4fetch")).AwsClient; |
| } catch (e) { |
| throw new Error("Failed to import aws4fetch"); |
| } |
|
|
| const { url, accessKey, secretKey, sessionToken, model, region, service } = |
| endpointAwsParametersSchema.parse(input); |
|
|
| const aws = new AwsClient({ |
| accessKeyId: accessKey, |
| secretAccessKey: secretKey, |
| sessionToken, |
| service, |
| region, |
| }); |
|
|
| return async ({ messages, preprompt, continueMessage, generateSettings }) => { |
| const prompt = await buildPrompt({ |
| messages, |
| continueMessage, |
| preprompt, |
| model, |
| }); |
|
|
| return textGenerationStream( |
| { |
| parameters: { ...model.parameters, ...generateSettings, return_full_text: false }, |
| model: url, |
| inputs: prompt, |
| }, |
| { |
| use_cache: false, |
| fetch: aws.fetch.bind(aws) as typeof fetch, |
| } |
| ); |
| }; |
| } |
|
|
| export default endpointAws; |
|
|