Spaces:
Sleeping
Sleeping
File size: 5,785 Bytes
3b53c7a 7bf1507 c4408b8 7bf1507 9092d43 7bf1507 3b53c7a 7bf1507 c4408b8 7bf1507 9092d43 7bf1507 3b53c7a 7bf1507 c4408b8 7bf1507 c4408b8 7bf1507 9092d43 3b53c7a 9092d43 7bf1507 9092d43 3b53c7a 7bf1507 9092d43 3b53c7a 9092d43 c4408b8 9092d43 3b53c7a 7bf1507 9092d43 7bf1507 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import type {
Endpoint,
EndpointParameters,
EndpointMessage,
TextGenerationStreamOutputSimplified,
} from "../endpoints/endpoints";
import endpoints from "../endpoints/endpoints";
import type { ProcessedModel, EndpointOptions } from "../models";
import { config } from "$lib/server/config";
import { logger } from "$lib/server/logger";
import { archSelectRoute } from "./arch";
import { getRoutes, resolveRouteModels } from "./policy";
const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
const ROUTER_MULTIMODAL_ROUTE = "multimodal";
function stripReasoningBlocks(text: string): string {
const stripped = text.replace(REASONING_BLOCK_REGEX, "");
return stripped === text ? text : stripped.trim();
}
function stripReasoningFromMessage(message: EndpointMessage): EndpointMessage {
const { reasoning: _reasoning, ...rest } = message;
void _reasoning;
const content =
typeof message.content === "string" ? stripReasoningBlocks(message.content) : message.content;
return {
...rest,
content,
};
}
/**
* Create an Endpoint that performs route selection via Arch and then forwards
* to the selected model (with fallbacks) using the OpenAI-compatible endpoint.
*/
export async function makeRouterEndpoint(
routerModel: ProcessedModel,
options?: EndpointOptions
): Promise<Endpoint> {
return async function routerEndpoint(params: EndpointParameters) {
const routes = await getRoutes();
const sanitizedMessages = params.messages.map(stripReasoningFromMessage);
const routerMultimodalEnabled =
(config.LLM_ROUTER_ENABLE_MULTIMODAL || "").toLowerCase() === "true";
const hasImageInput = sanitizedMessages.some((message) =>
(message.files ?? []).some(
(file) => typeof file?.mime === "string" && file.mime.startsWith("image/")
)
);
// Helper to create an OpenAI endpoint for a specific candidate model id
async function createCandidateEndpoint(candidateModelId: string): Promise<Endpoint> {
// Try to use the real candidate model config if present in chat-ui's model list
let modelForCall: ProcessedModel | undefined;
try {
const mod = await import("../models");
const all = (mod as { models: ProcessedModel[] }).models;
modelForCall = all?.find((m) => m.id === candidateModelId || m.name === candidateModelId);
} catch (e) {
logger.warn({ err: String(e) }, "[router] failed to load models for candidate lookup");
}
if (!modelForCall) {
// Fallback: clone router model with candidate id
modelForCall = {
...routerModel,
id: candidateModelId,
name: candidateModelId,
displayName: candidateModelId,
} as ProcessedModel;
}
const defaultApiKey = config.OPENAI_API_KEY || config.HF_TOKEN || "sk-";
return endpoints.openai({
type: "openai",
baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""),
apiKey: options?.apiKey ?? defaultApiKey,
model: modelForCall,
// Ensure streaming path is used
streamingSupported: true,
});
}
// Yield router metadata for immediate UI display, using the actual candidate
async function* metadataThenStream(
gen: AsyncGenerator<TextGenerationStreamOutputSimplified>,
actualModel: string,
selectedRoute: string
) {
yield {
token: { id: 0, text: "", special: true, logprob: 0 },
generated_text: null,
details: null,
routerMetadata: { route: selectedRoute, model: actualModel },
};
for await (const ev of gen) yield ev;
}
async function findFirstMultimodalCandidateId(): Promise<string | undefined> {
try {
const mod = await import("../models");
const all = (mod as { models: ProcessedModel[] }).models;
const first = all?.find((m) => !m.isRouter && m.multimodal);
return first?.id ?? first?.name;
} catch (e) {
logger.warn({ err: String(e) }, "[router] failed to load models for multimodal lookup");
return undefined;
}
}
if (routerMultimodalEnabled && hasImageInput) {
const multimodalCandidate = await findFirstMultimodalCandidateId();
if (!multimodalCandidate) {
throw new Error(
"No multimodal models are configured for the router. Remove the image or enable a multimodal model."
);
}
try {
logger.info(
{ route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate },
"[router] multimodal input detected; bypassing Arch selection"
);
const ep = await createCandidateEndpoint(multimodalCandidate);
const gen = await ep({ ...params });
return metadataThenStream(gen, multimodalCandidate, ROUTER_MULTIMODAL_ROUTE);
} catch (e) {
logger.error(
{ route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate, err: String(e) },
"[router] multimodal fallback failed"
);
throw new Error(
"Failed to call the configured multimodal model. Remove the image or try again later."
);
}
}
const { routeName } = await archSelectRoute(sanitizedMessages, { apiKey: options?.apiKey });
const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
let lastErr: unknown = undefined;
for (const candidate of candidates) {
try {
logger.info({ route: routeName, model: candidate }, "[router] trying candidate");
const ep = await createCandidateEndpoint(candidate);
const gen = await ep({ ...params });
return metadataThenStream(gen, candidate, routeName);
} catch (e) {
lastErr = e;
logger.warn(
{ route: routeName, model: candidate, err: String(e) },
"[router] candidate failed"
);
continue;
}
}
// Exhausted all candidates — throw to signal upstream failure
throw new Error(`Routing failed for route=${routeName}: ${String(lastErr)}`);
};
}
|