chat-dev / server /chatStream.js
incognitolm
Update chatStream.js
d04bf99
raw
history blame
23.1 kB
import OpenAI from "openai";
import { Worker } from "worker_threads";
import { fileURLToPath } from "url";
import path from "path";
import { LIGHTNING_BASE } from "./config.js";
import WebSocket from "ws";
import crypto from "crypto";
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const WORKER_PATH = path.join(__dirname, "searchWorker.js");
// Persistent WebSocket pool
let persistentWs = null;
let wsAuthPromise = null;
let requestIdCounter = 0;
let activeStreamHandlers = new Map(); // Track active stream handlers by request ID
let errorHandlers = new Map(); // Track error handlers by request ID
async function getSafeWebSocket() {
// If we have a valid persistent connection, return it
if (persistentWs && persistentWs.readyState === WebSocket.OPEN) {
return persistentWs;
}
// If we're already authenticating, wait for it
if (wsAuthPromise) {
return wsAuthPromise;
}
// Create and authenticate a new websocket
wsAuthPromise = (async () => {
const wsURL =
(process.env.LIGHTNING_BASE.startsWith("https")
? process.env.LIGHTNING_BASE.replace("https", "wss")
: process.env.LIGHTNING_BASE.replace("http", "ws")) + "/ws/chat";
persistentWs = new WebSocket(wsURL);
const safeParse = (str) => {
try {
const cleaned = str.startsWith("data: ") ? str.slice(6) : str;
return JSON.parse(cleaned);
} catch {
return null;
}
};
// Wait for connection open
await new Promise((resolve, reject) => {
const timer = setTimeout(() => reject(new Error("WS connection timeout")), 5000);
persistentWs.on("open", () => {
clearTimeout(timer);
resolve();
});
persistentWs.on("error", (err) => {
console.error("[WS] Connection error", err);
clearTimeout(timer);
persistentWs = null;
wsAuthPromise = null;
reject(err);
});
});
// Authenticate
persistentWs.send(JSON.stringify({ key: process.env.WEBSOCKET_KEY }));
await new Promise((resolve, reject) => {
const timer = setTimeout(() => reject(new Error("WS auth timeout")), 5000);
const authHandler = (data) => {
const msg = safeParse(data.toString());
if (!msg) return;
if (msg.type === "auth" && msg.status === "ok") {
persistentWs.removeListener("message", authHandler);
clearTimeout(timer);
resolve();
}
if (msg.error) {
console.error("[WS] Auth error", msg.error);
persistentWs.removeListener("message", authHandler);
clearTimeout(timer);
persistentWs = null;
wsAuthPromise = null;
reject(new Error(`WS auth error: ${msg.error}`));
}
};
persistentWs.on("message", authHandler);
persistentWs.on("error", (err) => {
console.error("[WS] Auth error event", err);
persistentWs.removeListener("message", authHandler);
clearTimeout(timer);
persistentWs = null;
wsAuthPromise = null;
reject(err);
});
});
// Set up the global message and error routing after authentication
const globalMessageHandler = (data) => {
const line = data.toString();
// Route to all active stream handlers
for (const [id, handler] of activeStreamHandlers.entries()) {
if (!String(id).startsWith('__')) { // Skip metadata keys
handler(line);
}
}
};
const globalErrorHandler = (err) => {
console.error("[WS ERROR]", err);
// Notify all active streams
for (const [id, handler] of errorHandlers.entries()) {
handler(err);
}
};
persistentWs.on("message", globalMessageHandler);
persistentWs.on("error", globalErrorHandler);
activeStreamHandlers.set("__messageListener__", globalMessageHandler);
activeStreamHandlers.set("__errorHandler__", globalErrorHandler);
wsAuthPromise = null;
return persistentWs;
})();
return wsAuthPromise;
}
async function gradioSearch(query) {
const req = await fetch("https://incognitolm-Web-Search.hf.space/api/search", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ query }),
});
const res = await req.json();
if (!req.ok || !res.results) {
throw new Error(`Search API error: ${res.error || req.statusText}`);
}
return res.results[0];
}
const SYSTEM_PROMPT =
"CRITICAL RULE: Every response MUST use HTML <span data-color=\"{COLOR NAME}\"> tags to color main points and headings. " +
"COLORS MUST HAVE MEANING AND CONSISTENCY ACROSS THE ENTIRE CONVERSATION. " +
"You may ONLY use the following semantic color names: green, pink, blue, red, orange, yellow, purple, teal, gold, coral. " +
"Never output text formatted with explicit black or white colors. " +
"Use a variety of colors throughout every response to distinguish headings, sections, and key terms. " +
"Keep code blocks plain, but color headings and important points in surrounding text. " +
"Do not over-color responses. Use color intentionally and sparingly. " +
"CRITICAL RULE: MARKDOWN FORMATTING SUCH AS #, ##, ###, **, * MUST BE PLACED OUTSIDE <span> tags. Use the same colors for similar meanings." +
"You are a helpful, friendly AI assistant. Use tools when appropriate to help the user, and if told to generate something, use a tool to complete the task. " +
"When generating media, do not include URLs — it is displayed automatically. " +
"You can render SVG images by outputting SVG code in a code block tagged exactly as:\n```svg\n<svg>...</svg>\n```\n" +
"Never use single backslashes. You may use emojis where appropriate. " +
"Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " +
"Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.\n\n" +
"SESSION NAMING: After you have fully responded to the user, append a session name tag on its own line at the very end of your response (NEVER inside a code block). Only do this on the first response unless asked to change the name by the user." +
"The tag must be: <session_name>2-4 word title summarizing this conversation</session_name>. " +
"Example: <session_name>React State Management</session_name>. " +
"This tag is hidden from the user and used only to name the chat. Do not mention it." +
"Make sure your responses are always accurate. If you are not completely sure about something, search the web.";
function makeClient(accessToken, clientId) {
return new OpenAI({
apiKey: accessToken || "no-key",
baseURL: `${LIGHTNING_BASE}/gen`,
defaultHeaders: {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
...(clientId ? { "X-Client-ID": clientId } : {}),
},
});
}
export async function websocketChatStream(body, headers, onToken, abortSignal) {
const ws = await getSafeWebSocket();
const currentRequestId = ++requestIdCounter;
const safeParse = (str) => {
try {
const cleaned = str.startsWith("data: ") ? str.slice(6) : str;
return JSON.parse(cleaned);
} catch {
return null;
}
};
// Send the chat body
ws.send(JSON.stringify({ body, headers }));
let assistantText = "";
const toolCallBuffer = new Map();
let finished = false;
return new Promise((resolve, reject) => {
// Timeout fallback - resolve after 120 seconds if no finish_reason
const timeoutId = setTimeout(() => {
if (!finished) {
finished = true;
cleanup();
const toolCalls = [...toolCallBuffer.values()].map((t) => ({
id: t.id || `call_${crypto.randomUUID()}`,
type: "function",
function: { name: t.name, arguments: t.arguments },
}));
resolve({ assistantText, toolCalls });
}
}, 120000);
const cleanup = () => {
activeStreamHandlers.delete(currentRequestId);
errorHandlers.delete(currentRequestId);
clearTimeout(timeoutId);
if (abortSignal) {
abortSignal.removeEventListener("abort", abortHandler);
}
};
const messageHandler = (line) => {
// Parse request ID prefix (format: "id:payload")
const colonIdx = line.indexOf(':');
if (colonIdx === -1) {
return; // Ignore messages without request ID format
}
const msgRequestId = line.substring(0, colonIdx);
const payload = safeParse(line.substring(colonIdx + 1));
// Ignore messages from other requests
if (msgRequestId !== String(currentRequestId)) {
return;
}
if (!payload) {
console.warn("[WS] Failed to parse JSON:", line);
return;
}
// Only treat as fatal error if it's a structured error response (not delta data with error field)
if (payload.error && !payload.choices) {
console.error("[WS PAYLOAD ERROR]", payload.error);
if (onToken) onToken(`[ERROR] ${payload.error}`);
// Mark as finished to end the stream on fatal errors only
if (!finished) {
finished = true;
cleanup();
const toolCalls = [...toolCallBuffer.values()].map((t) => ({
id: t.id || `call_${crypto.randomUUID()}`,
type: "function",
function: { name: t.name, arguments: t.arguments },
}));
resolve({ assistantText, toolCalls });
}
return;
}
const delta = payload.choices?.[0]?.delta;
if (delta?.content) {
assistantText += delta.content;
if (onToken) onToken(delta.content);
}
if (delta?.tool_calls) {
for (const call of delta.tool_calls) {
const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
if (call.id) entry.id = call.id;
if (call.function?.name) entry.name = call.function.name;
if (call.function?.arguments) entry.arguments += call.function.arguments;
toolCallBuffer.set(call.index, entry);
}
}
if (payload.choices?.[0]?.finish_reason) {
if (!finished) {
finished = true;
cleanup();
const toolCalls = [...toolCallBuffer.values()].map((t) => ({
id: t.id || `call_${crypto.randomUUID()}`,
type: "function",
function: { name: t.name, arguments: t.arguments },
}));
resolve({ assistantText, toolCalls });
}
}
};
const errorHandler = (err) => {
if (!finished) {
finished = true;
cleanup();
reject(err);
}
};
const abortHandler = () => {
if (!finished) {
finished = true;
cleanup();
reject(new Error("AbortError"));
}
};
// Register handlers for this request
activeStreamHandlers.set(currentRequestId, messageHandler);
errorHandlers.set(currentRequestId, errorHandler);
if (abortSignal) {
abortSignal.addEventListener("abort", abortHandler);
}
});
}
/**
* Extract session name from <session_name>...</session_name> tag.
* Must NOT be inside any kind of code block.
*/
export function extractSessionName(text) {
if (!text) return null;
// Remove all code blocks first (``` ... ```) so we don't match tags inside them
const withoutCode = text.replace(/```[\s\S]*?```/g, '');
const match = withoutCode.match(/<session_name>([\s\S]*?)<\/session_name>/i);
if (!match) return null;
const name = match[1].trim();
// Sanity check: 1-80 chars, no newlines
if (!name || name.length > 80 || /\n/.test(name)) return null;
return name;
}
export async function streamChat({
sessionId,
model,
history = [],
userMessage,
tools,
accessToken,
clientId,
onToken = () => {},
onDone = () => {},
onError = () => {},
onToolCall = () => {},
onNewAsset = () => {},
abortSignal,
}) {
const enabledTools = buildToolList(tools);
let normalizedUserMessage = userMessage;
if (Array.isArray(userMessage)) {
const hasImages = userMessage.some(item => item.type === "image_url");
if (hasImages) {
const textItems = userMessage.filter(
item => item.type === "text" && item.text?.trim()
);
if (textItems.length === 0) {
normalizedUserMessage = [
{ type: "text", text: "[Image(s) attached]" },
...userMessage.filter(item => item.type === "image_url"),
];
}
} else {
normalizedUserMessage =
userMessage
.filter(b => b.type === "text")
.map(b => b.text)
.join("\n")
.trim() || "";
}
}
const hasUserMessage =
userMessage !== undefined &&
userMessage !== null &&
(typeof userMessage === "string" ? userMessage.trim() !== "" : Array.isArray(userMessage) && userMessage.length > 0);
const messages = [
{ role: "system", content: SYSTEM_PROMPT },
...history.map(normalizeMessage).filter(Boolean),
];
if (hasUserMessage) {
messages.push({
role: "user",
content: normalizedUserMessage,
});
}
const headers = {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
...(clientId ? { "X-Client-ID": clientId } : {}),
};
try {
const body = {
model: model || "lightning",
messages,
tools: enabledTools.length ? enabledTools : undefined,
stream: true,
};
let { assistantText, toolCalls } =
await websocketChatStream(body, headers, onToken, abortSignal);
if (toolCalls.length > 0) {
const toolResults = await processToolCalls(
null,
toolCalls,
tools,
accessToken,
clientId,
abortSignal,
onToolCall,
onNewAsset
);
const followUpMessages = [
{ role: "system", content: SYSTEM_PROMPT },
...history.map(normalizeMessage).filter(Boolean),
];
if (hasUserMessage) {
followUpMessages.push({
role: "user",
content: normalizedUserMessage,
});
}
followUpMessages.push(
{
role: "assistant",
content: assistantText || "",
tool_calls: toolCalls,
},
...toolResults
);
const followUpBody = {
model: model || "lightning",
messages: followUpMessages,
stream: true,
};
const followUp =
await websocketChatStream(followUpBody, headers, onToken, abortSignal);
assistantText += followUp.assistantText;
}
const sessionName = extractSessionName(assistantText);
if (typeof onDone === "function") {
onDone(assistantText, toolCalls, false, sessionName);
}
} catch (err) {
console.error("streamChat error:", err);
if (
err.name === "AbortError" ||
err.message === "AbortError"
) {
if (typeof onDone === "function") {
onDone(null, null, true, null);
}
} else {
console.error("streamChat error:", err);
if (typeof onError === "function") {
onError(String(err));
}
}
}
}
const VALID_ROLES = new Set(["system", "user", "assistant", "tool"]);
function normalizeMessage(msg) {
if (!VALID_ROLES.has(msg.role)) return null;
if (msg.role === "assistant" && msg.tool_calls) {
return { role: "assistant", content: "", tool_calls: msg.tool_calls };
}
if (Array.isArray(msg.content)) {
// If the array contains images, preserve the full array format
const hasImages = msg.content.some(item => item.type === 'image_url');
if (hasImages) {
return { role: msg.role, content: msg.content };
}
// Otherwise extract text only
const textOnly = msg.content
.filter(b => b.type === "text")
.map(b => b.text)
.join("\n");
return { role: msg.role, content: textOnly || "" };
}
return { role: msg.role, content: msg.content ?? "" };
}
function buildToolList(tools) {
if (!tools) return [];
const list = [];
if (tools.webSearch) {
list.push({
type: "function",
function: {
name: "ollama_search",
description: "Search the web for current information",
parameters: {
type: "object",
properties: { query: { type: "string", description: "Search query" } },
required: ["query"],
},
},
});
list.push({
type: "function",
function: {
name: "read_web_page",
description: "Read the content of a web page by URL",
parameters: {
type: "object",
properties: { url: { type: "string", description: "URL to fetch" } },
required: ["url"],
},
},
});
}
if (tools.imageGen) {
list.push({
type: "function",
function: {
name: "generate_image",
description: "Generate an image from a prompt",
parameters: {
type: "object",
properties: {
prompt: { type: "string" },
mode: { type: "string", enum: ["auto", "fantasy", "realistic"] },
image_urls: { type: "array", items: { type: "string" } },
},
required: ["prompt"],
},
},
});
}
if (tools.videoGen) {
list.push({
type: "function",
function: {
name: "generate_video",
description: "Generate a video from a prompt",
parameters: {
type: "object",
properties: {
prompt: { type: "string" },
ratio: { type: "string", enum: ["3:2", "2:3", "1:1"] },
mode: { type: "string", enum: ["normal", "fun"] },
duration: { type: "number" },
image_urls: { type: "array", items: { type: "string" } },
},
required: ["prompt"],
},
},
});
}
if (tools.audioGen) {
list.push({
type: "function",
function: {
name: "generate_audio",
description: "Generate music or sound effects from a prompt",
parameters: {
type: "object",
properties: { prompt: { type: "string" } },
required: ["prompt"],
},
},
});
}
return list;
}
async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) {
const toolResults = [];
const authHeaders = {};
if (accessToken) {
authHeaders["Authorization"] = `Bearer ${accessToken}`;
}
if (clientId) {
authHeaders["X-Client-ID"] = clientId;
}
for (const call of toolCalls) {
let args;
try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; }
onToolCall({ id: call.id, name: call.function.name, state: "pending", args });
let result = "Tool completed.";
try {
if (call.function.name === "ollama_search") {
result = await gradioSearch(args.query);
}
else if (call.function.name === "read_web_page") {
const { convert } = await import("html-to-text");
const res = await fetch(args.url, { signal: abortSignal });
if (!res.ok) {
result = `Failed to fetch: ${res.status}`;
} else {
const html = await res.text();
const titleMatch = html.match(/<title>(.*?)<\/title>/i);
result = JSON.stringify({
title: titleMatch?.[1] || "No title",
content: convert(html, { wordwrap: false }).slice(0, 8000),
});
}
}
else if (call.function.name === "generate_image") {
const body = { prompt: args.prompt };
if (args.mode) body.mode = args.mode;
if (args.image_urls?.length) body.image_urls = args.image_urls;
const res = await fetch(`${LIGHTNING_BASE}/gen/image`, {
method: "POST",
headers: { "Content-Type": "application/json", ...authHeaders },
body: JSON.stringify(body),
signal: abortSignal,
});
if (res.ok) {
const buf = await res.arrayBuffer();
const ct = res.headers.get("content-type") || "image/png";
const b64 = Buffer.from(buf).toString("base64");
const dataUrl = `data:${ct};base64,${b64}`;
onNewAsset({ role: "image", content: dataUrl });
result = "Image generated successfully and shown to the user.";
} else if (res.status == 402) {
result = "An upgraded plan is required for higher limits.";
} else if (res.status == 429) {
result = "Too many requests. Try again later.";
} else {
result = `Image generation failed: ${res.status}`;
}
}
else if (call.function.name === "generate_video") {
const body = { prompt: args.prompt };
if (args.ratio) body.ratio = args.ratio;
if (args.mode) body.mode = args.mode;
if (args.duration) body.duration = args.duration;
if (args.image_urls?.length) body.image_urls = args.image_urls;
const res = await fetch(`${LIGHTNING_BASE}/gen/video`, {
method: "POST",
headers: { "Content-Type": "application/json", ...authHeaders },
body: JSON.stringify(body),
signal: abortSignal,
});
if (res.ok) {
const buf = await res.arrayBuffer();
const b64 = Buffer.from(buf).toString("base64");
const dataUrl = `data:video/mp4;base64,${b64}`;
onNewAsset({ role: "video", content: dataUrl });
result = "Video generated successfully and shown to the user.";
} else if (res.status == 402) {
result = "An upgraded plan is required for higher limits.";
} else if (res.status == 429) {
result = "Too many requests. Try again later.";
} else {
result = `Video generation failed: ${res.status}`;
}
}
else if (call.function.name === "generate_audio") {
const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, {
method: "POST",
headers: { "Content-Type": "application/json", ...authHeaders },
body: JSON.stringify({ prompt: args.prompt }),
signal: abortSignal,
});
if (res.ok) {
const buf = await res.arrayBuffer();
const b64 = Buffer.from(buf).toString("base64");
const dataUrl = `data:audio/mpeg;base64,${b64}`;
onNewAsset({ role: "audio", content: dataUrl });
result = "Audio generated successfully and shown to the user.";
} else if (res.status == 429) {
result = "Too many requests. Try again later.";
} else {
result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`;
}
}
} catch (err) {
result = `Tool error: ${String(err)}`;
}
onToolCall({ id: call.id, name: call.function.name, state: "resolved", result });
toolResults.push({
role: "tool",
tool_call_id: call.id,
content: typeof result === "string" ? result : JSON.stringify(result),
});
}
return toolResults;
}