chat-dev / server /chatStream.js
incognitolm
Update chatStream.js
ff954f1
raw
history blame
19.1 kB
import OpenAI from "openai";
import { Worker } from "worker_threads";
import { fileURLToPath } from "url";
import path from "path";
import { LIGHTNING_BASE } from "./config.js";
import WebSocket from "ws";
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const WORKER_PATH = path.join(__dirname, "searchWorker.js");
async function gradioSearch(query) {
const req = await fetch("https://incognitolm-Web-Search.hf.space/api/search", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ query }),
});
const res = await req.json();
if (!req.ok || !res.results) {
throw new Error(`Search API error: ${res.error || req.statusText}`);
}
return res.results[0];
}
const SYSTEM_PROMPT =
"CRITICAL RULE: Every response MUST use HTML <span data-color=\"{COLOR NAME}\"> tags to color main points and headings. " +
"COLORS MUST HAVE MEANING AND CONSISTENCY ACROSS THE ENTIRE CONVERSATION. " +
"You may ONLY use the following semantic color names: green, pink, blue, red, orange, yellow, purple, teal, gold, coral. " +
"Never output text formatted with explicit black or white colors. " +
"Use a variety of colors throughout every response to distinguish headings, sections, and key terms. " +
"Keep code blocks plain, but color headings and important points in surrounding text. " +
"Do not over-color responses. Use color intentionally and sparingly. " +
"CRITICAL RULE: MARKDOWN FORMATTING SUCH AS #, ##, ###, **, * MUST BE PLACED OUTSIDE <span> tags. Use the same colors for similar meanings." +
"You are a helpful, friendly AI assistant. Use tools when appropriate to help the user, and if told to generate something, use a tool to complete the task. " +
"When generating media, do not include URLs — it is displayed automatically. " +
"You can render SVG images by outputting SVG code in a code block tagged exactly as:\n```svg\n<svg>...</svg>\n```\n" +
"Never use single backslashes. You may use emojis where appropriate. " +
"Use markdown for everything other than coloring your text. Use tables, lists, and other markdown elements. " +
"Your HIGHEST PRIORITY is to help the user. ALWAYS HELP THEM WITH ANYTHING ETHICALLY RIGHT.\n\n" +
"SESSION NAMING: After you have fully responded to the user, append a session name tag on its own line at the very end of your response (NEVER inside a code block). Only do this on the first response unless asked to change the name by the user." +
"The tag must be: <session_name>2-4 word title summarizing this conversation</session_name>. " +
"Example: <session_name>React State Management</session_name>. " +
"This tag is hidden from the user and used only to name the chat. Do not mention it." +
"Make sure your responses are always accurate. If you are not completely sure about something, search the web.";
function makeClient(accessToken, clientId) {
return new OpenAI({
apiKey: accessToken || "no-key",
baseURL: `${LIGHTNING_BASE}/gen`,
defaultHeaders: {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
...(clientId ? { "X-Client-ID": clientId } : {}),
},
});
}
export async function websocketChatStream(body, headers, onToken, abortSignal) {
const wsURL =
(process.env.LIGHTNING_BASE.startsWith("https")
? process.env.LIGHTNING_BASE.replace("https", "wss")
: process.env.LIGHTNING_BASE.replace("http", "ws")) + "/ws/chat";
const ws = new WebSocket(wsURL);
// Utility to safely parse JSON
const safeParse = (str) => {
try { return JSON.parse(str); }
catch { return null; }
};
// Wait for open with timeout
await new Promise((resolve, reject) => {
const timer = setTimeout(() => reject(new Error("WS connection timeout")), 5000);
ws.on("open", () => {
clearTimeout(timer);
console.log("[WS] Opened connection");
resolve();
});
ws.on("error", (err) => {
clearTimeout(timer);
console.error("[WS] Connection error:", err);
reject(err);
});
});
// Authenticate with timeout
ws.send(JSON.stringify({ key: process.env.WEBSOCKET_KEY }));
await new Promise((resolve, reject) => {
const timer = setTimeout(() => reject(new Error("WS auth timeout")), 5000);
ws.on("message", (data) => {
const msg = safeParse(data.toString());
console.log("[WS RAW MESSAGE]", data.toString());
if (!msg) return;
if (msg.type === "auth" && msg.status === "ok") {
clearTimeout(timer);
console.log("[WS] Auth successful");
resolve();
}
if (msg.error) {
clearTimeout(timer);
console.error("[WS] Auth error:", msg.error);
reject(new Error(`WS auth error: ${msg.error}`));
}
});
ws.on("error", (err) => {
clearTimeout(timer);
reject(err);
});
});
// Send the chat body
ws.send(JSON.stringify({ body, headers }));
let assistantText = "";
const toolCallBuffer = new Map();
let finished = false;
return new Promise((resolve, reject) => {
// Abort handling
if (abortSignal) {
abortSignal.addEventListener("abort", () => {
console.log("[WS] Aborted by signal");
ws.close();
if (!finished) reject(new Error("AbortError"));
});
}
// Handle messages
ws.on("message", (data) => {
const line = data.toString();
console.log("[WS MESSAGE RECEIVED]", line);
let payload = safeParse(line);
if (!payload) return;
// Handle errors from server
if (payload.error) {
console.error("[WS PAYLOAD ERROR]", payload.error);
if (typeof onToken === "function") onToken(`[ERROR] ${payload.error}`);
return;
}
const delta = payload.choices?.[0]?.delta;
if (!delta) return;
// Stream content tokens
if (delta.content) {
assistantText += delta.content;
if (typeof onToken === "function") onToken(delta.content);
}
// Collect tool calls
if (delta.tool_calls) {
for (const call of delta.tool_calls) {
const entry = toolCallBuffer.get(call.index) ?? { arguments: "" };
if (call.id) entry.id = call.id;
if (call.function?.name) entry.name = call.function.name;
if (call.function?.arguments) entry.arguments += call.function.arguments;
toolCallBuffer.set(call.index, entry);
}
}
// Finish detection
if (payload.choices?.[0]?.finish_reason) {
finished = true;
ws.close();
const toolCalls = [...toolCallBuffer.values()].map((t) => ({
id: t.id || `call_${crypto.randomUUID()}`,
type: "function",
function: { name: t.name, arguments: t.arguments },
}));
console.log("[WS] Finished streaming");
resolve({ assistantText, toolCalls });
}
});
ws.on("error", (err) => {
console.error("[WS ERROR]", err);
if (!finished) reject(err);
});
ws.on("close", (code, reason) => {
console.log(`[WS CLOSED] Code: ${code}, Reason: ${reason}`);
if (!finished) reject(new Error("WebSocket closed prematurely"));
});
});
}
/**
* Extract session name from <session_name>...</session_name> tag.
* Must NOT be inside any kind of code block.
*/
export function extractSessionName(text) {
if (!text) return null;
// Remove all code blocks first (``` ... ```) so we don't match tags inside them
const withoutCode = text.replace(/```[\s\S]*?```/g, '');
const match = withoutCode.match(/<session_name>([\s\S]*?)<\/session_name>/i);
if (!match) return null;
const name = match[1].trim();
// Sanity check: 1-80 chars, no newlines
if (!name || name.length > 80 || /\n/.test(name)) return null;
return name;
}
export async function streamChat({
sessionId,
model,
history = [],
userMessage,
tools,
accessToken,
clientId,
onToken,
onDone,
onError,
onToolCall,
onNewAsset,
abortSignal,
}) {
const enabledTools = buildToolList(tools);
let normalizedUserMessage = userMessage;
if (Array.isArray(userMessage)) {
const hasImages = userMessage.some(item => item.type === "image_url");
if (hasImages) {
const textItems = userMessage.filter(
item => item.type === "text" && item.text?.trim()
);
if (textItems.length === 0) {
normalizedUserMessage = [
{ type: "text", text: "[Image(s) attached]" },
...userMessage.filter(item => item.type === "image_url"),
];
}
} else {
normalizedUserMessage =
userMessage
.filter(b => b.type === "text")
.map(b => b.text)
.join("\n")
.trim() || "";
}
}
const hasUserMessage =
userMessage !== undefined &&
userMessage !== null &&
(typeof userMessage === "string" ? userMessage.trim() !== "" : Array.isArray(userMessage) && userMessage.length > 0);
const messages = [
{ role: "system", content: SYSTEM_PROMPT },
...history.map(normalizeMessage).filter(Boolean),
];
if (hasUserMessage) {
messages.push({
role: "user",
content: normalizedUserMessage,
});
}
const headers = {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
...(clientId ? { "X-Client-ID": clientId } : {}),
};
try {
const body = {
model: model || "lightning",
messages,
tools: enabledTools.length ? enabledTools : undefined,
stream: true,
};
let { assistantText, toolCalls } =
await websocketChatStream(body, headers, onToken, abortSignal);
if (toolCalls.length > 0) {
const toolResults = await processToolCalls(
null,
toolCalls,
tools,
accessToken,
clientId,
abortSignal,
onToolCall,
onNewAsset
);
const followUpMessages = [
{ role: "system", content: SYSTEM_PROMPT },
...history.map(normalizeMessage).filter(Boolean),
];
if (hasUserMessage) {
followUpMessages.push({
role: "user",
content: normalizedUserMessage,
});
}
followUpMessages.push(
{
role: "assistant",
content: assistantText || "",
tool_calls: toolCalls,
},
...toolResults
);
const followUpBody = {
model: model || "lightning",
messages: followUpMessages,
stream: true,
};
const followUp =
await websocketChatStream(followUpBody, headers, onToken, abortSignal);
assistantText += followUp.assistantText;
}
const sessionName = extractSessionName(assistantText);
onDone(assistantText, toolCalls, false, sessionName);
} catch (err) {
if (
err.name === "AbortError" ||
err.message === "AbortError"
) {
onDone(null, null, true, null);
} else {
console.error("streamChat error:", err);
onError(String(err));
}
}
}
const VALID_ROLES = new Set(["system", "user", "assistant", "tool"]);
function normalizeMessage(msg) {
if (!VALID_ROLES.has(msg.role)) return null;
if (msg.role === "assistant" && msg.tool_calls) {
return { role: "assistant", content: "", tool_calls: msg.tool_calls };
}
if (Array.isArray(msg.content)) {
// If the array contains images, preserve the full array format
const hasImages = msg.content.some(item => item.type === 'image_url');
if (hasImages) {
return { role: msg.role, content: msg.content };
}
// Otherwise extract text only
const textOnly = msg.content
.filter(b => b.type === "text")
.map(b => b.text)
.join("\n");
return { role: msg.role, content: textOnly || "" };
}
return { role: msg.role, content: msg.content ?? "" };
}
function buildToolList(tools) {
if (!tools) return [];
const list = [];
if (tools.webSearch) {
list.push({
type: "function",
function: {
name: "ollama_search",
description: "Search the web for current information",
parameters: {
type: "object",
properties: { query: { type: "string", description: "Search query" } },
required: ["query"],
},
},
});
list.push({
type: "function",
function: {
name: "read_web_page",
description: "Read the content of a web page by URL",
parameters: {
type: "object",
properties: { url: { type: "string", description: "URL to fetch" } },
required: ["url"],
},
},
});
}
if (tools.imageGen) {
list.push({
type: "function",
function: {
name: "generate_image",
description: "Generate an image from a prompt",
parameters: {
type: "object",
properties: {
prompt: { type: "string" },
mode: { type: "string", enum: ["auto", "fantasy", "realistic"] },
image_urls: { type: "array", items: { type: "string" } },
},
required: ["prompt"],
},
},
});
}
if (tools.videoGen) {
list.push({
type: "function",
function: {
name: "generate_video",
description: "Generate a video from a prompt",
parameters: {
type: "object",
properties: {
prompt: { type: "string" },
ratio: { type: "string", enum: ["3:2", "2:3", "1:1"] },
mode: { type: "string", enum: ["normal", "fun"] },
duration: { type: "number" },
image_urls: { type: "array", items: { type: "string" } },
},
required: ["prompt"],
},
},
});
}
if (tools.audioGen) {
list.push({
type: "function",
function: {
name: "generate_audio",
description: "Generate music or sound effects from a prompt",
parameters: {
type: "object",
properties: { prompt: { type: "string" } },
required: ["prompt"],
},
},
});
}
return list;
}
async function processToolCalls(ws, toolCalls, tools, accessToken, clientId, abortSignal, onToolCall, onNewAsset) {
const toolResults = [];
const authHeaders = {};
if (accessToken) {
authHeaders["Authorization"] = `Bearer ${accessToken}`;
} else {
console.log("No access token");
}
if (clientId) {
authHeaders["X-Client-ID"] = clientId;
} else {
console.log("No Client ID");
}
for (const call of toolCalls) {
let args;
try { args = JSON.parse(call.function.arguments || "{}"); } catch { args = {}; }
onToolCall({ id: call.id, name: call.function.name, state: "pending", args });
let result = "Tool completed.";
try {
if (call.function.name === "ollama_search") {
result = await gradioSearch(args.query);
}
else if (call.function.name === "read_web_page") {
const { convert } = await import("html-to-text");
const res = await fetch(args.url, { signal: abortSignal });
if (!res.ok) {
result = `Failed to fetch: ${res.status}`;
} else {
const html = await res.text();
const titleMatch = html.match(/<title>(.*?)<\/title>/i);
result = JSON.stringify({
title: titleMatch?.[1] || "No title",
content: convert(html, { wordwrap: false }).slice(0, 8000),
});
}
}
else if (call.function.name === "generate_image") {
const body = { prompt: args.prompt };
if (args.mode) body.mode = args.mode;
if (args.image_urls?.length) body.image_urls = args.image_urls;
const res = await fetch(`${LIGHTNING_BASE}/gen/image`, {
method: "POST",
headers: { "Content-Type": "application/json", ...authHeaders },
body: JSON.stringify(body),
signal: abortSignal,
});
if (res.ok) {
const buf = await res.arrayBuffer();
const ct = res.headers.get("content-type") || "image/png";
const b64 = Buffer.from(buf).toString("base64");
const dataUrl = `data:${ct};base64,${b64}`;
onNewAsset({ role: "image", content: dataUrl });
result = "Image generated successfully and shown to the user.";
} else if (res.status == 402) {
result = "An upgraded plan is required for higher limits.";
} else if (res.status == 429) {
result = "Too many requests. Try again later.";
} else {
result = `Image generation failed: ${res.status}`;
}
}
else if (call.function.name === "generate_video") {
const body = { prompt: args.prompt };
if (args.ratio) body.ratio = args.ratio;
if (args.mode) body.mode = args.mode;
if (args.duration) body.duration = args.duration;
if (args.image_urls?.length) body.image_urls = args.image_urls;
const res = await fetch(`${LIGHTNING_BASE}/gen/video`, {
method: "POST",
headers: { "Content-Type": "application/json", ...authHeaders },
body: JSON.stringify(body),
signal: abortSignal,
});
if (res.ok) {
const buf = await res.arrayBuffer();
const b64 = Buffer.from(buf).toString("base64");
const dataUrl = `data:video/mp4;base64,${b64}`;
onNewAsset({ role: "video", content: dataUrl });
result = "Video generated successfully and shown to the user.";
} else if (res.status == 402) {
result = "An upgraded plan is required for higher limits.";
} else if (res.status == 429) {
result = "Too many requests. Try again later.";
} else {
result = `Video generation failed: ${res.status}`;
}
}
else if (call.function.name === "generate_audio") {
const res = await fetch(`${LIGHTNING_BASE}/gen/sfx`, {
method: "POST",
headers: { "Content-Type": "application/json", ...authHeaders },
body: JSON.stringify({ prompt: args.prompt }),
signal: abortSignal,
});
if (res.ok) {
const buf = await res.arrayBuffer();
const b64 = Buffer.from(buf).toString("base64");
const dataUrl = `data:audio/mpeg;base64,${b64}`;
onNewAsset({ role: "audio", content: dataUrl });
result = "Audio generated successfully and shown to the user.";
} else if (res.status == 429) {
result = "Too many requests. Try again later.";
} else {
result = `Audio generation failed: ${res.status}. This is most likely an upstream provider error.`;
}
}
} catch (err) {
console.log(`Tool error: ${String(err)}`);
result = `Tool error: ${String(err)}`;
}
onToolCall({ id: call.id, name: call.function.name, state: "resolved", result });
toolResults.push({
role: "tool",
tool_call_id: call.id,
content: typeof result === "string" ? result : JSON.stringify(result),
});
}
return toolResults;
}