incognitolm commited on
Commit
748cfb2
·
1 Parent(s): c239331

Update chatStream.js

Browse files
Files changed (1) hide show
  1. server/chatStream.js +49 -84
server/chatStream.js CHANGED
@@ -1,55 +1,23 @@
1
  import OpenAI from "openai";
2
- import { Worker } from "worker_threads";
3
- import { fileURLToPath } from "url";
4
- import path from "path";
5
- import { LIGHTNING_BASE } from "./config.js";
6
-
7
- // ── Web Search via an isolated Worker thread ──────────────────────────────
8
- //
9
- // The @gradio/client library opens a persistent SSE (Server-Sent Events)
10
- // fetch stream for its session queue. Even after client.close() is called,
11
- // the SSE response-body reader keeps an async iterator alive in the current
12
- // event loop. When that iterator eventually settles (stream closed/errored
13
- // by the remote server), it triggers internal callbacks that emit events
14
- // on objects the Node `ws` library also watches — causing the browser to
15
- // see a "connection lost" message immediately after every web search.
16
- //
17
- // The only reliable fix is to run the Gradio client in a worker_thread so
18
- // it gets its own V8 context and event loop. When the worker exits (via
19
- // process.exit(0) in searchWorker.js), every handle it opened — SSE stream,
20
- // heartbeat timer, etc. — is torn down with it, leaving the main thread's
21
- // WS server completely untouched.
22
-
23
- const __dirname = path.dirname(fileURLToPath(import.meta.url));
24
- const WORKER_PATH = path.join(__dirname, "searchWorker.js");
25
-
26
- function gradioSearch(query) {
27
- return new Promise((resolve, reject) => {
28
- const worker = new Worker(WORKER_PATH, { workerData: { query } });
29
-
30
- const timeout = setTimeout(() => {
31
- worker.terminate();
32
- reject(new Error("Search timed out after 45s"));
33
- }, 45_000);
34
-
35
- worker.on("message", (msg) => {
36
- clearTimeout(timeout);
37
- if (msg.ok) resolve(msg.result);
38
- else reject(new Error(msg.error));
39
- });
40
-
41
- worker.on("error", (err) => {
42
- clearTimeout(timeout);
43
- reject(err);
44
- });
45
 
46
- worker.on("exit", (code) => {
47
- // By the time this fires the promise is already settled via "message".
48
- // Only reject if the worker crashed without posting anything.
49
- clearTimeout(timeout);
50
- if (code !== 0) reject(new Error(`Search worker exited with code ${code}`));
51
- });
 
 
 
 
52
  });
 
 
 
 
 
 
53
  }
54
 
55
  const SYSTEM_PROMPT =
@@ -68,26 +36,17 @@ const SYSTEM_PROMPT =
68
  "Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " +
69
  "Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.";
70
 
71
- /**
72
- * Build a per-request OpenAI client pointed at the Lightning backend.
73
- * A new client is created each call so per-user auth headers are always fresh.
74
- */
75
  function makeClient(accessToken, clientId) {
76
  return new OpenAI({
77
  apiKey: accessToken || "no-key",
78
  baseURL: `${LIGHTNING_BASE}/gen`,
79
  defaultHeaders: {
80
  ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
81
- ...(clientId ? { "X-Client-ID": clientId } : {}),
82
  },
83
  });
84
  }
85
 
