incognitolm commited on
Commit
466bda3
·
1 Parent(s): 8270b87

Update chatStream.js

Browse files
Files changed (1) hide show
  1. server/chatStream.js +51 -26
server/chatStream.js CHANGED
@@ -13,6 +13,8 @@ const WORKER_PATH = path.join(__dirname, "searchWorker.js");
13
  let persistentWs = null;
14
  let wsAuthPromise = null;
15
  let requestIdCounter = 0;
 
 
16
 
17
  async function getSafeWebSocket() {
18
  // If we have a valid persistent connection, return it
@@ -93,6 +95,30 @@ async function getSafeWebSocket() {
93
  });
94
  });
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  wsAuthPromise = null;
97
  return persistentWs;
98
  })();
@@ -166,9 +192,15 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
166
  let finished = false;
167
 
168
  return new Promise((resolve, reject) => {
169
- const messageHandler = (data) => {
170
- const line = data.toString();
171
-
 
 
 
 
 
 
172
  // Parse request ID prefix (format: "id:payload")
173
  const colonIdx = line.indexOf(':');
174
  if (colonIdx === -1) {
@@ -196,14 +228,12 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
196
  // Mark as finished to end the stream
197
  if (!finished) {
198
  finished = true;
 
199
  const toolCalls = [...toolCallBuffer.values()].map((t) => ({
200
  id: t.id || `call_${crypto.randomUUID()}`,
201
  type: "function",
202
  function: { name: t.name, arguments: t.arguments },
203
  }));
204
- ws.removeListener("message", messageHandler);
205
- ws.removeListener("error", errorHandler);
206
- if (abortSignal) abortSignal.removeEventListener("abort", abortHandler);
207
  resolve({ assistantText, toolCalls });
208
  }
209
  return;
@@ -226,28 +256,23 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
226
  }
227
 
228
  if (payload.choices?.[0]?.finish_reason) {
229
- finished = true;
230
- const toolCalls = [...toolCallBuffer.values()].map((t) => ({
231
- id: t.id || `call_${crypto.randomUUID()}`,
232
- type: "function",
233
- function: { name: t.name, arguments: t.arguments },
234
- }));
235
-
236
- // Remove listener and resolve
237
- ws.removeListener("message", messageHandler);
238
- ws.removeListener("error", errorHandler);
239
- if (abortSignal) abortSignal.removeEventListener("abort", abortHandler);
240
- resolve({ assistantText, toolCalls });
241
  }
242
  };
243
 
244
  const errorHandler = (err) => {
245
- console.error("[WS ERROR]", err);
246
  if (!finished) {
247
  finished = true;
248
- ws.removeListener("message", messageHandler);
249
- ws.removeListener("error", errorHandler);
250
- if (abortSignal) abortSignal.removeEventListener("abort", abortHandler);
251
  reject(err);
252
  }
253
  };
@@ -255,14 +280,14 @@ export async function websocketChatStream(body, headers, onToken, abortSignal) {
255
  const abortHandler = () => {
256
  if (!finished) {
257
  finished = true;
258
- ws.removeListener("message", messageHandler);
259
- ws.removeListener("error", errorHandler);
260
  reject(new Error("AbortError"));
261
  }
262
  };
263
 
264
- ws.on("message", messageHandler);
265
- ws.on("error", errorHandler);
 
266
 
267
  if (abortSignal) {
268
  abortSignal.addEventListener("abort", abortHandler);
 
13
  let persistentWs = null;
14
  let wsAuthPromise = null;
15
  let requestIdCounter = 0;
16
+ let activeStreamHandlers = new Map(); // Track active stream handlers by request ID
17
+ let errorHandlers = new Map(); // Track error handlers by request ID
18
 
19
  async function getSafeWebSocket() {
20
  // If we have a valid persistent connection, return it
 
95
  });
96
  });
97
 
98
+ // Set up the global message and error routing after authentication
99
+ const globalMessageHandler = (data) => {
100
+ const line = data.toString();
101
+ // Route to all active stream handlers
102
+ for (const [id, handler] of activeStreamHandlers.entries()) {
103
+ if (!id.startsWith('__')) { // Skip metadata keys
104
+ handler(line);
105
+ }
106
+ }
107
+ };
108
+
109
+ const globalErrorHandler = (err) => {
110
+ console.error("[WS ERROR]", err);
111
+ // Notify all active streams
112
+ for (const [id, handler] of errorHandlers.entries()) {
113
+ handler(err);
114
+ }
115
+ };
116
+
117
+ persistentWs.on("message", globalMessageHandler);
118
+ persistentWs.on("error", globalErrorHandler);
119
+ activeStreamHandlers.set("__messageListener__", globalMessageHandler);
120
+ activeStreamHandlers.set("__errorHandler__", globalErrorHandler);
121
+
122
  wsAuthPromise = null;
123
  return persistentWs;
124
  })();
 
192
  let finished = false;
193
 
194
  return new Promise((resolve, reject) => {
195
+ const cleanup = () => {
196
+ activeStreamHandlers.delete(currentRequestId);
197
+ errorHandlers.delete(currentRequestId);
198
+ if (abortSignal) {
199
+ abortSignal.removeEventListener("abort", abortHandler);
200
+ }
201
+ };
202
+
203
+ const messageHandler = (line) => {
204
  // Parse request ID prefix (format: "id:payload")
205
  const colonIdx = line.indexOf(':');
206
  if (colonIdx === -1) {
 
228
  // Mark as finished to end the stream
229
  if (!finished) {
230
  finished = true;
231
+ cleanup();
232
  const toolCalls = [...toolCallBuffer.values()].map((t) => ({
233
  id: t.id || `call_${crypto.randomUUID()}`,
234
  type: "function",
235
  function: { name: t.name, arguments: t.arguments },
236
  }));
 
 
 
237
  resolve({ assistantText, toolCalls });
238
  }
239
  return;
 
256
  }
257
 
258
  if (payload.choices?.[0]?.finish_reason) {
259
+ if (!finished) {
260
+ finished = true;
261
+ cleanup();
262
+ const toolCalls = [...toolCallBuffer.values()].map((t) => ({
263
+ id: t.id || `call_${crypto.randomUUID()}`,
264
+ type: "function",
265
+ function: { name: t.name, arguments: t.arguments },
266
+ }));
267
+ resolve({ assistantText, toolCalls });
268
+ }
 
 
269
  }
270
  };
271
 
272
  const errorHandler = (err) => {
 
273
  if (!finished) {
274
  finished = true;
275
+ cleanup();
 
 
276
  reject(err);
277
  }
278
  };
 
280
  const abortHandler = () => {
281
  if (!finished) {
282
  finished = true;
283
+ cleanup();
 
284
  reject(new Error("AbortError"));
285
  }
286
  };
287
 
288
+ // Register handlers for this request
289
+ activeStreamHandlers.set(currentRequestId, messageHandler);
290
+ errorHandlers.set(currentRequestId, errorHandler);
291
 
292
  if (abortSignal) {
293
  abortSignal.addEventListener("abort", abortHandler);