incognitolm commited on
Commit
2ec195c
Β·
1 Parent(s): 748cfb2

Update chatStream.js

Browse files
Files changed (1) hide show
  1. server/chatStream.js +84 -49
server/chatStream.js CHANGED
@@ -1,23 +1,55 @@
1
  import OpenAI from "openai";
2
- import { LIGHTNING_BASE, SEARCH_API_BASE } from "./config.js";
3
-
4
- async function apiSearch(query) {
5
- const res = await fetch(`${SEARCH_API_BASE}/api/search`, {
6
- method: "POST",
7
- headers: {
8
- "Content-Type": "application/json",
9
- },
10
- body: JSON.stringify({
11
- query,
12
- max_results: 1,
13
- }),
14
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- if (!res.ok) {
17
- throw new Error(`Search API error: ${res.status}`);
18
- }
 
19
 
20
- return await res.json();
 
 
 
 
 
 
21
  }
22
 
23
  const SYSTEM_PROMPT =
@@ -36,17 +68,26 @@ const SYSTEM_PROMPT =
36
  "Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " +
37
  "Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.";
38
 
 
 
 
 
39
  function makeClient(accessToken, clientId) {
40
  return new OpenAI({
41
  apiKey: accessToken || "no-key",
42
  baseURL: `${LIGHTNING_BASE}/gen`,
43
  defaultHeaders: {
44
  ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
45
- ...(clientId ? { "X-Client-ID": clientId } : {}),
46
  },
47
  });
48
  }
49
 
 
 
 
 
 
50
  async function consumeStream(stream, onToken) {
51
  let assistantText = "";
52
  const toolCallBuffer = new Map();
@@ -63,17 +104,17 @@ async function consumeStream(stream, onToken) {
63
  if (delta.tool_calls) {
64
  for (const call of delta.tool_calls) {
65
  const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
66
- if (call.id) entry.id = call.id;
67
- if (call.function?.name) entry.name = call.function.name;
68
  if (call.function?.arguments) entry.arguments += call.function.arguments;
69
  toolCallBuffer.set(call.index, entry);
70
  }
71
  }
72
  }
73
 
74
- const toolCalls = [...toolCallBuffer.values()].map((t) => ({
75
- id: t.id || `call_${crypto.randomUUID()}`,
76
- type: "function",
77
  function: { name: t.name, arguments: t.arguments },
78
  }));
79
 
@@ -96,7 +137,7 @@ export async function streamChat(ws, {
96
  abortSignal,
97
  }) {
98
  if (!accessToken) console.log("Missing access token");
99
- const client = makeClient(accessToken, clientId);
100
  const enabledTools = buildToolList(tools);
101
 
102
  const messages = [
@@ -106,30 +147,32 @@ export async function streamChat(ws, {
106
  ];
107
 
108
  try {
 
109
  const stream = await client.chat.completions.create({
110
- model: model || "lightning",
111
  messages,
112
- tools: enabledTools.length > 0 ? enabledTools : undefined,
113
  stream: true,
114
  }, { signal: abortSignal });
115
 
116
  let { assistantText, toolCalls } = await consumeStream(stream, onToken);
117
 
 
118
  if (toolCalls.length > 0) {
119
  const toolResults = await processToolCalls(
120
  ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset,
121
  );
122
 
123
  const followUpMessages = [
124
- { role: "system", content: SYSTEM_PROMPT },
125
  ...history.map(normalizeMessage).filter(Boolean),
126
- { role: "user", content: userMessage },
127
  { role: "assistant", content: assistantText || "", tool_calls: toolCalls },
128
  ...toolResults,
129
  ];
130
 
131
  const followUpStream = await client.chat.completions.create({
132
- model: model || "lightning",
133
  messages: followUpMessages,
134
  stream: true,
135
  }, { signal: abortSignal });
@@ -141,7 +184,7 @@ export async function streamChat(ws, {
141
  onDone(assistantText, toolCalls);
142
  } catch (err) {
143
  if (err.name === "AbortError" || err.constructor?.name === "APIUserAbortError") {
144
- onDone(null, null, true);
145
  } else {
146
  onError(String(err));
147
  }
@@ -156,7 +199,6 @@ function normalizeMessage(msg) {
156
  if (msg.role === "assistant" && msg.tool_calls) {
157
  return { role: "assistant", content: "", tool_calls: msg.tool_calls };
158
  }
159
-
160
  if (Array.isArray(msg.content)) {
161
  const textOnly = msg.content
162
  .filter(b => b.type === "text")
@@ -164,14 +206,12 @@ function normalizeMessage(msg) {
164
  .join("\n");
165
  return { role: msg.role, content: textOnly || "" };
166
  }
167
-
168
  return { role: msg.role, content: msg.content };
169
  }
170
 
171
  function buildToolList(tools) {
172
  if (!tools) return [];
173
  const list = [];
174
-
175
  if (tools.webSearch) {
176
  list.push({
177
  type: "function",
@@ -185,7 +225,6 @@ function buildToolList(tools) {
185
  },
186
  },
187
  });
188
-
189
  list.push({
190
  type: "function",
191
  function: {
@@ -199,7 +238,6 @@ function buildToolList(tools) {
199
  },
200
  });
201
  }
202
-
203
  if (tools.imageGen) {
204
  list.push({
205
  type: "function",
@@ -218,7 +256,6 @@ function buildToolList(tools) {
218
  },
219
  });
220
  }
221
-
222
  if (tools.videoGen) {
223
  list.push({
224
  type: "function",
@@ -239,7 +276,6 @@ function buildToolList(tools) {
239
  },
240
  });
241
  }
242
-
243
  if (tools.audioGen) {
244
  list.push({
245
  type: "function",
@@ -254,16 +290,22 @@ function buildToolList(tools) {
254
  },
255
  });
256
  }
257
-
258
  return list;
259
  }
260
 
261
  async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) {
262
  const toolResults = [];
263
  const authHeaders = {};
264
-
265
- if (accessToken) authHeaders["Authorization"] = `Bearer ${accessToken}`;
266
- if (clientId) authHeaders["X-Client-ID"] = clientId;
 
 
 
 
 
 
 
267
 
268
  for (const call of toolCalls) {
269
  let args;
@@ -274,21 +316,18 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
274
  let result = "Tool completed.";
275
 
276
  try {
277
-
278
  if (call.function.name === "ollama_search") {
279
- result = await apiSearch(args.query);
280
  }
281
 
282
  else if (call.function.name === "read_web_page") {
283
  const { convert } = await import("html-to-text");
284
  const res = await fetch(args.url, { signal: abortSignal });
285
-
286
  if (!res.ok) {
287
  result = `Failed to fetch: ${res.status}`;
288
  } else {
289
  const html = await res.text();
290
  const titleMatch = html.match(/<title>(.*?)<\/title>/i);
291
-
292
  result = JSON.stringify({
293
  title: titleMatch?.[1] || "No title",
294
  content: convert(html, { wordwrap: false }).slice(0, 8000),
@@ -307,7 +346,6 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
307
  body: JSON.stringify(body),
308
  signal: abortSignal,
309
  });
310
-
311
  if (res.ok) {
312
  const buf = await res.arrayBuffer();
313
  const ct = res.headers.get("content-type") || "image/png";
@@ -337,7 +375,6 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
337
  body: JSON.stringify(body),
338
  signal: abortSignal,
339
  });
340
-
341
  if (res.ok) {
342
  const buf = await res.arrayBuffer();
343
  const b64 = Buffer.from(buf).toString("base64");
@@ -360,7 +397,6 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
360
  body: JSON.stringify({ prompt: args.prompt }),
361
  signal: abortSignal,
362
  });
363
-
364
  if (res.ok) {
365
  const buf = await res.arrayBuffer();
366
  const b64 = Buffer.from(buf).toString("base64");
@@ -373,7 +409,6 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
373
  result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`;
374
  }
375
  }
376
-
377
  } catch (err) {
378
  console.log(`Tool error: ${String(err)}`);
379
  result = `Tool error: ${String(err)}`;
 
1
  import OpenAI from "openai";
2
+ import { Worker } from "worker_threads";
3
+ import { fileURLToPath } from "url";
4
+ import path from "path";
5
+ import { LIGHTNING_BASE } from "./config.js";
6
+
7
+ // ── Web Search via an isolated Worker thread ──────────────────────────────
8
+ //
9
+ // The @gradio/client library opens a persistent SSE (Server-Sent Events)
10
+ // fetch stream for its session queue. Even after client.close() is called,
11
+ // the SSE response-body reader keeps an async iterator alive in the current
12
+ // event loop. When that iterator eventually settles (stream closed/errored
13
+ // by the remote server), it triggers internal callbacks that emit events
14
+ // on objects the Node `ws` library also watches β€” causing the browser to
15
+ // see a "connection lost" message immediately after every web search.
16
+ //
17
+ // The only reliable fix is to run the Gradio client in a worker_thread so
18
+ // it gets its own V8 context and event loop. When the worker exits (via
19
+ // process.exit(0) in searchWorker.js), every handle it opened β€” SSE stream,
20
+ // heartbeat timer, etc. β€” is torn down with it, leaving the main thread's
21
+ // WS server completely untouched.
22
+
23
+ const __dirname = path.dirname(fileURLToPath(import.meta.url));
24
+ const WORKER_PATH = path.join(__dirname, "searchWorker.js");
25
+
26
+ function gradioSearch(query) {
27
+ return new Promise((resolve, reject) => {
28
+ const worker = new Worker(WORKER_PATH, { workerData: { query } });
29
+
30
+ const timeout = setTimeout(() => {
31
+ worker.terminate();
32
+ reject(new Error("Search timed out after 45s"));
33
+ }, 45_000);
34
+
35
+ worker.on("message", (msg) => {
36
+ clearTimeout(timeout);
37
+ if (msg.ok) resolve(msg.result);
38
+ else reject(new Error(msg.error));
39
+ });
40
 
41
+ worker.on("error", (err) => {
42
+ clearTimeout(timeout);
43
+ reject(err);
44
+ });
45
 
46
+ worker.on("exit", (code) => {
47
+ // By the time this fires the promise is already settled via "message".
48
+ // Only reject if the worker crashed without posting anything.
49
+ clearTimeout(timeout);
50
+ if (code !== 0) reject(new Error(`Search worker exited with code ${code}`));
51
+ });
52
+ });
53
  }
54
 
55
  const SYSTEM_PROMPT =
 
68
  "Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " +
69
  "Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.";
70
 
71
+ /**
72
+ * Build a per-request OpenAI client pointed at the Lightning backend.
73
+ * A new client is created each call so per-user auth headers are always fresh.
74
+ */
75
  function makeClient(accessToken, clientId) {
76
  return new OpenAI({
77
  apiKey: accessToken || "no-key",
78
  baseURL: `${LIGHTNING_BASE}/gen`,
79
  defaultHeaders: {
80
  ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
81
+ ...(clientId ? { "X-Client-ID": clientId } : {}),
82
  },
83
  });
84
  }
85
 
86
+ /**
87
+ * Consume an OpenAI streaming response, firing onToken for each text delta
88
+ * and collecting any tool-call chunks into a finished toolCalls array.
89
+ * Returns { assistantText, toolCalls }.
90
+ */
91
  async function consumeStream(stream, onToken) {
92
  let assistantText = "";
93
  const toolCallBuffer = new Map();
 
104
  if (delta.tool_calls) {
105
  for (const call of delta.tool_calls) {
106
  const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
107
+ if (call.id) entry.id = call.id;
108
+ if (call.function?.name) entry.name = call.function.name;
109
  if (call.function?.arguments) entry.arguments += call.function.arguments;
110
  toolCallBuffer.set(call.index, entry);
111
  }
112
  }
113
  }
114
 
115
+ const toolCalls = [...toolCallBuffer.values()].map(t => ({
116
+ id: t.id || `call_${crypto.randomUUID()}`,
117
+ type: "function",
118
  function: { name: t.name, arguments: t.arguments },
119
  }));
120
 
 
137
  abortSignal,
138
  }) {
139
  if (!accessToken) console.log("Missing access token");
140
+ const client = makeClient(accessToken, clientId);
141
  const enabledTools = buildToolList(tools);
142
 
143
  const messages = [
 
147
  ];
148
 
149
  try {
150
+ // ── First stream ────────────────────────────────────────────────────────
151
  const stream = await client.chat.completions.create({
152
+ model: model || "lightning",
153
  messages,
154
+ tools: enabledTools.length > 0 ? enabledTools : undefined,
155
  stream: true,
156
  }, { signal: abortSignal });
157
 
158
  let { assistantText, toolCalls } = await consumeStream(stream, onToken);
159
 
160
+ // ── Tool calls β†’ follow-up stream ───────────────────────────────────────
161
  if (toolCalls.length > 0) {
162
  const toolResults = await processToolCalls(
163
  ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset,
164
  );
165
 
166
  const followUpMessages = [
167
+ { role: "system", content: SYSTEM_PROMPT },
168
  ...history.map(normalizeMessage).filter(Boolean),
169
+ { role: "user", content: userMessage },
170
  { role: "assistant", content: assistantText || "", tool_calls: toolCalls },
171
  ...toolResults,
172
  ];
173
 
174
  const followUpStream = await client.chat.completions.create({
175
+ model: model || "lightning",
176
  messages: followUpMessages,
177
  stream: true,
178
  }, { signal: abortSignal });
 
184
  onDone(assistantText, toolCalls);
185
  } catch (err) {
186
  if (err.name === "AbortError" || err.constructor?.name === "APIUserAbortError") {
187
+ onDone(null, null, true); // aborted
188
  } else {
189
  onError(String(err));
190
  }
 
199
  if (msg.role === "assistant" && msg.tool_calls) {
200
  return { role: "assistant", content: "", tool_calls: msg.tool_calls };
201
  }
 
202
  if (Array.isArray(msg.content)) {
203
  const textOnly = msg.content
204
  .filter(b => b.type === "text")
 
206
  .join("\n");
207
  return { role: msg.role, content: textOnly || "" };
208
  }
 
209
  return { role: msg.role, content: msg.content };
210
  }
211
 
212
  function buildToolList(tools) {
213
  if (!tools) return [];
214
  const list = [];
 
215
  if (tools.webSearch) {
216
  list.push({
217
  type: "function",
 
225
  },
226
  },
227
  });
 
228
  list.push({
229
  type: "function",
230
  function: {
 
238
  },
239
  });
240
  }
 
241
  if (tools.imageGen) {
242
  list.push({
243
  type: "function",
 
256
  },
257
  });
258
  }
 
259
  if (tools.videoGen) {
260
  list.push({
261
  type: "function",
 
276
  },
277
  });
278
  }
 
279
  if (tools.audioGen) {
280
  list.push({
281
  type: "function",
 
290
  },
291
  });
292
  }
 
293
  return list;
294
  }
295
 
296
  async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) {
297
  const toolResults = [];
298
  const authHeaders = {};
299
+ if (accessToken) {
300
+ authHeaders["Authorization"] = `Bearer ${accessToken}`;
301
+ } else {
302
+ console.log("No access token");
303
+ }
304
+ if (clientId) {
305
+ authHeaders["X-Client-ID"] = clientId;
306
+ } else {
307
+ console.log("No Client ID");
308
+ }
309
 
310
  for (const call of toolCalls) {
311
  let args;
 
316
  let result = "Tool completed.";
317
 
318
  try {
 
319
  if (call.function.name === "ollama_search") {
320
+ result = await gradioSearch(args.query);
321
  }
322
 
323
  else if (call.function.name === "read_web_page") {
324
  const { convert } = await import("html-to-text");
325
  const res = await fetch(args.url, { signal: abortSignal });
 
326
  if (!res.ok) {
327
  result = `Failed to fetch: ${res.status}`;
328
  } else {
329
  const html = await res.text();
330
  const titleMatch = html.match(/<title>(.*?)<\/title>/i);
 
331
  result = JSON.stringify({
332
  title: titleMatch?.[1] || "No title",
333
  content: convert(html, { wordwrap: false }).slice(0, 8000),
 
346
  body: JSON.stringify(body),
347
  signal: abortSignal,
348
  });
 
349
  if (res.ok) {
350
  const buf = await res.arrayBuffer();
351
  const ct = res.headers.get("content-type") || "image/png";
 
375
  body: JSON.stringify(body),
376
  signal: abortSignal,
377
  });
 
378
  if (res.ok) {
379
  const buf = await res.arrayBuffer();
380
  const b64 = Buffer.from(buf).toString("base64");
 
397
  body: JSON.stringify({ prompt: args.prompt }),
398
  signal: abortSignal,
399
  });
 
400
  if (res.ok) {
401
  const buf = await res.arrayBuffer();
402
  const b64 = Buffer.from(buf).toString("base64");
 
409
  result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`;
410
  }
411
  }
 
412
  } catch (err) {
413
  console.log(`Tool error: ${String(err)}`);
414
  result = `Tool error: ${String(err)}`;