|
|
import type { |
|
|
Endpoint, |
|
|
EndpointParameters, |
|
|
EndpointMessage, |
|
|
TextGenerationStreamOutputSimplified, |
|
|
} from "../endpoints/endpoints"; |
|
|
import endpoints from "../endpoints/endpoints"; |
|
|
import type { ProcessedModel } from "../models"; |
|
|
import { config } from "$lib/server/config"; |
|
|
import { logger } from "$lib/server/logger"; |
|
|
import { archSelectRoute } from "./arch"; |
|
|
import { getRoutes, resolveRouteModels } from "./policy"; |
|
|
import { getApiToken } from "$lib/server/apiToken"; |
|
|
import { ROUTER_FAILURE } from "./types"; |
|
|
import { |
|
|
hasActiveToolsSelection, |
|
|
isRouterToolsBypassEnabled, |
|
|
pickToolsCapableModel, |
|
|
ROUTER_TOOLS_ROUTE, |
|
|
} from "./toolsRoute"; |
|
|
import { getConfiguredMultimodalModelId } from "./multimodal"; |
|
|
|
|
|
const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g; |
|
|
|
|
|
const ROUTER_MULTIMODAL_ROUTE = "multimodal"; |
|
|
|
|
|
|
|
|
let cachedModels: ProcessedModel[] | undefined; |
|
|
|
|
|
async function getModels(): Promise<ProcessedModel[]> { |
|
|
if (!cachedModels) { |
|
|
const mod = await import("../models"); |
|
|
cachedModels = (mod as { models: ProcessedModel[] }).models; |
|
|
} |
|
|
return cachedModels; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HTTPError extends Error { |
|
|
constructor( |
|
|
message: string, |
|
|
public statusCode?: number |
|
|
) { |
|
|
super(message); |
|
|
this.name = "HTTPError"; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function extractUpstreamError(error: unknown): { message: string; statusCode?: number } { |
|
|
|
|
|
if (error && typeof error === "object") { |
|
|
const err = error as Record<string, unknown>; |
|
|
|
|
|
|
|
|
if ( |
|
|
err.error && |
|
|
typeof err.error === "object" && |
|
|
"message" in err.error && |
|
|
typeof err.error.message === "string" |
|
|
) { |
|
|
return { |
|
|
message: err.error.message, |
|
|
statusCode: typeof err.status === "number" ? err.status : undefined, |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
if (typeof err.statusCode === "number" && typeof err.message === "string") { |
|
|
return { message: err.message, statusCode: err.statusCode }; |
|
|
} |
|
|
|
|
|
|
|
|
if (typeof err.status === "number" && typeof err.message === "string") { |
|
|
return { message: err.message, statusCode: err.status }; |
|
|
} |
|
|
|
|
|
|
|
|
if (typeof err.message === "string") { |
|
|
return { message: err.message }; |
|
|
} |
|
|
} |
|
|
|
|
|
return { message: String(error) }; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function isPolicyError(statusCode?: number): boolean { |
|
|
if (!statusCode) return false; |
|
|
|
|
|
return statusCode === 400 || statusCode === 401 || statusCode === 402 || statusCode === 403; |
|
|
} |
|
|
|
|
|
function stripReasoningBlocks(text: string): string { |
|
|
const stripped = text.replace(REASONING_BLOCK_REGEX, ""); |
|
|
return stripped === text ? text : stripped.trim(); |
|
|
} |
|
|
|
|
|
function stripReasoningFromMessage(message: EndpointMessage): EndpointMessage { |
|
|
const content = |
|
|
typeof message.content === "string" ? stripReasoningBlocks(message.content) : message.content; |
|
|
return { |
|
|
...message, |
|
|
content, |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<Endpoint> { |
|
|
return async function routerEndpoint(params: EndpointParameters) { |
|
|
const routes = await getRoutes(); |
|
|
const sanitizedMessages = params.messages.map(stripReasoningFromMessage); |
|
|
const routerMultimodalEnabled = |
|
|
(config.LLM_ROUTER_ENABLE_MULTIMODAL || "").toLowerCase() === "true"; |
|
|
const routerToolsEnabled = isRouterToolsBypassEnabled(); |
|
|
const hasImageInput = sanitizedMessages.some((message) => |
|
|
(message.files ?? []).some( |
|
|
(file) => typeof file?.mime === "string" && file.mime.startsWith("image/") |
|
|
) |
|
|
); |
|
|
|
|
|
const hasToolsActive = hasActiveToolsSelection(params.locals); |
|
|
|
|
|
|
|
|
async function createCandidateEndpoint(candidateModelId: string): Promise<Endpoint> { |
|
|
|
|
|
let modelForCall: ProcessedModel | undefined; |
|
|
try { |
|
|
const all = await getModels(); |
|
|
modelForCall = all?.find((m) => m.id === candidateModelId || m.name === candidateModelId); |
|
|
} catch (e) { |
|
|
logger.warn({ err: String(e) }, "[router] failed to load models for candidate lookup"); |
|
|
} |
|
|
|
|
|
if (!modelForCall) { |
|
|
|
|
|
modelForCall = { |
|
|
...routerModel, |
|
|
id: candidateModelId, |
|
|
name: candidateModelId, |
|
|
displayName: candidateModelId, |
|
|
} as ProcessedModel; |
|
|
} |
|
|
|
|
|
return endpoints.openai({ |
|
|
type: "openai", |
|
|
baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""), |
|
|
apiKey: getApiToken(params.locals), |
|
|
model: modelForCall, |
|
|
|
|
|
streamingSupported: true, |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
async function* metadataThenStream( |
|
|
gen: AsyncGenerator<TextGenerationStreamOutputSimplified>, |
|
|
actualModel: string, |
|
|
selectedRoute: string |
|
|
) { |
|
|
yield { |
|
|
token: { id: 0, text: "", special: true, logprob: 0 }, |
|
|
generated_text: null, |
|
|
details: null, |
|
|
routerMetadata: { route: selectedRoute, model: actualModel }, |
|
|
}; |
|
|
for await (const ev of gen) yield ev; |
|
|
} |
|
|
|
|
|
if (routerMultimodalEnabled && hasImageInput) { |
|
|
let multimodalCandidate: string | undefined; |
|
|
try { |
|
|
const all = await getModels(); |
|
|
multimodalCandidate = getConfiguredMultimodalModelId(all); |
|
|
} catch (e) { |
|
|
logger.warn({ err: String(e) }, "[router] failed to load models for multimodal lookup"); |
|
|
} |
|
|
if (!multimodalCandidate) { |
|
|
throw new Error( |
|
|
"Router multimodal is enabled but LLM_ROUTER_MULTIMODAL_MODEL is not correctly configured. Remove the image or configure a multimodal model via LLM_ROUTER_MULTIMODAL_MODEL." |
|
|
); |
|
|
} |
|
|
|
|
|
try { |
|
|
logger.info( |
|
|
{ route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate }, |
|
|
"[router] multimodal input detected; bypassing Arch selection" |
|
|
); |
|
|
const ep = await createCandidateEndpoint(multimodalCandidate); |
|
|
const gen = await ep({ ...params }); |
|
|
return metadataThenStream(gen, multimodalCandidate, ROUTER_MULTIMODAL_ROUTE); |
|
|
} catch (e) { |
|
|
const { message, statusCode } = extractUpstreamError(e); |
|
|
logger.error( |
|
|
{ |
|
|
route: ROUTER_MULTIMODAL_ROUTE, |
|
|
model: multimodalCandidate, |
|
|
err: message, |
|
|
...(statusCode && { status: statusCode }), |
|
|
}, |
|
|
"[router] multimodal fallback failed" |
|
|
); |
|
|
throw statusCode ? new HTTPError(message, statusCode) : new Error(message); |
|
|
} |
|
|
} |
|
|
|
|
|
async function findToolsCandidateModel(): Promise<ProcessedModel | undefined> { |
|
|
try { |
|
|
const all = await getModels(); |
|
|
return pickToolsCapableModel(all); |
|
|
} catch (e) { |
|
|
logger.warn({ err: String(e) }, "[router] failed to load models for tools lookup"); |
|
|
return undefined; |
|
|
} |
|
|
} |
|
|
|
|
|
if (routerToolsEnabled && hasToolsActive) { |
|
|
const toolsModel = await findToolsCandidateModel(); |
|
|
const toolsCandidate = toolsModel?.id ?? toolsModel?.name; |
|
|
if (!toolsCandidate) { |
|
|
|
|
|
} else { |
|
|
try { |
|
|
logger.info( |
|
|
{ route: ROUTER_TOOLS_ROUTE, model: toolsCandidate }, |
|
|
"[router] tools active; bypassing Arch selection" |
|
|
); |
|
|
const ep = await createCandidateEndpoint(toolsCandidate); |
|
|
const gen = await ep({ ...params }); |
|
|
return metadataThenStream(gen, toolsCandidate, ROUTER_TOOLS_ROUTE); |
|
|
} catch (e) { |
|
|
const { message, statusCode } = extractUpstreamError(e); |
|
|
logger.error( |
|
|
{ |
|
|
route: ROUTER_TOOLS_ROUTE, |
|
|
model: toolsCandidate, |
|
|
err: message, |
|
|
...(statusCode && { status: statusCode }), |
|
|
}, |
|
|
"[router] tools fallback failed" |
|
|
); |
|
|
throw statusCode ? new HTTPError(message, statusCode) : new Error(message); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
const routeSelection = await archSelectRoute(sanitizedMessages, undefined, params.locals); |
|
|
|
|
|
|
|
|
|
|
|
if (routeSelection.routeName === ROUTER_FAILURE && routeSelection.error) { |
|
|
const { message, statusCode } = routeSelection.error; |
|
|
|
|
|
if (isPolicyError(statusCode)) { |
|
|
|
|
|
logger.error( |
|
|
{ err: message, ...(statusCode && { status: statusCode }) }, |
|
|
"[router] arch router failed with policy error, propagating to client" |
|
|
); |
|
|
throw statusCode ? new HTTPError(message, statusCode) : new Error(message); |
|
|
} |
|
|
|
|
|
|
|
|
logger.warn( |
|
|
{ err: message, ...(statusCode && { status: statusCode }) }, |
|
|
"[router] arch router failed with transient error, attempting fallback" |
|
|
); |
|
|
} |
|
|
|
|
|
const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id; |
|
|
const { candidates } = resolveRouteModels(routeSelection.routeName, routes, fallbackModel); |
|
|
|
|
|
let lastErr: unknown = undefined; |
|
|
for (const candidate of candidates) { |
|
|
try { |
|
|
logger.info( |
|
|
{ route: routeSelection.routeName, model: candidate }, |
|
|
"[router] trying candidate" |
|
|
); |
|
|
const ep = await createCandidateEndpoint(candidate); |
|
|
const gen = await ep({ ...params }); |
|
|
return metadataThenStream(gen, candidate, routeSelection.routeName); |
|
|
} catch (e) { |
|
|
lastErr = e; |
|
|
const { message: errMsg, statusCode: errStatus } = extractUpstreamError(e); |
|
|
logger.warn( |
|
|
{ |
|
|
route: routeSelection.routeName, |
|
|
model: candidate, |
|
|
err: errMsg, |
|
|
...(errStatus && { status: errStatus }), |
|
|
}, |
|
|
"[router] candidate failed" |
|
|
); |
|
|
continue; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
const { message, statusCode } = extractUpstreamError(lastErr); |
|
|
throw statusCode ? new HTTPError(message, statusCode) : new Error(message); |
|
|
}; |
|
|
} |
|
|
|