import OpenAI from "openai"; import { Worker } from "worker_threads"; import { fileURLToPath } from "url"; import path from "path"; import { LIGHTNING_BASE } from "./config.js"; import WebSocket from "ws"; import crypto from "crypto"; const __dirname = path.dirname(fileURLToPath(import.meta.url)); const WORKER_PATH = path.join(__dirname, "searchWorker.js"); // Persistent WebSocket pool let persistentWs = null; let wsAuthPromise = null; let requestIdCounter = 0; let activeStreamHandlers = new Map(); // Track active stream handlers by request ID let errorHandlers = new Map(); // Track error handlers by request ID async function getSafeWebSocket() { // If we have a valid persistent connection, return it if (persistentWs && persistentWs.readyState === WebSocket.OPEN) { return persistentWs; } // If we're already authenticating, wait for it if (wsAuthPromise) { return wsAuthPromise; } // Create and authenticate a new websocket wsAuthPromise = (async () => { const wsURL = (process.env.LIGHTNING_BASE.startsWith("https") ? process.env.LIGHTNING_BASE.replace("https", "wss") : process.env.LIGHTNING_BASE.replace("http", "ws")) + "/ws/chat"; persistentWs = new WebSocket(wsURL); const safeParse = (str) => { try { const cleaned = str.startsWith("data: ") ? str.slice(6) : str; return JSON.parse(cleaned); } catch { return null; } }; // Wait for connection open await new Promise((resolve, reject) => { const timer = setTimeout(() => reject(new Error("WS connection timeout")), 5000); persistentWs.on("open", () => { clearTimeout(timer); resolve(); }); persistentWs.on("error", (err) => { console.error("[WS] Connection error", err); clearTimeout(timer); persistentWs = null; wsAuthPromise = null; reject(err); }); }); // Authenticate persistentWs.send(JSON.stringify({ key: process.env.WEBSOCKET_KEY })); await new Promise((resolve, reject) => { const timer = setTimeout(() => reject(new Error("WS auth timeout")), 5000); const authHandler = (data) => { const msg = safeParse(data.toString()); if (!msg) return; if (msg.type === "auth" && msg.status === "ok") { persistentWs.removeListener("message", authHandler); clearTimeout(timer); resolve(); } if (msg.error) { console.error("[WS] Auth error", msg.error); persistentWs.removeListener("message", authHandler); clearTimeout(timer); persistentWs = null; wsAuthPromise = null; reject(new Error(`WS auth error: ${msg.error}`)); } }; persistentWs.on("message", authHandler); persistentWs.on("error", (err) => { console.error("[WS] Auth error event", err); persistentWs.removeListener("message", authHandler); clearTimeout(timer); persistentWs = null; wsAuthPromise = null; reject(err); }); }); // Set up the global message and error routing after authentication const globalMessageHandler = (data) => { const line = data.toString(); // Route to all active stream handlers for (const [id, handler] of activeStreamHandlers.entries()) { if (!String(id).startsWith('__')) { // Skip metadata keys handler(line); } } }; const globalErrorHandler = (err) => { console.error("[WS ERROR]", err); // Notify all active streams for (const [id, handler] of errorHandlers.entries()) { handler(err); } }; persistentWs.on("message", globalMessageHandler); persistentWs.on("error", globalErrorHandler); activeStreamHandlers.set("__messageListener__", globalMessageHandler); activeStreamHandlers.set("__errorHandler__", globalErrorHandler); wsAuthPromise = null; return persistentWs; })(); return wsAuthPromise; } async function gradioSearch(query) { const req = await fetch("https://incognitolm-Web-Search.hf.space/api/search", { method: "POST", headers: { "Content-Type": "application/json" }, body: JSON.stringify({ query }), }); const res = await req.json(); if (!req.ok || !res.results) { throw new Error(`Search API error: ${res.error || req.statusText}`); } return res.results[0]; } const SYSTEM_PROMPT = "CRITICAL RULE: Every response MUST use HTML tags to color main points and headings. " + "COLORS MUST HAVE MEANING AND CONSISTENCY ACROSS THE ENTIRE CONVERSATION. " + "You may ONLY use the following semantic color names: green, pink, blue, red, orange, yellow, purple, teal, gold, coral. " + "Never output text formatted with explicit black or white colors. " + "Use a variety of colors throughout every response to distinguish headings, sections, and key terms. " + "Keep code blocks plain, but color headings and important points in surrounding text. " + "Do not over-color responses. Use color intentionally and sparingly. " + "CRITICAL RULE: MARKDOWN FORMATTING SUCH AS #, ##, ###, **, * MUST BE PLACED OUTSIDE tags. Use the same colors for similar meanings." + "You are a helpful, friendly AI assistant. Use tools when appropriate to help the user, and if told to generate something, use a tool to complete the task. " + "When generating media, do not include URLs — it is displayed automatically. " + "You can render SVG images by outputting SVG code in a code block tagged exactly as:\n```svg\n...\n```\n" + "Never use single backslashes. You may use emojis where appropriate. " + "Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " + "Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.\n\n" + "SESSION NAMING: After you have fully responded to the user, append a session name tag on its own line at the very end of your response (NEVER inside a code block). Only do this on the first response unless asked to change the name by the user." + "The tag must be: 2-4 word title summarizing this conversation. " + "Example: React State Management. " + "This tag is hidden from the user and used only to name the chat. Do not mention it." + "Make sure your responses are always accurate. If you are not completely sure about something, search the web."; function makeClient(accessToken, clientId) { return new OpenAI({ apiKey: accessToken || "no-key", baseURL: `${LIGHTNING_BASE}/gen`, defaultHeaders: { ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), ...(clientId ? { "X-Client-ID": clientId } : {}), }, }); } export async function websocketChatStream(body, headers, onToken, abortSignal) { const ws = await getSafeWebSocket(); const currentRequestId = ++requestIdCounter; const safeParse = (str) => { try { const cleaned = str.startsWith("data: ") ? str.slice(6) : str; return JSON.parse(cleaned); } catch { return null; } }; // Send the chat body ws.send(JSON.stringify({ body, headers })); let assistantText = ""; const toolCallBuffer = new Map(); let finished = false; return new Promise((resolve, reject) => { // Timeout fallback - resolve after 120 seconds if no finish_reason const timeoutId = setTimeout(() => { if (!finished) { finished = true; cleanup(); const toolCalls = [...toolCallBuffer.values()].map((t) => ({ id: t.id || `call_${crypto.randomUUID()}`, type: "function", function: { name: t.name, arguments: t.arguments }, })); resolve({ assistantText, toolCalls }); } }, 120000); const cleanup = () => { activeStreamHandlers.delete(currentRequestId); errorHandlers.delete(currentRequestId); clearTimeout(timeoutId); if (abortSignal) { abortSignal.removeEventListener("abort", abortHandler); } }; const messageHandler = (line) => { // Parse request ID prefix (format: "id:payload") const colonIdx = line.indexOf(':'); if (colonIdx === -1) { return; // Ignore messages without request ID format } const msgRequestId = line.substring(0, colonIdx); const payload = safeParse(line.substring(colonIdx + 1)); // Ignore messages from other requests if (msgRequestId !== String(currentRequestId)) { return; } if (!payload) { console.warn("[WS] Failed to parse JSON:", line); return; } // Only treat as fatal error if it's a structured error response (not delta data with error field) if (payload.error && !payload.choices) { console.error("[WS PAYLOAD ERROR]", payload.error); if (onToken) onToken(`[ERROR] ${payload.error}`); // Mark as finished to end the stream on fatal errors only if (!finished) { finished = true; cleanup(); const toolCalls = [...toolCallBuffer.values()].map((t) => ({ id: t.id || `call_${crypto.randomUUID()}`, type: "function", function: { name: t.name, arguments: t.arguments }, })); resolve({ assistantText, toolCalls }); } return; } const delta = payload.choices?.[0]?.delta; if (delta?.content) { assistantText += delta.content; if (onToken) onToken(delta.content); } if (delta?.tool_calls) { for (const call of delta.tool_calls) { const entry = toolCallBuffer.get(call.index) ?? { arguments: "" }; if (call.id) entry.id = call.id; if (call.function?.name) entry.name = call.function.name; if (call.function?.arguments) entry.arguments += call.function.arguments; toolCallBuffer.set(call.index, entry); } } if (payload.choices?.[0]?.finish_reason) { if (!finished) { finished = true; cleanup(); const toolCalls = [...toolCallBuffer.values()].map((t) => ({ id: t.id || `call_${crypto.randomUUID()}`, type: "function", function: { name: t.name, arguments: t.arguments }, })); resolve({ assistantText, toolCalls }); } } }; const errorHandler = (err) => { if (!finished) { finished = true; cleanup(); reject(err); } }; const abortHandler = () => { if (!finished) { finished = true; cleanup(); reject(new Error("AbortError")); } }; // Register handlers for this request activeStreamHandlers.set(currentRequestId, messageHandler); errorHandlers.set(currentRequestId, errorHandler); if (abortSignal) { abortSignal.addEventListener("abort", abortHandler); } }); } /** * Extract session name from ... tag. * Must NOT be inside any kind of code block. */ export function extractSessionName(text) { if (!text) return null; // Remove all code blocks first (``` ... ```) so we don't match tags inside them const withoutCode = text.replace(/```[\s\S]*?```/g, ''); const match = withoutCode.match(/([\s\S]*?)<\/session_name>/i); if (!match) return null; const name = match[1].trim(); // Sanity check: 1-80 chars, no newlines if (!name || name.length > 80 || /\n/.test(name)) return null; return name; } export async function streamChat({ sessionId, model, history = [], userMessage, tools, accessToken, clientId, onToken = () => {}, onDone = () => {}, onError = () => {}, onToolCall = () => {}, onNewAsset = () => {}, abortSignal, }) { const enabledTools = buildToolList(tools); let normalizedUserMessage = userMessage; if (Array.isArray(userMessage)) { const hasImages = userMessage.some(item => item.type === "image_url"); if (hasImages) { const textItems = userMessage.filter( item => item.type === "text" && item.text?.trim() ); if (textItems.length === 0) { normalizedUserMessage = [ { type: "text", text: "[Image(s) attached]" }, ...userMessage.filter(item => item.type === "image_url"), ]; } } else { normalizedUserMessage = userMessage .filter(b => b.type === "text") .map(b => b.text) .join("\n") .trim() || ""; } } const hasUserMessage = userMessage !== undefined && userMessage !== null && (typeof userMessage === "string" ? userMessage.trim() !== "" : Array.isArray(userMessage) && userMessage.length > 0); const messages = [ { role: "system", content: SYSTEM_PROMPT }, ...history.map(normalizeMessage).filter(Boolean), ]; if (hasUserMessage) { messages.push({ role: "user", content: normalizedUserMessage, }); } const headers = { ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), ...(clientId ? { "X-Client-ID": clientId } : {}), }; try { const body = { model: model || "lightning", messages, tools: enabledTools.length ? enabledTools : undefined, stream: true, }; let { assistantText, toolCalls } = await websocketChatStream(body, headers, onToken, abortSignal); if (toolCalls.length > 0) { const toolResults = await processToolCalls( null, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset ); const followUpMessages = [ { role: "system", content: SYSTEM_PROMPT }, ...history.map(normalizeMessage).filter(Boolean), ]; if (hasUserMessage) { followUpMessages.push({ role: "user", content: normalizedUserMessage, }); } followUpMessages.push( { role: "assistant", content: assistantText || "", tool_calls: toolCalls, }, ...toolResults ); const followUpBody = { model: model || "lightning", messages: followUpMessages, stream: true, }; const followUp = await websocketChatStream(followUpBody, headers, onToken, abortSignal); assistantText += followUp.assistantText; } const sessionName = extractSessionName(assistantText); if (typeof onDone === "function") { onDone(assistantText, toolCalls, false, sessionName); } } catch (err) { console.error("streamChat error:", err); if ( err.name === "AbortError" || err.message === "AbortError" ) { if (typeof onDone === "function") { onDone(null, null, true, null); } } else { console.error("streamChat error:", err); if (typeof onError === "function") { onError(String(err)); } } } } const VALID_ROLES = new Set(["system", "user", "assistant", "tool"]); function normalizeMessage(msg) { if (!VALID_ROLES.has(msg.role)) return null; if (msg.role === "assistant" && msg.tool_calls) { return { role: "assistant", content: "", tool_calls: msg.tool_calls }; } if (Array.isArray(msg.content)) { // If the array contains images, preserve the full array format const hasImages = msg.content.some(item => item.type === 'image_url'); if (hasImages) { return { role: msg.role, content: msg.content }; } // Otherwise extract text only const textOnly = msg.content .filter(b => b.type === "text") .map(b => b.text) .join("\n"); return { role: msg.role, content: textOnly || "" }; } return { role: msg.role, content: msg.content ?? "" }; } function buildToolList(tools) { if (!tools) return []; const list = []; if (tools.webSearch) { list.push({ type: "function", function: { name: "ollama_search", description: "Search the web for current information", parameters: { type: "object", properties: { query: { type: "string", description: "Search query" } }, required: ["query"], }, }, }); list.push({ type: "function", function: { name: "read_web_page", description: "Read the content of a web page by URL", parameters: { type: "object", properties: { url: { type: "string", description: "URL to fetch" } }, required: ["url"], }, }, }); } if (tools.imageGen) { list.push({ type: "function", function: { name: "generate_image", description: "Generate an image from a prompt", parameters: { type: "object", properties: { prompt: { type: "string" }, mode: { type: "string", enum: ["auto", "fantasy", "realistic"] }, image_urls: { type: "array", items: { type: "string" } }, }, required: ["prompt"], }, }, }); } if (tools.videoGen) { list.push({ type: "function", function: { name: "generate_video", description: "Generate a video from a prompt", parameters: { type: "object", properties: { prompt: { type: "string" }, ratio: { type: "string", enum: ["3:2", "2:3", "1:1"] }, mode: { type: "string", enum: ["normal", "fun"] }, duration: { type: "number" }, image_urls: { type: "array", items: { type: "string" } }, }, required: ["prompt"], }, }, }); } if (tools.audioGen) { list.push({ type: "function", function: { name: "generate_audio", description: "Generate music or sound effects from a prompt", parameters: { type: "object", properties: { prompt: { type: "string" } }, required: ["prompt"], }, }, }); } return list; } async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) { const toolResults = []; const authHeaders = {}; if (accessToken) { authHeaders["Authorization"] = `Bearer ${accessToken}`; } if (clientId) { authHeaders["X-Client-ID"] = clientId; } for (const call of toolCalls) { let args; try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; } onToolCall({ id: call.id, name: call.function.name, state: "pending", args }); let result = "Tool completed."; try { if (call.function.name === "ollama_search") { result = await gradioSearch(args.query); } else if (call.function.name === "read_web_page") { const { convert } = await import("html-to-text"); const res = await fetch(args.url, { signal: abortSignal }); if (!res.ok) { result = `Failed to fetch: ${res.status}`; } else { const html = await res.text(); const titleMatch = html.match(/(.*?)<\/title>/i); result = JSON.stringify({ title: titleMatch?.[1] || "No title", content: convert(html, { wordwrap: false }).slice(0, 8000), }); } } else if (call.function.name === "generate_image") { const body = { prompt: args.prompt }; if (args.mode) body.mode = args.mode; if (args.image_urls?.length) body.image_urls = args.image_urls; const res = await fetch(`${LIGHTNING_BASE}/gen/image`, { method: "POST", headers: { "Content-Type": "application/json", ...authHeaders }, body: JSON.stringify(body), signal: abortSignal, }); if (res.ok) { const buf = await res.arrayBuffer(); const ct = res.headers.get("content-type") || "image/png"; const b64 = Buffer.from(buf).toString("base64"); const dataUrl = `data:${ct};base64,${b64}`; onNewAsset({ role: "image", content: dataUrl }); result = "Image generated successfully and shown to the user."; } else if (res.status == 402) { result = "An upgraded plan is required for higher limits."; } else if (res.status == 429) { result = "Too many requests. Try again later."; } else { result = `Image generation failed: ${res.status}`; } } else if (call.function.name === "generate_video") { const body = { prompt: args.prompt }; if (args.ratio) body.ratio = args.ratio; if (args.mode) body.mode = args.mode; if (args.duration) body.duration = args.duration; if (args.image_urls?.length) body.image_urls = args.image_urls; const res = await fetch(`${LIGHTNING_BASE}/gen/video`, { method: "POST", headers: { "Content-Type": "application/json", ...authHeaders }, body: JSON.stringify(body), signal: abortSignal, }); if (res.ok) { const buf = await res.arrayBuffer(); const b64 = Buffer.from(buf).toString("base64"); const dataUrl = `data:video/mp4;base64,${b64}`; onNewAsset({ role: "video", content: dataUrl }); result = "Video generated successfully and shown to the user."; } else if (res.status == 402) { result = "An upgraded plan is required for higher limits."; } else if (res.status == 429) { result = "Too many requests. Try again later."; } else { result = `Video generation failed: ${res.status}`; } } else if (call.function.name === "generate_audio") { const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, { method: "POST", headers: { "Content-Type": "application/json", ...authHeaders }, body: JSON.stringify({ prompt: args.prompt }), signal: abortSignal, }); if (res.ok) { const buf = await res.arrayBuffer(); const b64 = Buffer.from(buf).toString("base64"); const dataUrl = `data:audio/mpeg;base64,${b64}`; onNewAsset({ role: "audio", content: dataUrl }); result = "Audio generated successfully and shown to the user."; } else if (res.status == 429) { result = "Too many requests. Try again later."; } else { result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`; } } } catch (err) { result = `Tool error: ${String(err)}`; } onToolCall({ id: call.id, name: call.function.name, state: "resolved", result }); toolResults.push({ role: "tool", tool_call_id: call.id, content: typeof result === "string" ? result : JSON.stringify(result), }); } return toolResults; }