victor HF Staff commited on
Commit
9092d43
·
unverified ·
1 Parent(s): b9b13da

Omni multimodality (#1880)

Browse files

* Add multimodal fallback support to router

* fix: improve error handling for multimodal model routing

* Handle multimodal messages without files

.env CHANGED
@@ -56,6 +56,9 @@ LLM_ROUTER_FALLBACK_MODEL=
56
  # Arch selection timeout in milliseconds (default 10000)
57
  LLM_ROUTER_ARCH_TIMEOUT_MS=10000
58
 
 
 
 
59
  # Router UI overrides (client-visible)
60
  # Public display name for the router entry in the model list. Defaults to "Omni".
61
  PUBLIC_LLM_ROUTER_DISPLAY_NAME=Omni
 
56
  # Arch selection timeout in milliseconds (default 10000)
57
  LLM_ROUTER_ARCH_TIMEOUT_MS=10000
58
 
59
+ # Enable router multimodal fallback (set to true to allow image inputs via router)
60
+ LLM_ROUTER_ENABLE_MULTIMODAL=false
61
+
62
  # Router UI overrides (client-visible)
63
  # Public display name for the router entry in the model list. Defaults to "Omni".
64
  PUBLIC_LLM_ROUTER_DISPLAY_NAME=Omni
src/lib/server/endpoints/openai/endpointOai.ts CHANGED
@@ -217,11 +217,11 @@ async function prepareMessages(
217
  return Promise.all(
218
  messages.map(async (message) => {
219
  if (message.from === "user" && isMultimodal) {
220
- const parts = [
221
- { type: "text" as const, text: message.content },
222
- ...(await prepareFiles(imageProcessor, message.files ?? [])),
223
- ];
224
- return { role: message.from, content: parts };
225
  }
226
  return { role: message.from, content: message.content };
227
  })
 
217
  return Promise.all(
218
  messages.map(async (message) => {
219
  if (message.from === "user" && isMultimodal) {
220
+ const imageParts = await prepareFiles(imageProcessor, message.files ?? []);
221
+ if (imageParts.length) {
222
+ const parts = [{ type: "text" as const, text: message.content }, ...imageParts];
223
+ return { role: message.from, content: parts };
224
+ }
225
  }
226
  return { role: message.from, content: message.content };
227
  })
src/lib/server/models.ts CHANGED
@@ -288,6 +288,8 @@ const archBase = (config.LLM_ROUTER_ARCH_BASE_URL || "").trim();
288
  const routerLabel = (config.PUBLIC_LLM_ROUTER_DISPLAY_NAME || "Omni").trim() || "Omni";
289
  const routerLogo = (config.PUBLIC_LLM_ROUTER_LOGO_URL || "").trim();
290
  const routerAliasId = (config.PUBLIC_LLM_ROUTER_ALIAS_ID || "omni").trim() || "omni";
 
 
291
 
292
  let decorated = builtModels as any[];
293
 
@@ -309,6 +311,11 @@ if (archBase) {
309
  unlisted: false,
310
  } as any;
311
 
 
 
 
 
 
312
  const aliasBase = await processModel(aliasRaw);
313
  // Create a self-referential ProcessedModel for the router endpoint
314
  let aliasModel: any = {};
 
288
  const routerLabel = (config.PUBLIC_LLM_ROUTER_DISPLAY_NAME || "Omni").trim() || "Omni";
289
  const routerLogo = (config.PUBLIC_LLM_ROUTER_LOGO_URL || "").trim();
290
  const routerAliasId = (config.PUBLIC_LLM_ROUTER_ALIAS_ID || "omni").trim() || "omni";
291
+ const routerMultimodalEnabled =
292
+ (config.LLM_ROUTER_ENABLE_MULTIMODAL || "").toLowerCase() === "true";
293
 
294
  let decorated = builtModels as any[];
295
 
 
311
  unlisted: false,
312
  } as any;
313
 
314
+ if (routerMultimodalEnabled) {
315
+ aliasRaw.multimodal = true;
316
+ aliasRaw.multimodalAcceptedMimetypes = ["image/*"];
317
+ }
318
+
319
  const aliasBase = await processModel(aliasRaw);
320
  // Create a self-referential ProcessedModel for the router endpoint
321
  let aliasModel: any = {};
src/lib/server/router/endpoint.ts CHANGED
@@ -8,6 +8,8 @@ import { getRoutes, resolveRouteModels } from "./policy";
8
 
9
  const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
10
 
 
 
11
  function stripReasoningBlocks(text: string): string {
12
  const stripped = text.replace(REASONING_BLOCK_REGEX, "");
13
  return stripped === text ? text : stripped.trim();
@@ -31,10 +33,13 @@ export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<E
31
  return async function routerEndpoint(params: EndpointParameters) {
32
  const routes = await getRoutes();
33
  const sanitizedMessages = params.messages.map(stripReasoningFromMessage);
34
- const { routeName } = await archSelectRoute(sanitizedMessages);
35
-
36
- const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
37
- const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
 
 
 
38
 
39
  // Helper to create an OpenAI endpoint for a specific candidate model id
40
  async function createCandidateEndpoint(candidateModelId: string): Promise<Endpoint> {
@@ -69,24 +74,71 @@ export async function makeRouterEndpoint(routerModel: ProcessedModel): Promise<E
69
  }
70
 
71
  // Yield router metadata for immediate UI display, using the actual candidate
72
- async function* metadataThenStream(gen: AsyncGenerator<any>, actualModel: string) {
 
 
 
 
73
  yield {
74
  token: { id: 0, text: "", special: true, logprob: 0 },
75
  generated_text: null,
76
  details: null,
77
- routerMetadata: { route: routeName, model: actualModel },
78
  } as any;
79
  for await (const ev of gen) yield ev;
80
  }
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  let lastErr: any = undefined;
83
  for (const candidate of candidates) {
84
  try {
85
  logger.info({ route: routeName, model: candidate }, "[router] trying candidate");
86
  const ep = await createCandidateEndpoint(candidate);
87
  const gen = await ep({ ...params });
88
- // Yield metadata with the actual candidate used
89
- return metadataThenStream(gen, candidate);
90
  } catch (e) {
91
  lastErr = e;
92
  logger.warn(
 
8
 
9
  const REASONING_BLOCK_REGEX = /<think>[\s\S]*?(?:<\/think>|$)/g;
10
 
11
+ const ROUTER_MULTIMODAL_ROUTE = "multimodal";
12
+
13
  function stripReasoningBlocks(text: string): string {
14
  const stripped = text.replace(REASONING_BLOCK_REGEX, "");
15
  return stripped === text ? text : stripped.trim();
 
33
  return async function routerEndpoint(params: EndpointParameters) {
34
  const routes = await getRoutes();
35
  const sanitizedMessages = params.messages.map(stripReasoningFromMessage);
36
+ const routerMultimodalEnabled =
37
+ (config.LLM_ROUTER_ENABLE_MULTIMODAL || "").toLowerCase() === "true";
38
+ const hasImageInput = sanitizedMessages.some((message) =>
39
+ (message.files ?? []).some(
40
+ (file) => typeof file?.mime === "string" && file.mime.startsWith("image/")
41
+ )
42
+ );
43
 
44
  // Helper to create an OpenAI endpoint for a specific candidate model id
45
  async function createCandidateEndpoint(candidateModelId: string): Promise<Endpoint> {
 
74
  }
75
 
76
  // Yield router metadata for immediate UI display, using the actual candidate
77
+ async function* metadataThenStream(
78
+ gen: AsyncGenerator<any>,
79
+ actualModel: string,
80
+ selectedRoute: string
81
+ ) {
82
  yield {
83
  token: { id: 0, text: "", special: true, logprob: 0 },
84
  generated_text: null,
85
  details: null,
86
+ routerMetadata: { route: selectedRoute, model: actualModel },
87
  } as any;
88
  for await (const ev of gen) yield ev;
89
  }
90
 
91
+ async function findFirstMultimodalCandidateId(): Promise<string | undefined> {
92
+ try {
93
+ const mod = await import("../models");
94
+ const all = (mod as any).models as ProcessedModel[];
95
+ const first = all?.find((m) => !m.isRouter && m.multimodal);
96
+ return first?.id ?? first?.name;
97
+ } catch (e) {
98
+ logger.warn({ err: String(e) }, "[router] failed to load models for multimodal lookup");
99
+ return undefined;
100
+ }
101
+ }
102
+
103
+ if (routerMultimodalEnabled && hasImageInput) {
104
+ const multimodalCandidate = await findFirstMultimodalCandidateId();
105
+ if (!multimodalCandidate) {
106
+ throw new Error(
107
+ "No multimodal models are configured for the router. Remove the image or enable a multimodal model."
108
+ );
109
+ }
110
+
111
+ try {
112
+ logger.info(
113
+ { route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate },
114
+ "[router] multimodal input detected; bypassing Arch selection"
115
+ );
116
+ const ep = await createCandidateEndpoint(multimodalCandidate);
117
+ const gen = await ep({ ...params });
118
+ return metadataThenStream(gen, multimodalCandidate, ROUTER_MULTIMODAL_ROUTE);
119
+ } catch (e) {
120
+ logger.error(
121
+ { route: ROUTER_MULTIMODAL_ROUTE, model: multimodalCandidate, err: String(e) },
122
+ "[router] multimodal fallback failed"
123
+ );
124
+ throw new Error(
125
+ "Failed to call the configured multimodal model. Remove the image or try again later."
126
+ );
127
+ }
128
+ }
129
+
130
+ const { routeName } = await archSelectRoute(sanitizedMessages);
131
+
132
+ const fallbackModel = config.LLM_ROUTER_FALLBACK_MODEL || routerModel.id;
133
+ const { candidates } = resolveRouteModels(routeName, routes, fallbackModel);
134
+
135
  let lastErr: any = undefined;
136
  for (const candidate of candidates) {
137
  try {
138
  logger.info({ route: routeName, model: candidate }, "[router] trying candidate");
139
  const ep = await createCandidateEndpoint(candidate);
140
  const gen = await ep({ ...params });
141
+ return metadataThenStream(gen, candidate, routeName);
 
142
  } catch (e) {
143
  lastErr = e;
144
  logger.warn(