Spaces:
Runtime error
Runtime error
incognitolm commited on
Commit ·
c70b348
1
Parent(s): 0b4a6bb
Update chatStream.js
Browse files- server/chatStream.js +122 -25
server/chatStream.js
CHANGED
|
@@ -409,12 +409,54 @@ function contentToText(content, { preview = false } = {}) {
|
|
| 409 |
return pieces.join(preview ? " " : "\n");
|
| 410 |
}
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
function normalizeStoredToolCalls(toolCalls = []) {
|
| 413 |
return toolCalls.map((call) => ({
|
| 414 |
id: call.id || `call_${crypto.randomUUID()}`,
|
| 415 |
type: "function",
|
| 416 |
function: {
|
| 417 |
-
name: call.name || call.function?.name || "unknown_tool",
|
| 418 |
arguments: (() => {
|
| 419 |
const rawArgs = call.args ?? call.function?.arguments ?? {};
|
| 420 |
return typeof rawArgs === "string" ? rawArgs : JSON.stringify(rawArgs);
|
|
@@ -991,9 +1033,20 @@ async function websocketChatStreamWithRetry(body, headers, onToken, abortSignal)
|
|
| 991 |
}
|
| 992 |
}
|
| 993 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 994 |
export async function websocketChatStream(body, headers, onToken, abortSignal) {
|
| 995 |
const ws = await getSafeWebSocket();
|
| 996 |
const currentRequestId = ++requestIdCounter;
|
|
|
|
| 997 |
const safeParse = (str) => {
|
| 998 |
try { return JSON.parse(str.startsWith("data: ") ? str.slice(6) : str); } catch { return null; }
|
| 999 |
};
|
|
@@ -1007,11 +1060,7 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
|
|
| 1007 |
if (!finished) {
|
| 1008 |
finished = true;
|
| 1009 |
cleanup();
|
| 1010 |
-
const toolCalls =
|
| 1011 |
-
id: t.id || `call_${crypto.randomUUID()}`,
|
| 1012 |
-
type: "function",
|
| 1013 |
-
function: { name: t.name, arguments: t.arguments },
|
| 1014 |
-
}));
|
| 1015 |
resolve({ assistantText, toolCalls });
|
| 1016 |
}
|
| 1017 |
}, 120000);
|
|
@@ -1103,7 +1152,9 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
|
|
| 1103 |
for (const call of delta.tool_calls) {
|
| 1104 |
const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
|
| 1105 |
if (call.id) entry.id = call.id;
|
| 1106 |
-
if (call.function?.name)
|
|
|
|
|
|
|
| 1107 |
if (call.function?.arguments) entry.arguments += call.function.arguments;
|
| 1108 |
toolCallBuffer.set(call.index, entry);
|
| 1109 |
}
|
|
@@ -1112,11 +1163,7 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
|
|
| 1112 |
if (payload.choices?.[0]?.finish_reason && !finished) {
|
| 1113 |
finished = true;
|
| 1114 |
cleanup();
|
| 1115 |
-
const toolCalls =
|
| 1116 |
-
id: t.id || `call_${crypto.randomUUID()}`,
|
| 1117 |
-
type: "function",
|
| 1118 |
-
function: { name: t.name, arguments: t.arguments },
|
| 1119 |
-
}));
|
| 1120 |
resolve({ assistantText, toolCalls });
|
| 1121 |
}
|
| 1122 |
};
|
|
@@ -1227,6 +1274,37 @@ export async function streamChat({
|
|
| 1227 |
}
|
| 1228 |
}
|
| 1229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1230 |
const sessionName = extractSessionName(assistantText);
|
| 1231 |
|
| 1232 |
if (typeof onDone === "function") {
|
|
@@ -1252,10 +1330,20 @@ function normalizeMessage(msg) {
|
|
| 1252 |
if (!VALID_ROLES.has(msg.role)) return null;
|
| 1253 |
|
| 1254 |
if (msg.role === "assistant" && (msg.tool_calls || msg.toolCalls)) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1255 |
return {
|
| 1256 |
role: "assistant",
|
| 1257 |
content: msg.content ?? "",
|
| 1258 |
-
tool_calls:
|
| 1259 |
};
|
| 1260 |
}
|
| 1261 |
if (Array.isArray(msg.content)) {
|
|
@@ -1452,6 +1540,7 @@ async function processToolCalls({
|
|
| 1452 |
}) {
|
| 1453 |
const nextMessages = [];
|
| 1454 |
const authHeaders = {};
|
|
|
|
| 1455 |
if (accessToken) {
|
| 1456 |
authHeaders["Authorization"] = `Bearer ${accessToken}`;
|
| 1457 |
}
|
|
@@ -1462,18 +1551,26 @@ async function processToolCalls({
|
|
| 1462 |
for (const call of toolCalls) {
|
| 1463 |
let args;
|
| 1464 |
try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; }
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1465 |
|
| 1466 |
-
onToolCall({ id: call.id, name: call.function.name, state: "pending", args });
|
| 1467 |
|
| 1468 |
let result = "Tool completed.";
|
| 1469 |
|
| 1470 |
try {
|
| 1471 |
-
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1472 |
const state = getPromptState(sessionId);
|
| 1473 |
result = JSON.stringify(buildPromptResourceManifest(state), null, 2);
|
| 1474 |
}
|
| 1475 |
|
| 1476 |
-
else if (
|
| 1477 |
const state = getPromptState(sessionId);
|
| 1478 |
const chunkIndex = Number.isInteger(Number(args.chunk_index))
|
| 1479 |
? Number(args.chunk_index)
|
|
@@ -1500,7 +1597,7 @@ async function processToolCalls({
|
|
| 1500 |
}
|
| 1501 |
}
|
| 1502 |
|
| 1503 |
-
else if (
|
| 1504 |
const state = getPromptState(sessionId);
|
| 1505 |
const resource = findPromptResource(state, args.resource_id);
|
| 1506 |
|
|
@@ -1533,22 +1630,22 @@ async function processToolCalls({
|
|
| 1533 |
...selectedImages,
|
| 1534 |
],
|
| 1535 |
});
|
| 1536 |
-
onToolCall({ id: call.id, name:
|
| 1537 |
continue;
|
| 1538 |
}
|
| 1539 |
}
|
| 1540 |
}
|
| 1541 |
|
| 1542 |
-
else if (
|
| 1543 |
appendAssistantNote(sessionId, args.note);
|
| 1544 |
result = "Note stored for the rest of this response.";
|
| 1545 |
}
|
| 1546 |
|
| 1547 |
-
else if (
|
| 1548 |
result = await gradioSearch(args.query);
|
| 1549 |
}
|
| 1550 |
|
| 1551 |
-
else if (
|
| 1552 |
const { convert } = await import("html-to-text");
|
| 1553 |
const res = await fetch(args.url, { signal: abortSignal });
|
| 1554 |
if (!res.ok) {
|
|
@@ -1563,7 +1660,7 @@ async function processToolCalls({
|
|
| 1563 |
}
|
| 1564 |
}
|
| 1565 |
|
| 1566 |
-
else if (
|
| 1567 |
const body = { prompt: args.prompt };
|
| 1568 |
if (args.mode) body.mode = args.mode;
|
| 1569 |
if (args.image_urls?.length) body.image_urls = args.image_urls;
|
|
@@ -1590,7 +1687,7 @@ async function processToolCalls({
|
|
| 1590 |
}
|
| 1591 |
}
|
| 1592 |
|
| 1593 |
-
else if (
|
| 1594 |
const body = { prompt: args.prompt };
|
| 1595 |
if (args.ratio) body.ratio = args.ratio;
|
| 1596 |
if (args.mode) body.mode = args.mode;
|
|
@@ -1618,7 +1715,7 @@ async function processToolCalls({
|
|
| 1618 |
}
|
| 1619 |
}
|
| 1620 |
|
| 1621 |
-
else if (
|
| 1622 |
const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, {
|
| 1623 |
method: "POST",
|
| 1624 |
headers: { "Content-Type": "application/json", ...authHeaders },
|
|
@@ -1641,7 +1738,7 @@ async function processToolCalls({
|
|
| 1641 |
result = `Tool error: ${String(err)}`;
|
| 1642 |
}
|
| 1643 |
|
| 1644 |
-
onToolCall({ id: call.id, name: call.function.name, state: "resolved", result });
|
| 1645 |
|
| 1646 |
nextMessages.push({
|
| 1647 |
role: "tool",
|
|
|
|
| 409 |
return pieces.join(preview ? " " : "\n");
|
| 410 |
}
|
| 411 |
|
| 412 |
+
function getAllowedToolNames(toolDefs = []) {
|
| 413 |
+
return toolDefs
|
| 414 |
+
.map((tool) => tool?.function?.name)
|
| 415 |
+
.filter(Boolean);
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
function sanitizeToolName(rawName, allowedToolNames = []) {
|
| 419 |
+
if (!rawName) return null;
|
| 420 |
+
|
| 421 |
+
let name = String(rawName).trim();
|
| 422 |
+
if (!name) return null;
|
| 423 |
+
|
| 424 |
+
name = name.replace(/<\|[\s\S]*$/, "").trim();
|
| 425 |
+
name = name.replace(/^["'`]+|["'`]+$/g, "");
|
| 426 |
+
|
| 427 |
+
const identifierMatch = name.match(/^[A-Za-z0-9_-]+/);
|
| 428 |
+
if (identifierMatch) {
|
| 429 |
+
name = identifierMatch[0];
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
if (!allowedToolNames.length) {
|
| 433 |
+
return name || null;
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
if (allowedToolNames.includes(name)) {
|
| 437 |
+
return name;
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
const normalizedRaw = String(rawName).toLowerCase();
|
| 441 |
+
const prefixedMatch = allowedToolNames.find((allowedName) =>
|
| 442 |
+
normalizedRaw.startsWith(allowedName.toLowerCase())
|
| 443 |
+
);
|
| 444 |
+
if (prefixedMatch) {
|
| 445 |
+
return prefixedMatch;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
const embeddedMatch = allowedToolNames.find((allowedName) =>
|
| 449 |
+
normalizedRaw.includes(allowedName.toLowerCase())
|
| 450 |
+
);
|
| 451 |
+
return embeddedMatch || null;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
function normalizeStoredToolCalls(toolCalls = []) {
|
| 455 |
return toolCalls.map((call) => ({
|
| 456 |
id: call.id || `call_${crypto.randomUUID()}`,
|
| 457 |
type: "function",
|
| 458 |
function: {
|
| 459 |
+
name: sanitizeToolName(call.name || call.function?.name || "unknown_tool"),
|
| 460 |
arguments: (() => {
|
| 461 |
const rawArgs = call.args ?? call.function?.arguments ?? {};
|
| 462 |
return typeof rawArgs === "string" ? rawArgs : JSON.stringify(rawArgs);
|
|
|
|
| 1033 |
}
|
| 1034 |
}
|
| 1035 |
|
| 1036 |
+
function serializeToolCalls(toolCallBuffer) {
|
| 1037 |
+
return [...toolCallBuffer.values()]
|
| 1038 |
+
.filter((toolCall) => toolCall?.name)
|
| 1039 |
+
.map((toolCall) => ({
|
| 1040 |
+
id: toolCall.id || `call_${crypto.randomUUID()}`,
|
| 1041 |
+
type: "function",
|
| 1042 |
+
function: { name: toolCall.name, arguments: toolCall.arguments },
|
| 1043 |
+
}));
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
export async function websocketChatStream(body, headers, onToken, abortSignal) {
|
| 1047 |
const ws = await getSafeWebSocket();
|
| 1048 |
const currentRequestId = ++requestIdCounter;
|
| 1049 |
+
const allowedToolNames = getAllowedToolNames(body?.tools);
|
| 1050 |
const safeParse = (str) => {
|
| 1051 |
try { return JSON.parse(str.startsWith("data: ") ? str.slice(6) : str); } catch { return null; }
|
| 1052 |
};
|
|
|
|
| 1060 |
if (!finished) {
|
| 1061 |
finished = true;
|
| 1062 |
cleanup();
|
| 1063 |
+
const toolCalls = serializeToolCalls(toolCallBuffer);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1064 |
resolve({ assistantText, toolCalls });
|
| 1065 |
}
|
| 1066 |
}, 120000);
|
|
|
|
| 1152 |
for (const call of delta.tool_calls) {
|
| 1153 |
const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
|
| 1154 |
if (call.id) entry.id = call.id;
|
| 1155 |
+
if (call.function?.name) {
|
| 1156 |
+
entry.name = sanitizeToolName(call.function.name, allowedToolNames);
|
| 1157 |
+
}
|
| 1158 |
if (call.function?.arguments) entry.arguments += call.function.arguments;
|
| 1159 |
toolCallBuffer.set(call.index, entry);
|
| 1160 |
}
|
|
|
|
| 1163 |
if (payload.choices?.[0]?.finish_reason && !finished) {
|
| 1164 |
finished = true;
|
| 1165 |
cleanup();
|
| 1166 |
+
const toolCalls = serializeToolCalls(toolCallBuffer);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1167 |
resolve({ assistantText, toolCalls });
|
| 1168 |
}
|
| 1169 |
};
|
|
|
|
| 1274 |
}
|
| 1275 |
}
|
| 1276 |
|
| 1277 |
+
if (!finished) {
|
| 1278 |
+
const finalMessages = [
|
| 1279 |
+
...buildModelMessages(baseMessages, workingMessages, sessionId),
|
| 1280 |
+
{
|
| 1281 |
+
role: "system",
|
| 1282 |
+
content: "Tool-use budget is exhausted for this response. Do not call tools. Answer directly using the information already gathered. If something is still missing, briefly say what is missing without calling tools.",
|
| 1283 |
+
},
|
| 1284 |
+
];
|
| 1285 |
+
|
| 1286 |
+
const { assistantText: finalStepText } = await websocketChatStreamWithRetry(
|
| 1287 |
+
{
|
| 1288 |
+
model: model || "lightning",
|
| 1289 |
+
messages: finalMessages,
|
| 1290 |
+
stream: true,
|
| 1291 |
+
},
|
| 1292 |
+
headers,
|
| 1293 |
+
onToken,
|
| 1294 |
+
abortSignal
|
| 1295 |
+
);
|
| 1296 |
+
|
| 1297 |
+
if (finalStepText) {
|
| 1298 |
+
assistantText += finalStepText;
|
| 1299 |
+
workingMessages.push({ role: "assistant", content: finalStepText });
|
| 1300 |
+
}
|
| 1301 |
+
finished = true;
|
| 1302 |
+
}
|
| 1303 |
+
|
| 1304 |
+
if (!assistantText.trim()) {
|
| 1305 |
+
assistantText = "I wasn’t able to finish that response cleanly. Please try again.";
|
| 1306 |
+
}
|
| 1307 |
+
|
| 1308 |
const sessionName = extractSessionName(assistantText);
|
| 1309 |
|
| 1310 |
if (typeof onDone === "function") {
|
|
|
|
| 1330 |
if (!VALID_ROLES.has(msg.role)) return null;
|
| 1331 |
|
| 1332 |
if (msg.role === "assistant" && (msg.tool_calls || msg.toolCalls)) {
|
| 1333 |
+
const normalizedToolCalls = (msg.tool_calls || normalizeStoredToolCalls(msg.toolCalls))
|
| 1334 |
+
.map((call) => ({
|
| 1335 |
+
...call,
|
| 1336 |
+
function: {
|
| 1337 |
+
...call.function,
|
| 1338 |
+
name: sanitizeToolName(call.function?.name || call.name || "unknown_tool"),
|
| 1339 |
+
},
|
| 1340 |
+
}))
|
| 1341 |
+
.filter((call) => call.function?.name);
|
| 1342 |
+
|
| 1343 |
return {
|
| 1344 |
role: "assistant",
|
| 1345 |
content: msg.content ?? "",
|
| 1346 |
+
...(normalizedToolCalls.length ? { tool_calls: normalizedToolCalls } : {}),
|
| 1347 |
};
|
| 1348 |
}
|
| 1349 |
if (Array.isArray(msg.content)) {
|
|
|
|
| 1540 |
}) {
|
| 1541 |
const nextMessages = [];
|
| 1542 |
const authHeaders = {};
|
| 1543 |
+
const allowedToolNames = getAllowedToolNames(buildToolList(tools));
|
| 1544 |
if (accessToken) {
|
| 1545 |
authHeaders["Authorization"] = `Bearer ${accessToken}`;
|
| 1546 |
}
|
|
|
|
| 1551 |
for (const call of toolCalls) {
|
| 1552 |
let args;
|
| 1553 |
try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; }
|
| 1554 |
+
const toolName = sanitizeToolName(call.function?.name, allowedToolNames);
|
| 1555 |
+
if (toolName) {
|
| 1556 |
+
call.function.name = toolName;
|
| 1557 |
+
}
|
| 1558 |
|
| 1559 |
+
onToolCall({ id: call.id, name: toolName || call.function?.name, state: "pending", args });
|
| 1560 |
|
| 1561 |
let result = "Tool completed.";
|
| 1562 |
|
| 1563 |
try {
|
| 1564 |
+
if (!toolName || !allowedToolNames.includes(toolName)) {
|
| 1565 |
+
result = `Invalid tool name "${call.function?.name || "unknown"}".`;
|
| 1566 |
+
}
|
| 1567 |
+
|
| 1568 |
+
else if (toolName === "list_prompt_resources") {
|
| 1569 |
const state = getPromptState(sessionId);
|
| 1570 |
result = JSON.stringify(buildPromptResourceManifest(state), null, 2);
|
| 1571 |
}
|
| 1572 |
|
| 1573 |
+
else if (toolName === "read_prompt_chunk") {
|
| 1574 |
const state = getPromptState(sessionId);
|
| 1575 |
const chunkIndex = Number.isInteger(Number(args.chunk_index))
|
| 1576 |
? Number(args.chunk_index)
|
|
|
|
| 1597 |
}
|
| 1598 |
}
|
| 1599 |
|
| 1600 |
+
else if (toolName === "load_prompt_images") {
|
| 1601 |
const state = getPromptState(sessionId);
|
| 1602 |
const resource = findPromptResource(state, args.resource_id);
|
| 1603 |
|
|
|
|
| 1630 |
...selectedImages,
|
| 1631 |
],
|
| 1632 |
});
|
| 1633 |
+
onToolCall({ id: call.id, name: toolName, state: "resolved", result });
|
| 1634 |
continue;
|
| 1635 |
}
|
| 1636 |
}
|
| 1637 |
}
|
| 1638 |
|
| 1639 |
+
else if (toolName === "write_notes") {
|
| 1640 |
appendAssistantNote(sessionId, args.note);
|
| 1641 |
result = "Note stored for the rest of this response.";
|
| 1642 |
}
|
| 1643 |
|
| 1644 |
+
else if (toolName === "ollama_search") {
|
| 1645 |
result = await gradioSearch(args.query);
|
| 1646 |
}
|
| 1647 |
|
| 1648 |
+
else if (toolName === "read_web_page") {
|
| 1649 |
const { convert } = await import("html-to-text");
|
| 1650 |
const res = await fetch(args.url, { signal: abortSignal });
|
| 1651 |
if (!res.ok) {
|
|
|
|
| 1660 |
}
|
| 1661 |
}
|
| 1662 |
|
| 1663 |
+
else if (toolName === "generate_image") {
|
| 1664 |
const body = { prompt: args.prompt };
|
| 1665 |
if (args.mode) body.mode = args.mode;
|
| 1666 |
if (args.image_urls?.length) body.image_urls = args.image_urls;
|
|
|
|
| 1687 |
}
|
| 1688 |
}
|
| 1689 |
|
| 1690 |
+
else if (toolName === "generate_video") {
|
| 1691 |
const body = { prompt: args.prompt };
|
| 1692 |
if (args.ratio) body.ratio = args.ratio;
|
| 1693 |
if (args.mode) body.mode = args.mode;
|
|
|
|
| 1715 |
}
|
| 1716 |
}
|
| 1717 |
|
| 1718 |
+
else if (toolName === "generate_audio") {
|
| 1719 |
const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, {
|
| 1720 |
method: "POST",
|
| 1721 |
headers: { "Content-Type": "application/json", ...authHeaders },
|
|
|
|
| 1738 |
result = `Tool error: ${String(err)}`;
|
| 1739 |
}
|
| 1740 |
|
| 1741 |
+
onToolCall({ id: call.id, name: toolName || call.function?.name, state: "resolved", result });
|
| 1742 |
|
| 1743 |
nextMessages.push({
|
| 1744 |
role: "tool",
|