incognitolm commited on
Commit
c70b348
·
1 Parent(s): 0b4a6bb

Update chatStream.js

Browse files
Files changed (1) hide show
  1. server/chatStream.js +122 -25
server/chatStream.js CHANGED
@@ -409,12 +409,54 @@ function contentToText(content, { preview = false } = {}) {
409
  return pieces.join(preview ? " " : "\n");
410
  }
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  function normalizeStoredToolCalls(toolCalls = []) {
413
  return toolCalls.map((call) => ({
414
  id: call.id || `call_${crypto.randomUUID()}`,
415
  type: "function",
416
  function: {
417
- name: call.name || call.function?.name || "unknown_tool",
418
  arguments: (() => {
419
  const rawArgs = call.args ?? call.function?.arguments ?? {};
420
  return typeof rawArgs === "string" ? rawArgs : JSON.stringify(rawArgs);
@@ -991,9 +1033,20 @@ async function websocketChatStreamWithRetry(body, headers, onToken, abortSignal)
991
  }
992
  }
993
 
 
 
 
 
 
 
 
 
 
 
994
  export async function websocketChatStream(body, headers, onToken, abortSignal) {
995
  const ws = await getSafeWebSocket();
996
  const currentRequestId = ++requestIdCounter;
 
997
  const safeParse = (str) => {
998
  try { return JSON.parse(str.startsWith("data: ") ? str.slice(6) : str); } catch { return null; }
999
  };
@@ -1007,11 +1060,7 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
1007
  if (!finished) {
1008
  finished = true;
1009
  cleanup();
1010
- const toolCalls = [...toolCallBuffer.values()].map((t) => ({
1011
- id: t.id || `call_${crypto.randomUUID()}`,
1012
- type: "function",
1013
- function: { name: t.name, arguments: t.arguments },
1014
- }));
1015
  resolve({ assistantText, toolCalls });
1016
  }
1017
  }, 120000);
@@ -1103,7 +1152,9 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
1103
  for (const call of delta.tool_calls) {
1104
  const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
1105
  if (call.id) entry.id = call.id;
1106
- if (call.function?.name) entry.name = call.function.name;
 
 
1107
  if (call.function?.arguments) entry.arguments += call.function.arguments;
1108
  toolCallBuffer.set(call.index, entry);
1109
  }
@@ -1112,11 +1163,7 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
1112
  if (payload.choices?.[0]?.finish_reason && !finished) {
1113
  finished = true;
1114
  cleanup();
1115
- const toolCalls = [...toolCallBuffer.values()].map((t) => ({
1116
- id: t.id || `call_${crypto.randomUUID()}`,
1117
- type: "function",
1118
- function: { name: t.name, arguments: t.arguments },
1119
- }));
1120
  resolve({ assistantText, toolCalls });
1121
  }
1122
  };
@@ -1227,6 +1274,37 @@ export async function streamChat({
1227
  }
1228
  }
1229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1230
  const sessionName = extractSessionName(assistantText);
1231
 
1232
  if (typeof onDone === "function") {
@@ -1252,10 +1330,20 @@ function normalizeMessage(msg) {
1252
  if (!VALID_ROLES.has(msg.role)) return null;
1253
 
1254
  if (msg.role === "assistant" && (msg.tool_calls || msg.toolCalls)) {
 
 
 
 
 
 
 
 
 
 
1255
  return {
1256
  role: "assistant",
1257
  content: msg.content ?? "",
1258
- tool_calls: msg.tool_calls || normalizeStoredToolCalls(msg.toolCalls),
1259
  };
1260
  }
1261
  if (Array.isArray(msg.content)) {
@@ -1452,6 +1540,7 @@ async function processToolCalls({
1452
  }) {
1453
  const nextMessages = [];
1454
  const authHeaders = {};
 
1455
  if (accessToken) {
1456
  authHeaders["Authorization"] = `Bearer ${accessToken}`;
1457
  }
@@ -1462,18 +1551,26 @@ async function processToolCalls({
1462
  for (const call of toolCalls) {
1463
  let args;
1464
  try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; }
 
 
 
 
1465
 
1466
- onToolCall({ id: call.id, name: call.function.name, state: "pending", args });
1467
 
1468
  let result = "Tool completed.";
1469
 
1470
  try {
1471
- if (call.function.name === "list_prompt_resources") {
 
 
 
 
1472
  const state = getPromptState(sessionId);
1473
  result = JSON.stringify(buildPromptResourceManifest(state), null, 2);
1474
  }
1475
 
1476
- else if (call.function.name === "read_prompt_chunk") {
1477
  const state = getPromptState(sessionId);
1478
  const chunkIndex = Number.isInteger(Number(args.chunk_index))
1479
  ? Number(args.chunk_index)
@@ -1500,7 +1597,7 @@ async function processToolCalls({
1500
  }
1501
  }
1502
 
1503
- else if (call.function.name === "load_prompt_images") {
1504
  const state = getPromptState(sessionId);
1505
  const resource = findPromptResource(state, args.resource_id);
1506
 
@@ -1533,22 +1630,22 @@ async function processToolCalls({
1533
  ...selectedImages,
1534
  ],
1535
  });
1536
- onToolCall({ id: call.id, name: call.function.name, state: "resolved", result });
1537
  continue;
1538
  }
1539
  }
1540
  }
1541
 
1542
- else if (call.function.name === "write_notes") {
1543
  appendAssistantNote(sessionId, args.note);
1544
  result = "Note stored for the rest of this response.";
1545
  }
1546
 
1547
- else if (call.function.name === "ollama_search") {
1548
  result = await gradioSearch(args.query);
1549
  }
1550
 
1551
- else if (call.function.name === "read_web_page") {
1552
  const { convert } = await import("html-to-text");
1553
  const res = await fetch(args.url, { signal: abortSignal });
1554
  if (!res.ok) {
@@ -1563,7 +1660,7 @@ async function processToolCalls({
1563
  }
1564
  }
1565
 
1566
- else if (call.function.name === "generate_image") {
1567
  const body = { prompt: args.prompt };
1568
  if (args.mode) body.mode = args.mode;
1569
  if (args.image_urls?.length) body.image_urls = args.image_urls;
@@ -1590,7 +1687,7 @@ async function processToolCalls({
1590
  }
1591
  }
1592
 
1593
- else if (call.function.name === "generate_video") {
1594
  const body = { prompt: args.prompt };
1595
  if (args.ratio) body.ratio = args.ratio;
1596
  if (args.mode) body.mode = args.mode;
@@ -1618,7 +1715,7 @@ async function processToolCalls({
1618
  }
1619
  }
1620
 
1621
- else if (call.function.name === "generate_audio") {
1622
  const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, {
1623
  method: "POST",
1624
  headers: { "Content-Type": "application/json", ...authHeaders },
@@ -1641,7 +1738,7 @@ async function processToolCalls({
1641
  result = `Tool error: ${String(err)}`;
1642
  }
1643
 
1644
- onToolCall({ id: call.id, name: call.function.name, state: "resolved", result });
1645
 
1646
  nextMessages.push({
1647
  role: "tool",
 
409
  return pieces.join(preview ? " " : "\n");
410
  }
411
 
412
+ function getAllowedToolNames(toolDefs = []) {
413
+ return toolDefs
414
+ .map((tool) => tool?.function?.name)
415
+ .filter(Boolean);
416
+ }
417
+
418
+ function sanitizeToolName(rawName, allowedToolNames = []) {
419
+ if (!rawName) return null;
420
+
421
+ let name = String(rawName).trim();
422
+ if (!name) return null;
423
+
424
+ name = name.replace(/<\|[\s\S]*$/, "").trim();
425
+ name = name.replace(/^["'`]+|["'`]+$/g, "");
426
+
427
+ const identifierMatch = name.match(/^[A-Za-z0-9_-]+/);
428
+ if (identifierMatch) {
429
+ name = identifierMatch[0];
430
+ }
431
+
432
+ if (!allowedToolNames.length) {
433
+ return name || null;
434
+ }
435
+
436
+ if (allowedToolNames.includes(name)) {
437
+ return name;
438
+ }
439
+
440
+ const normalizedRaw = String(rawName).toLowerCase();
441
+ const prefixedMatch = allowedToolNames.find((allowedName) =>
442
+ normalizedRaw.startsWith(allowedName.toLowerCase())
443
+ );
444
+ if (prefixedMatch) {
445
+ return prefixedMatch;
446
+ }
447
+
448
+ const embeddedMatch = allowedToolNames.find((allowedName) =>
449
+ normalizedRaw.includes(allowedName.toLowerCase())
450
+ );
451
+ return embeddedMatch || null;
452
+ }
453
+
454
  function normalizeStoredToolCalls(toolCalls = []) {
455
  return toolCalls.map((call) => ({
456
  id: call.id || `call_${crypto.randomUUID()}`,
457
  type: "function",
458
  function: {
459
+ name: sanitizeToolName(call.name || call.function?.name || "unknown_tool"),
460
  arguments: (() => {
461
  const rawArgs = call.args ?? call.function?.arguments ?? {};
462
  return typeof rawArgs === "string" ? rawArgs : JSON.stringify(rawArgs);
 
1033
  }
1034
  }
1035
 
1036
+ function serializeToolCalls(toolCallBuffer) {
1037
+ return [...toolCallBuffer.values()]
1038
+ .filter((toolCall) => toolCall?.name)
1039
+ .map((toolCall) => ({
1040
+ id: toolCall.id || `call_${crypto.randomUUID()}`,
1041
+ type: "function",
1042
+ function: { name: toolCall.name, arguments: toolCall.arguments },
1043
+ }));
1044
+ }
1045
+
1046
  export async function websocketChatStream(body, headers, onToken, abortSignal) {
1047
  const ws = await getSafeWebSocket();
1048
  const currentRequestId = ++requestIdCounter;
1049
+ const allowedToolNames = getAllowedToolNames(body?.tools);
1050
  const safeParse = (str) => {
1051
  try { return JSON.parse(str.startsWith("data: ") ? str.slice(6) : str); } catch { return null; }
1052
  };
 
1060
  if (!finished) {
1061
  finished = true;
1062
  cleanup();
1063
+ const toolCalls = serializeToolCalls(toolCallBuffer);
 
 
 
 
1064
  resolve({ assistantText, toolCalls });
1065
  }
1066
  }, 120000);
 
1152
  for (const call of delta.tool_calls) {
1153
  const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
1154
  if (call.id) entry.id = call.id;
1155
+ if (call.function?.name) {
1156
+ entry.name = sanitizeToolName(call.function.name, allowedToolNames);
1157
+ }
1158
  if (call.function?.arguments) entry.arguments += call.function.arguments;
1159
  toolCallBuffer.set(call.index, entry);
1160
  }
 
1163
  if (payload.choices?.[0]?.finish_reason && !finished) {
1164
  finished = true;
1165
  cleanup();
1166
+ const toolCalls = serializeToolCalls(toolCallBuffer);
 
 
 
 
1167
  resolve({ assistantText, toolCalls });
1168
  }
1169
  };
 
1274
  }
1275
  }
1276
 
1277
+ if (!finished) {
1278
+ const finalMessages = [
1279
+ ...buildModelMessages(baseMessages, workingMessages, sessionId),
1280
+ {
1281
+ role: "system",
1282
+ content: "Tool-use budget is exhausted for this response. Do not call tools. Answer directly using the information already gathered. If something is still missing, briefly say what is missing without calling tools.",
1283
+ },
1284
+ ];
1285
+
1286
+ const { assistantText: finalStepText } = await websocketChatStreamWithRetry(
1287
+ {
1288
+ model: model || "lightning",
1289
+ messages: finalMessages,
1290
+ stream: true,
1291
+ },
1292
+ headers,
1293
+ onToken,
1294
+ abortSignal
1295
+ );
1296
+
1297
+ if (finalStepText) {
1298
+ assistantText += finalStepText;
1299
+ workingMessages.push({ role: "assistant", content: finalStepText });
1300
+ }
1301
+ finished = true;
1302
+ }
1303
+
1304
+ if (!assistantText.trim()) {
1305
+ assistantText = "I wasn’t able to finish that response cleanly. Please try again.";
1306
+ }
1307
+
1308
  const sessionName = extractSessionName(assistantText);
1309
 
1310
  if (typeof onDone === "function") {
 
1330
  if (!VALID_ROLES.has(msg.role)) return null;
1331
 
1332
  if (msg.role === "assistant" && (msg.tool_calls || msg.toolCalls)) {
1333
+ const normalizedToolCalls = (msg.tool_calls || normalizeStoredToolCalls(msg.toolCalls))
1334
+ .map((call) => ({
1335
+ ...call,
1336
+ function: {
1337
+ ...call.function,
1338
+ name: sanitizeToolName(call.function?.name || call.name || "unknown_tool"),
1339
+ },
1340
+ }))
1341
+ .filter((call) => call.function?.name);
1342
+
1343
  return {
1344
  role: "assistant",
1345
  content: msg.content ?? "",
1346
+ ...(normalizedToolCalls.length ? { tool_calls: normalizedToolCalls } : {}),
1347
  };
1348
  }
1349
  if (Array.isArray(msg.content)) {
 
1540
  }) {
1541
  const nextMessages = [];
1542
  const authHeaders = {};
1543
+ const allowedToolNames = getAllowedToolNames(buildToolList(tools));
1544
  if (accessToken) {
1545
  authHeaders["Authorization"] = `Bearer ${accessToken}`;
1546
  }
 
1551
  for (const call of toolCalls) {
1552
  let args;
1553
  try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; }
1554
+ const toolName = sanitizeToolName(call.function?.name, allowedToolNames);
1555
+ if (toolName) {
1556
+ call.function.name = toolName;
1557
+ }
1558
 
1559
+ onToolCall({ id: call.id, name: toolName || call.function?.name, state: "pending", args });
1560
 
1561
  let result = "Tool completed.";
1562
 
1563
  try {
1564
+ if (!toolName || !allowedToolNames.includes(toolName)) {
1565
+ result = `Invalid tool name "${call.function?.name || "unknown"}".`;
1566
+ }
1567
+
1568
+ else if (toolName === "list_prompt_resources") {
1569
  const state = getPromptState(sessionId);
1570
  result = JSON.stringify(buildPromptResourceManifest(state), null, 2);
1571
  }
1572
 
1573
+ else if (toolName === "read_prompt_chunk") {
1574
  const state = getPromptState(sessionId);
1575
  const chunkIndex = Number.isInteger(Number(args.chunk_index))
1576
  ? Number(args.chunk_index)
 
1597
  }
1598
  }
1599
 
1600
+ else if (toolName === "load_prompt_images") {
1601
  const state = getPromptState(sessionId);
1602
  const resource = findPromptResource(state, args.resource_id);
1603
 
 
1630
  ...selectedImages,
1631
  ],
1632
  });
1633
+ onToolCall({ id: call.id, name: toolName, state: "resolved", result });
1634
  continue;
1635
  }
1636
  }
1637
  }
1638
 
1639
+ else if (toolName === "write_notes") {
1640
  appendAssistantNote(sessionId, args.note);
1641
  result = "Note stored for the rest of this response.";
1642
  }
1643
 
1644
+ else if (toolName === "ollama_search") {
1645
  result = await gradioSearch(args.query);
1646
  }
1647
 
1648
+ else if (toolName === "read_web_page") {
1649
  const { convert } = await import("html-to-text");
1650
  const res = await fetch(args.url, { signal: abortSignal });
1651
  if (!res.ok) {
 
1660
  }
1661
  }
1662
 
1663
+ else if (toolName === "generate_image") {
1664
  const body = { prompt: args.prompt };
1665
  if (args.mode) body.mode = args.mode;
1666
  if (args.image_urls?.length) body.image_urls = args.image_urls;
 
1687
  }
1688
  }
1689
 
1690
+ else if (toolName === "generate_video") {
1691
  const body = { prompt: args.prompt };
1692
  if (args.ratio) body.ratio = args.ratio;
1693
  if (args.mode) body.mode = args.mode;
 
1715
  }
1716
  }
1717
 
1718
+ else if (toolName === "generate_audio") {
1719
  const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, {
1720
  method: "POST",
1721
  headers: { "Content-Type": "application/json", ...authHeaders },
 
1738
  result = `Tool error: ${String(err)}`;
1739
  }
1740
 
1741
+ onToolCall({ id: call.id, name: toolName || call.function?.name, state: "resolved", result });
1742
 
1743
  nextMessages.push({
1744
  role: "tool",