File size: 12,391 Bytes
31daf3d 79b976b ac22d69 f29cdb8 0ab1dfb b4690c8 0ab1dfb 3a19267 f29cdb8 f3064dd 3d931f5 806c0e0 3d931f5 bb6f8cb e08e6dd 7d7a53f 4e9a7a9 31daf3d 806c0e0 ec59d54 764ecdf 6647bbf 764ecdf b39a7fc d6bddc2 b39a7fc 8cabca8 b39a7fc eb2ef82 b39a7fc 0598c3f b39a7fc 0ab1dfb b39a7fc 52dfa8c 02d0f85 d3928c1 b39a7fc 0ab1dfb 52dfa8c 4489403 b39a7fc 0348c77 a5e332d faa93d9 0a3916d 3a19267 d4c7ddb 764ecdf b39a7fc f29cdb8 4e9a7a9 6a449cd 31daf3d 4e9a7a9 31daf3d 4e9a7a9 b39a7fc 3d931f5 4e9a7a9 d8e426c 4e9a7a9 3d931f5 2b554f9 3d931f5 2b554f9 bb6f8cb fa0afa9 2b554f9 3d931f5 4b62530 3d931f5 e08e6dd f7aef71 b8228c1 3d931f5 d4c7ddb b8228c1 3d931f5 d4c7ddb 3d931f5 b8228c1 3d931f5 4f61809 f7aef71 3d931f5 b39a7fc 3d931f5 b39a7fc 0ab1dfb 31daf3d 0ab1dfb 02d0f85 4e9a7a9 e08e6dd 1db019c 76d1477 b6274e8 02d0f85 48f1340 6f1638a ac2d8ff 1aafdd3 a9e4746 02d0f85 0ab1dfb 31daf3d 7d7a53f bad5446 21b8785 7d7a53f bad5446 baa5d2f bad5446 baa5d2f faa93d9 0ab1dfb baa5d2f a1ae528 68972f0 a1ae528 0ab1dfb 4489403 fbb1115 31daf3d fbb1115 a1ae528 fbb1115 31daf3d fbb1115 b39a7fc 56ea25e 0ab1dfb 4e9a7a9 31daf3d 4e9a7a9 b39a7fc 0a3916d baa5d2f 0a3916d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 | import { config } from "$lib/server/config";
import type { ChatTemplateInput } from "$lib/types/Template";
import { compileTemplate } from "$lib/utils/template";
import { z } from "zod";
import endpoints, { endpointSchema, type Endpoint } from "./endpoints/endpoints";
import { endpointTgi } from "./endpoints/tgi/endpointTgi";
import { sum } from "$lib/utils/sum";
import { embeddingModels, validateEmbeddingModelByName } from "./embeddingModels";
import type { PreTrainedTokenizer } from "@huggingface/transformers";
import JSON5 from "json5";
import { getTokenizer } from "$lib/utils/getTokenizer";
import { logger } from "$lib/server/logger";
import { type ToolInput } from "$lib/types/Tool";
import { fetchJSON } from "$lib/utils/fetchJSON";
import { join, dirname } from "path";
import { fileURLToPath } from "url";
import { findRepoRoot } from "./findRepoRoot";
import { Template } from "@huggingface/jinja";
import { readdirSync } from "fs";
export const MODELS_FOLDER =
config.MODELS_STORAGE_PATH ||
join(findRepoRoot(dirname(fileURLToPath(import.meta.url))), "models");
type Optional<T, K extends keyof T> = Pick<Partial<T>, K> & Omit<T, K>;
const reasoningSchema = z.union([
z.object({
type: z.literal("regex"), // everything is reasoning, extract the answer from the regex
regex: z.string(),
}),
z.object({
type: z.literal("tokens"), // use beginning and end tokens that define the reasoning portion of the answer
beginToken: z.string(), // empty string means the model starts in reasoning mode
endToken: z.string(),
}),
z.object({
type: z.literal("summarize"), // everything is reasoning, summarize the answer
}),
]);
const modelConfig = z.object({
/** Used as an identifier in DB */
id: z.string().optional(),
/** Used to link to the model page, and for inference */
name: z.string().default(""),
displayName: z.string().min(1).optional(),
description: z.string().min(1).optional(),
logoUrl: z.string().url().optional(),
websiteUrl: z.string().url().optional(),
modelUrl: z.string().url().optional(),
tokenizer: z
.union([
z.string(),
z.object({
tokenizerUrl: z.string().url(),
tokenizerConfigUrl: z.string().url(),
}),
])
.optional(),
datasetName: z.string().min(1).optional(),
datasetUrl: z.string().url().optional(),
preprompt: z.string().default(""),
prepromptUrl: z.string().url().optional(),
chatPromptTemplate: z.string().optional(),
promptExamples: z
.array(
z.object({
title: z.string().min(1),
prompt: z.string().min(1),
})
)
.optional(),
endpoints: z.array(endpointSchema).optional(),
parameters: z
.object({
temperature: z.number().min(0).max(2).optional(),
truncate: z.number().int().positive().optional(),
max_new_tokens: z.number().int().positive().optional(),
stop: z.array(z.string()).optional(),
top_p: z.number().positive().optional(),
top_k: z.number().positive().optional(),
repetition_penalty: z.number().min(-2).max(2).optional(),
presence_penalty: z.number().min(-2).max(2).optional(),
})
.passthrough()
.optional(),
multimodal: z.boolean().default(false),
multimodalAcceptedMimetypes: z.array(z.string()).optional(),
tools: z.boolean().default(false),
unlisted: z.boolean().default(false),
embeddingModel: validateEmbeddingModelByName(embeddingModels).optional(),
/** Used to enable/disable system prompt usage */
systemRoleSupported: z.boolean().default(true),
reasoning: reasoningSchema.optional(),
});
const ggufModelsConfig = await Promise.all(
readdirSync(MODELS_FOLDER)
.filter((f) => f.endsWith(".gguf"))
.map(async (f) => {
return {
name: f.replace(".gguf", ""),
endpoints: [
{
type: "local" as const,
modelPath: f,
},
],
};
})
);
const turnStringIntoLocalModel = z.preprocess((obj: unknown) => {
if (typeof obj !== "string") return obj;
const name = obj.startsWith("hf:") ? obj.split(":")[1] : obj;
const displayName = obj.startsWith("hf:")
? obj.split(":")[1].split("/").slice(0, 2).join("/")
: obj.endsWith(".gguf")
? obj.replace(".gguf", "")
: obj;
const modelPath = obj.includes("/") && !obj.startsWith("hf:") ? `hf:${obj}` : obj;
return {
name,
displayName,
endpoints: [
{
type: "local",
modelPath,
},
],
} satisfies z.input<typeof modelConfig>;
}, modelConfig);
let modelsRaw = z.array(turnStringIntoLocalModel).parse(JSON5.parse(config.MODELS ?? "[]"));
if (config.LOAD_GGUF_MODELS === "true" || modelsRaw.length === 0) {
const parsedGgufModels = z.array(modelConfig).parse(ggufModelsConfig);
modelsRaw = [...modelsRaw, ...parsedGgufModels];
}
async function getChatPromptRender(
m: z.infer<typeof modelConfig>
): Promise<ReturnType<typeof compileTemplate<ChatTemplateInput>>> {
if (m.endpoints?.some((e) => e.type === "local")) {
const endpoint = m.endpoints?.find((e) => e.type === "local");
const path = endpoint?.modelPath ?? `hf:${m.id ?? m.name}`;
const { resolveModelFile, readGgufFileInfo } = await import("node-llama-cpp");
const modelPath = await resolveModelFile(path, MODELS_FOLDER);
const info = await readGgufFileInfo(modelPath, {
readTensorInfo: false,
});
if (info.metadata.tokenizer.chat_template) {
// compile with jinja
const jinjaTemplate = new Template(info.metadata.tokenizer.chat_template);
return (inputs: ChatTemplateInput) => {
return jinjaTemplate.render({ ...m, ...inputs });
};
}
}
if (m.chatPromptTemplate) {
return compileTemplate<ChatTemplateInput>(m.chatPromptTemplate, m);
}
let tokenizer: PreTrainedTokenizer;
try {
tokenizer = await getTokenizer(m.tokenizer ?? m.id ?? m.name);
} catch (e) {
// if fetching the tokenizer fails but it wasnt manually set, use the default template
if (!m.tokenizer) {
logger.warn(
`No tokenizer found for model ${m.name}, using default template. Consider setting tokenizer manually or making sure the model is available on the hub.`,
m
);
return compileTemplate<ChatTemplateInput>(
"{{#if @root.preprompt}}<|im_start|>system\n{{@root.preprompt}}<|im_end|>\n{{/if}}{{#each messages}}{{#ifUser}}<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n{{/ifUser}}{{#ifAssistant}}{{content}}<|im_end|>\n{{/ifAssistant}}{{/each}}",
m
);
}
logger.error(
e,
`Failed to load tokenizer ${
m.tokenizer ?? m.id ?? m.name
} make sure the model is available on the hub and you have access to any gated models.`
);
process.exit();
}
const renderTemplate = ({ messages, preprompt, tools, continueMessage }: ChatTemplateInput) => {
let formattedMessages: {
role: string;
content: string;
tool_calls?: { id: string; tool_call_id: string; output: string }[];
}[] = messages.map((message) => ({
content: message.content,
role: message.from,
}));
if (!m.systemRoleSupported) {
const firstSystemMessage = formattedMessages.find((msg) => msg.role === "system");
formattedMessages = formattedMessages.filter((msg) => msg.role !== "system");
if (
firstSystemMessage &&
formattedMessages.length > 0 &&
formattedMessages[0].role === "user"
) {
formattedMessages[0].content =
firstSystemMessage.content + "\n" + formattedMessages[0].content;
}
}
if (preprompt && formattedMessages[0].role !== "system") {
formattedMessages = [
{
role: m.systemRoleSupported ? "system" : "user",
content: preprompt,
},
...formattedMessages,
];
}
const mappedTools =
tools?.map((tool) => {
const inputs: Record<
string,
{
type: ToolInput["type"];
description: string;
required: boolean;
}
> = {};
for (const value of tool.inputs) {
if (value.paramType !== "fixed") {
inputs[value.name] = {
type: value.type,
description: value.description ?? "",
required: value.paramType === "required",
};
}
}
return {
name: tool.name,
description: tool.description,
parameter_definitions: inputs,
};
}) ?? [];
const output = tokenizer.apply_chat_template(formattedMessages, {
tokenize: false,
add_generation_prompt: !continueMessage,
tools: mappedTools.length ? mappedTools : undefined,
});
if (typeof output !== "string") {
throw new Error("Failed to apply chat template, the output is not a string");
}
return output;
};
return renderTemplate;
}
const processModel = async (m: z.infer<typeof modelConfig>) => ({
...m,
chatPromptRender: await getChatPromptRender(m),
id: m.id || m.name,
displayName: m.displayName || m.name,
preprompt: m.prepromptUrl ? await fetch(m.prepromptUrl).then((r) => r.text()) : m.preprompt,
parameters: { ...m.parameters, stop_sequences: m.parameters?.stop },
});
const addEndpoint = (m: Awaited<ReturnType<typeof processModel>>) => ({
...m,
getEndpoint: async (): Promise<Endpoint> => {
if (!m.endpoints) {
return endpointTgi({
type: "tgi",
url: `${config.HF_API_ROOT}/${m.name}`,
accessToken: config.HF_TOKEN ?? config.HF_ACCESS_TOKEN,
weight: 1,
model: m,
});
}
const totalWeight = sum(m.endpoints.map((e) => e.weight));
let random = Math.random() * totalWeight;
for (const endpoint of m.endpoints) {
if (random < endpoint.weight) {
const args = { ...endpoint, model: m };
switch (args.type) {
case "tgi":
return endpoints.tgi(args);
case "local":
return endpoints.local(args);
case "inference-client":
return endpoints.inferenceClient(args);
case "anthropic":
return endpoints.anthropic(args);
case "anthropic-vertex":
return endpoints.anthropicvertex(args);
case "bedrock":
return endpoints.bedrock(args);
case "aws":
return await endpoints.aws(args);
case "openai":
return await endpoints.openai(args);
case "llamacpp":
return endpoints.llamacpp(args);
case "ollama":
return endpoints.ollama(args);
case "vertex":
return await endpoints.vertex(args);
case "genai":
return await endpoints.genai(args);
case "cloudflare":
return await endpoints.cloudflare(args);
case "cohere":
return await endpoints.cohere(args);
case "langserve":
return await endpoints.langserve(args);
default:
// for legacy reason
return endpoints.tgi(args);
}
}
random -= endpoint.weight;
}
throw new Error(`Failed to select endpoint`);
},
});
const inferenceApiIds = config.isHuggingChat
? await fetchJSON<{ id: string }[]>(
"https://huggingface.co/api/models?pipeline_tag=text-generation&inference=warm&filter=conversational"
)
.then((arr) => arr?.map((r) => r.id) || [])
.catch(() => {
logger.error("Failed to fetch inference API ids");
return [];
})
: [];
export const models = await Promise.all(
modelsRaw.map((e) =>
processModel(e)
.then(addEndpoint)
.then(async (m) => ({
...m,
hasInferenceAPI: inferenceApiIds.includes(m.id ?? m.name),
}))
)
);
export type ProcessedModel = (typeof models)[number];
// super ugly but not sure how to make typescript happier
export const validModelIdSchema = z.enum(models.map((m) => m.id) as [string, ...string[]]);
export const defaultModel = models[0];
// Models that have been deprecated
export const oldModels = config.OLD_MODELS
? z
.array(
z.object({
id: z.string().optional(),
name: z.string().min(1),
displayName: z.string().min(1).optional(),
transferTo: validModelIdSchema.optional(),
})
)
.parse(JSON5.parse(config.OLD_MODELS))
.map((m) => ({ ...m, id: m.id || m.name, displayName: m.displayName || m.name }))
: [];
export const validateModel = (_models: BackendModel[]) => {
// Zod enum function requires 2 parameters
return z.enum([_models[0].id, ..._models.slice(1).map((m) => m.id)]);
};
// if `TASK_MODEL` is string & name of a model in `MODELS`, then we use `MODELS[TASK_MODEL]`, else we try to parse `TASK_MODEL` as a model config itself
export const taskModel = addEndpoint(
config.TASK_MODEL
? ((models.find((m) => m.name === config.TASK_MODEL) ||
(await processModel(modelConfig.parse(JSON5.parse(config.TASK_MODEL))))) ??
defaultModel)
: defaultModel
);
export type BackendModel = Optional<
typeof defaultModel,
"preprompt" | "parameters" | "multimodal" | "unlisted" | "tools" | "hasInferenceAPI"
>;
|