File size: 5,785 Bytes
3b53c7a
 
 
 
 
 
7bf1507
c4408b8
7bf1507
 
 
 
 
 
 
9092d43
 
7bf1507
 
 
 
 
 
 
3b53c7a
7bf1507
 
 
 
 
 
 
 
 
 
 
 
c4408b8
 
 
 
7bf1507
 
 
9092d43
 
 
 
 
 
 
7bf1507
 
 
 
 
 
 
3b53c7a
7bf1507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4408b8
 
7bf1507
 
 
c4408b8
7bf1507
 
 
 
 
 
 
9092d43
3b53c7a
9092d43
 
 
7bf1507
 
 
 
9092d43
3b53c7a
7bf1507
 
 
9092d43
 
 
3b53c7a
9092d43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4408b8
9092d43
 
 
 
3b53c7a
7bf1507
 
 
 
 
9092d43
7bf1507
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import type {
	Endpoint,
	EndpointParameters,
	EndpointMessage,
	TextGenerationStreamOutputSimplified,
} from "../endpoints/endpoints";
import endpoints from "../endpoints/endpoints";
import type { ProcessedModel, EndpointOptions } from "../models";
import { config } from "$lib/server/config";
import { logger } from "$lib/server/logger";
import { archSelectRoute } from "./arch";
import { getRoutes, resolveRouteModels } from "./policy";

const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;

const ROUTER_MULTIMODAL_ROUTE = "multimodal";

function stripReasoningBlocks(text: string): string {
	const stripped = text.replace(REASONING_BLOCK_REGEX, "");
	return stripped === text ? text : stripped.trim();
}

function stripReasoningFromMessage(message: EndpointMessage): EndpointMessage {
	const { reasoning: _reasoning, ...rest } = message;
	void _reasoning;
	const content =
		typeof message.content === "string" ? stripReasoningBlocks(message.content) : message.content;
	return {
		...rest,
		content,
	};
}

/**
 * Create an Endpoint that performs route selection via Arch and then forwards
 * to the selected model (with fallbacks) using the OpenAI-compatible endpoint.
 */
export async function makeRouterEndpoint(
	routerModel: ProcessedModel,
	options?: EndpointOptions
): Promise<Endpoint> {
	return async function routerEndpoint(params: EndpointParameters) {
		const routes = await getRoutes();
		const sanitizedMessages = params.messages.map(stripReasoningFromMessage);
		const routerMultimodalEnabled =
			(config.LLM_ROUTER_ENABLE_MULTIMODAL || "").toLowerCase() === "true";
		const hasImageInput = sanitizedMessages.some((message) =>
			(message.files ?? []).some(
				(file) => typeof file?.mime === "string" && file.mime.startsWith("image/")
			)
		);

		// Helper to create an OpenAI endpoint for a specific candidate model id
		async function createCandidateEndpoint(candidateModelId: string): Promise<Endpoint> {
			// Try to use the real candidate model config if present in chat-ui's model list
			let modelForCall: ProcessedModel | undefined;
			try {
				const mod = await import("../models");
				const all = (mod as { models: ProcessedModel[] }).models;
				modelForCall = all?.find((m) => m.id === candidateModelId || m.name === candidateModelId);
			} catch (e) {
				logger.warn({ err: String(e) }, "[router] failed to load models for candidate lookup");
			}

			if (!modelForCall) {
				// Fallback: clone router model with candidate id
				modelForCall = {
					...routerModel,
					id: candidateModelId,
					name: candidateModelId,
					displayName: candidateModelId,
				} as ProcessedModel;
			}

			const defaultApiKey = config.OPENAI_API_KEY || config.HF_TOKEN || "sk-";

			return endpoints.openai({
				type: "openai",
				baseURL: (config.OPENAI_BASE_URL || "https://router.huggingface.co/v1").replace(/\/$/, ""),
				apiKey: options?.apiKey ?? defaultApiKey,
				model: modelForCall,
				// Ensure streaming path is used
				streamingSupported: true,
			});
		}

		// Yield router metadata for immediate UI display, using the actual candidate
		async function* metadataThenStream(
			gen: AsyncGenerator<TextGenerationStreamOutputSimplified>,
			actualModel: string,
			selectedRoute: string
		) {
			yield {
				token: { id: 0, text: "", special: true, logprob: 0 },
				generated_text: null,
				details: null,
				routerMetadata: { route: selectedRoute, model: actualModel },
			};
			for await (const ev of gen) yield ev;
		}

		async function findFirstMultimodalCandidateId(): Promise<string | undefined> {
			try {
				const mod = await import("../models");
				const all = (mod as { models: ProcessedModel[] }).models;
				const first = all?.find((m) => !m.isRouter && m.multimodal);
				return first?.id ?? first?.name;
			} catch (e) {
				logger.warn({ err: String(e) }, "[router] failed to load models for multimodal lookup");
				return undefined;
			}
		}

		if (routerMultimodalEnabled && hasImageInput) {
			const multimodalCandidate = await findFirstMultimodalCandidateId();
			if (!multimodalCandidate) {
				throw new Error(
					"No multimodal models are configured for the router. Remove the image or enable a multimodal model."
				);
			}

			try {
				logger.info(
					{ route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate },
					"[router] multimodal input detected; bypassing Arch selection"
				);
				const ep = await createCandidateEndpoint(multimodalCandidate);
				const gen = await ep({ ...params });
				return metadataThenStream(gen, multimodalCandidate, ROUTER_MULTIMODAL_ROUTE);
			} catch (e) {
				logger.error(
					{ route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate, err: String(e) },
					"[router] multimodal fallback failed"
				);
				throw new Error(
					"Failed to call the configured multimodal model. Remove the image or try again later."
				);
			}
		}

		const { routeName } = await archSelectRoute(sanitizedMessages, { apiKey: options?.apiKey });

		const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
		const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);

		let lastErr: unknown = undefined;
		for (const candidate of candidates) {
			try {
				logger.info({ route: routeName, model: candidate }, "[router] trying candidate");
				const ep = await createCandidateEndpoint(candidate);
				const gen = await ep({ ...params });
				return metadataThenStream(gen, candidate, routeName);
			} catch (e) {
				lastErr = e;
				logger.warn(
					{ route: routeName, model: candidate, err: String(e) },
					"[router] candidate failed"
				);
				continue;
			}
		}

		// Exhausted all candidates — throw to signal upstream failure
		throw new Error(`Routing failed for route=${routeName}: ${String(lastErr)}`);
	};
}