86
- /**
87
- * Consume an OpenAI streaming response, firing onToken for each text delta
88
- * and collecting any tool-call chunks into a finished toolCalls array.
89
- * Returns { assistantText, toolCalls }.
90
- */
91
  async function consumeStream(stream, onToken) {
92
  let assistantText = "";
93
  const toolCallBuffer = new Map();
@@ -104,17 +63,17 @@ async function consumeStream(stream, onToken) {
104
  if (delta.tool_calls) {
105
  for (const call of delta.tool_calls) {
106
  const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
107
- if (call.id) entry.id = call.id;
108
- if (call.function?.name) entry.name = call.function.name;
109
  if (call.function?.arguments) entry.arguments += call.function.arguments;
110
  toolCallBuffer.set(call.index, entry);
111
  }
112
  }
113
  }
114
 
115
- const toolCalls = [...toolCallBuffer.values()].map(t => ({
116
- id: t.id || `call_${crypto.randomUUID()}`,
117
- type: "function",
118
  function: { name: t.name, arguments: t.arguments },
119
  }));
120
 
@@ -137,7 +96,7 @@ export async function streamChat(ws, {
137
  abortSignal,
138
  }) {
139
  if (!accessToken) console.log("Missing access token");
140
- const client = makeClient(accessToken, clientId);
141
  const enabledTools = buildToolList(tools);
142
 
143
  const messages = [
@@ -147,32 +106,30 @@ export async function streamChat(ws, {
147
  ];
148
 
149
  try {
150
- // ── First stream ────────────────────────────────────────────────────────
151
  const stream = await client.chat.completions.create({
152
- model: model || "lightning",
153
  messages,
154
- tools: enabledTools.length > 0 ? enabledTools : undefined,
155
  stream: true,
156
  }, { signal: abortSignal });
157
 
158
  let { assistantText, toolCalls } = await consumeStream(stream, onToken);
159
 
160
- // ── Tool calls → follow-up stream ───────────────────────────────────────
161
  if (toolCalls.length > 0) {
162
  const toolResults = await processToolCalls(
163
  ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset,
164
  );
165
 
166
  const followUpMessages = [
167
- { role: "system", content: SYSTEM_PROMPT },
168
  ...history.map(normalizeMessage).filter(Boolean),
169
- { role: "user", content: userMessage },
170
  { role: "assistant", content: assistantText || "", tool_calls: toolCalls },
171
  ...toolResults,
172
  ];
173
 
174
  const followUpStream = await client.chat.completions.create({
175
- model: model || "lightning",
176
  messages: followUpMessages,
177
  stream: true,
178
  }, { signal: abortSignal });
@@ -184,7 +141,7 @@ export async function streamChat(ws, {
184
  onDone(assistantText, toolCalls);
185
  } catch (err) {
186
  if (err.name === "AbortError" || err.constructor?.name === "APIUserAbortError") {
187
- onDone(null, null, true); // aborted
188
  } else {
189
  onError(String(err));
190
  }
@@ -199,6 +156,7 @@ function normalizeMessage(msg) {
199
  if (msg.role === "assistant" && msg.tool_calls) {
200
  return { role: "assistant", content: "", tool_calls: msg.tool_calls };
201
  }
 
202
  if (Array.isArray(msg.content)) {
203
  const textOnly = msg.content
204
  .filter(b => b.type === "text")
@@ -206,12 +164,14 @@ function normalizeMessage(msg) {
206
  .join("\n");
207
  return { role: msg.role, content: textOnly || "" };
208
  }
 
209
  return { role: msg.role, content: msg.content };
210
  }
211
 
212
  function buildToolList(tools) {
213
  if (!tools) return [];
214
  const list = [];
 
215
  if (tools.webSearch) {
216
  list.push({
217
  type: "function",
@@ -225,6 +185,7 @@ function buildToolList(tools) {
225
  },
226
  },
227
  });
 
228
  list.push({
229
  type: "function",
230
  function: {
@@ -238,6 +199,7 @@ function buildToolList(tools) {
238
  },
239
  });
240
  }
 
241
  if (tools.imageGen) {
242
  list.push({
243
  type: "function",
@@ -256,6 +218,7 @@ function buildToolList(tools) {
256
  },
257
  });
258
  }
 
259
  if (tools.videoGen) {
260
  list.push({
261
  type: "function",
@@ -276,6 +239,7 @@ function buildToolList(tools) {
276
  },
277
  });
278
  }
 
279
  if (tools.audioGen) {
280
  list.push({
281
  type: "function",
@@ -290,22 +254,16 @@ function buildToolList(tools) {
290
  },
291
  });
292
  }
 
293
  return list;
294
  }
295
 
296
  async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) {
297
  const toolResults = [];
298
  const authHeaders = {};
299
- if (accessToken) {
300
- authHeaders["Authorization"] = `Bearer ${accessToken}`;
301
- } else {
302
- console.log("No access token");
303
- }
304
- if (clientId) {
305
- authHeaders["X-Client-ID"] = clientId;
306
- } else {
307
- console.log("No Client ID");
308
- }
309
 
310
  for (const call of toolCalls) {
311
  let args;
@@ -316,18 +274,21 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
316
  let result = "Tool completed.";
317
 
318
  try {
 
319
  if (call.function.name === "ollama_search") {
320
- result = await gradioSearch(args.query);
321
  }
322
 
323
  else if (call.function.name === "read_web_page") {
324
  const { convert } = await import("html-to-text");
325
  const res = await fetch(args.url, { signal: abortSignal });
 
326
  if (!res.ok) {
327
  result = `Failed to fetch: ${res.status}`;
328
  } else {
329
  const html = await res.text();
330
  const titleMatch = html.match(/<title>(.*?)<\/title>/i);
 
331
  result = JSON.stringify({
332
  title: titleMatch?.[1] || "No title",
333
  content: convert(html, { wordwrap: false }).slice(0, 8000),
@@ -346,6 +307,7 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
346
  body: JSON.stringify(body),
347
  signal: abortSignal,
348
  });
 
349
  if (res.ok) {
350
  const buf = await res.arrayBuffer();
351
  const ct = res.headers.get("content-type") || "image/png";
@@ -375,6 +337,7 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
375
  body: JSON.stringify(body),
376
  signal: abortSignal,
377
  });
 
378
  if (res.ok) {
379
  const buf = await res.arrayBuffer();
380
  const b64 = Buffer.from(buf).toString("base64");
@@ -397,6 +360,7 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
397
  body: JSON.stringify({ prompt: args.prompt }),
398
  signal: abortSignal,
399
  });
 
400
  if (res.ok) {
401
  const buf = await res.arrayBuffer();
402
  const b64 = Buffer.from(buf).toString("base64");
@@ -409,6 +373,7 @@ async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abo
409
  result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`;
410
  }
411
  }
 
412
  } catch (err) {
413
  console.log(`Tool error: ${String(err)}`);
414
  result = `Tool error: ${String(err)}`;
 
1
  import OpenAI from "openai";
2
+ import { LIGHTNING_BASE, SEARCH_API_BASE } from "./config.js";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ async function apiSearch(query) {
5
+ const res = await fetch(`${SEARCH_API_BASE}/api/search`, {
6
+ method: "POST",
7
+ headers: {
8
+ "Content-Type": "application/json",
9
+ },
10
+ body: JSON.stringify({
11
+ query,
12
+ max_results: 1,
13
+ }),
14
  });
15
+
16
+ if (!res.ok) {
17
+ throw new Error(`Search API error: ${res.status}`);
18
+ }
19
+
20
+ return await res.json();
21
  }
22
 
23
  const SYSTEM_PROMPT =
 
36
  "Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " +
37
  "Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.";
38
 
 
 
 
 
39
  function makeClient(accessToken, clientId) {
40
  return new OpenAI({
41
  apiKey: accessToken || "no-key",
42
  baseURL: `${LIGHTNING_BASE}/gen`,
43
  defaultHeaders: {
44
  ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
45
+ ...(clientId ? { "X-Client-ID": clientId } : {}),
46
  },
47
  });
48
  }
49
 
 
 
 
 
 
50
  async function consumeStream(stream, onToken) {
51
  let assistantText = "";
52
  const toolCallBuffer = new Map();
 
63
  if (delta.tool_calls) {
64
  for (const call of delta.tool_calls) {
65
  const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
66
+ if (call.id) entry.id = call.id;
67
+ if (call.function?.name) entry.name = call.function.name;
68
  if (call.function?.arguments) entry.arguments += call.function.arguments;
69
  toolCallBuffer.set(call.index, entry);
70
  }
71
  }
72
  }
73
 
74
+ const toolCalls = [...toolCallBuffer.values()].map((t) => ({
75
+ id: t.id || `call_${crypto.randomUUID()}`,
76
+ type: "function",
77
  function: { name: t.name, arguments: t.arguments },
78
  }));
79
 
 
96
  abortSignal,
97
  }) {
98
  if (!accessToken) console.log("Missing access token");
99
+ const client = makeClient(accessToken, clientId);
100
  const enabledTools = buildToolList(tools);
101
 
102
  const messages = [
 
106
  ];
107
 
108
  try {
 
109
  const stream = await client.chat.completions.create({
110
+ model: model || "lightning",
111
  messages,
112
+ tools: enabledTools.length > 0 ? enabledTools : undefined,
113
  stream: true,
114
  }, { signal: abortSignal });
115
 
116
  let { assistantText, toolCalls } = await consumeStream(stream, onToken);
117
 
 
118
  if (toolCalls.length > 0) {
119
  const toolResults = await processToolCalls(
120
  ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset,
121
  );
122
 
123
  const followUpMessages = [
124
+ { role: "system", content: SYSTEM_PROMPT },
125
  ...history.map(normalizeMessage).filter(Boolean),
126
+ { role: "user", content: userMessage },
127
  { role: "assistant", content: assistantText || "", tool_calls: toolCalls },
128
  ...toolResults,
129
  ];
130
 
131
  const followUpStream = await client.chat.completions.create({
132
+ model: model || "lightning",
133
  messages: followUpMessages,
134
  stream: true,
135
  }, { signal: abortSignal });
 
141
  onDone(assistantText, toolCalls);
142
  } catch (err) {
143
  if (err.name === "AbortError" || err.constructor?.name === "APIUserAbortError") {
144
+ onDone(null, null, true);
145
  } else {
146
  onError(String(err));
147
  }
 
156
  if (msg.role === "assistant" && msg.tool_calls) {
157
  return { role: "assistant", content: "", tool_calls: msg.tool_calls };
158
  }
159
+
160
  if (Array.isArray(msg.content)) {
161
  const textOnly = msg.content
162
  .filter(b => b.type === "text")
 
164
  .join("\n");
165
  return { role: msg.role, content: textOnly || "" };
166
  }
167
+
168
  return { role: msg.role, content: msg.content };
169
  }
170
 
171
  function buildToolList(tools) {
172
  if (!tools) return [];
173
  const list = [];
174
+
175
  if (tools.webSearch) {
176
  list.push({
177
  type: "function",
 
185
  },
186
  },
187
  });
188
+
189
  list.push({
190
  type: "function",
191
  function: {
 
199
  },
200
  });
201
  }
202
+
203
  if (tools.imageGen) {
204
  list.push({
205
  type: "function",
 
218
  },
219
  });
220
  }
221
+
222
  if (tools.videoGen) {
223
  list.push({
224
  type: "function",
 
239
  },
240
  });
241
  }
242
+
243
  if (tools.audioGen) {
244
  list.push({
245
  type: "function",
 
254
  },
255
  });
256
  }
257
+
258
  return list;
259
  }
260
 
261
  async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) {
262
  const toolResults = [];
263
  const authHeaders = {};
264
+
265
+ if (accessToken) authHeaders["Authorization"] = `Bearer ${accessToken}`;
266
+ if (clientId) authHeaders["X-Client-ID"] = clientId;
 
 
 
 
 
 
 
267
 
268
  for (const call of toolCalls) {
269
  let args;
 
274
  let result = "Tool completed.";
275
 
276
  try {
277
+
278
  if (call.function.name === "ollama_search") {
279
+ result = await apiSearch(args.query);
280
  }
281
 
282
  else if (call.function.name === "read_web_page") {
283
  const { convert } = await import("html-to-text");
284
  const res = await fetch(args.url, { signal: abortSignal });
285
+
286
  if (!res.ok) {
287
  result = `Failed to fetch: ${res.status}`;
288
  } else {
289
  const html = await res.text();
290
  const titleMatch = html.match(/<title>(.*?)<\/title>/i);
291
+
292
  result = JSON.stringify({
293
  title: titleMatch?.[1] || "No title",
294
  content: convert(html, { wordwrap: false }).slice(0, 8000),
 
307
  body: JSON.stringify(body),
308
  signal: abortSignal,
309
  });
310
+
311
  if (res.ok) {
312
  const buf = await res.arrayBuffer();
313
  const ct = res.headers.get("content-type") || "image/png";
 
337
  body: JSON.stringify(body),
338
  signal: abortSignal,
339
  });
340
+
341
  if (res.ok) {
342
  const buf = await res.arrayBuffer();
343
  const b64 = Buffer.from(buf).toString("base64");
 
360
  body: JSON.stringify({ prompt: args.prompt }),
361
  signal: abortSignal,
362
  });
363
+
364
  if (res.ok) {
365
  const buf = await res.arrayBuffer();
366
  const b64 = Buffer.from(buf).toString("base64");
 
373
  result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`;
374
  }
375
  }
376
+
377
  } catch (err) {
378
  console.log(`Tool error: ${String(err)}`);
379
  result = `Tool error: ${String(err)}`;