import { type Message } from "@huggingface/transformers"; import { executeToolCall, splitResponse, webMCPToolToChatTemplateTool, } from "@utils/webMcp"; import type { WebMCPTool } from "@utils/webMcp/types.ts"; import { type ChatMessage, type ChatMessageAssistant, type ChatMessageAssistantResponse, type ChatMessageAssistantTool, type GenerationMetadata, type Request, RequestType, type Response, ResponseType, } from "./types.ts"; export default class TextGeneration { private worker: Worker; private requestId: number = 0; private modelKey: string | null = null; private tools: Array | null = null; private temperature: number | null = null; private enableThinking: boolean | null = null; private messages: Array = []; private _chatMessages: Array = []; private chatMessagesListener: Array< (chatMessages: Array) => void > = []; constructor() { this.worker = new Worker( new URL("./worker/textGenerationWorker.ts", import.meta.url), { type: "module", } ); } get chatMessages() { return this._chatMessages; } set chatMessages(chatMessages: Array) { this._chatMessages = chatMessages; this.chatMessagesListener.forEach((listener) => listener(chatMessages)); } public onChatMessageUpdate = ( callback: (messages: Array) => void ) => { this.chatMessagesListener.push(callback); return () => { this.chatMessagesListener = this.chatMessagesListener.filter( (listener) => listener !== callback ); }; }; private postWorkerMessage = (request: Request) => this.worker.postMessage(request); private addWorkerEventListener = ( listener: (ev: MessageEvent) => void ) => this.worker.addEventListener("message", listener); private removeWorkerEventListener = ( listener: (ev: MessageEvent) => void ) => this.worker.removeEventListener("message", listener); public async initializeModel( modelKey: string, onDownload: (percentage: number) => void ) { return new Promise((resolve, reject) => { const requestId = this.requestId++; const listener = ({ data }: MessageEvent) => { if (data.requestId !== requestId) return; if (data.type === ResponseType.ERROR) { this.removeWorkerEventListener(listener); reject(data.message); } if (data.type !== ResponseType.INITIALIZE_MODEL) return; if (data.done) { this.removeWorkerEventListener(listener); this.modelKey = modelKey; resolve(data.progress); } onDownload(data.progress); }; this.addWorkerEventListener(listener); this.postWorkerMessage({ type: RequestType.INITIALIZE_MODEL, modelKey, requestId, }); }); } public initializeConversation( tools: Array = [], temperature: number, enableThinking: boolean, systemPrompt: string ) { this.tools = tools; this.temperature = temperature; this.enableThinking = enableThinking; this.messages = [{ role: "system", content: systemPrompt }]; this.chatMessages = [{ role: "system", content: systemPrompt }]; } public async abort() { return new Promise((resolve, reject) => { const requestId = this.requestId++; const listener = ({ data }: MessageEvent) => { if (data.requestId !== requestId) return; if (data.type === ResponseType.ERROR) { this.removeWorkerEventListener(listener); reject(data.message); } if (data.type === ResponseType.GENERATE_TEXT_ABORTED) { this.removeWorkerEventListener(listener); resolve(); } }; this.addWorkerEventListener(listener); this.postWorkerMessage({ type: RequestType.GENERATE_MESSAGE_ABORT, requestId, }); }); } private generateText = ( prompt: string, role: "user" | "tool", onResponseUpdate: (response: string) => void = () => {} ) => { return new Promise<{ response: string; metadata: GenerationMetadata; interrupted: boolean; }>((resolve, reject) => { if (this.modelKey === null) { reject("Model not initialized"); return; } if ( this.tools === null || this.temperature === null || this.enableThinking === null ) { reject("Conversation not initialized"); return; } const requestId = this.requestId++; this.messages = [...this.messages, { role, content: prompt }]; let response = ""; const listener = ({ data }: MessageEvent) => { if (data.requestId !== requestId) return; if (data.type === ResponseType.ERROR) { this.removeWorkerEventListener(listener); reject(data.message); } if (data.type === ResponseType.GENERATE_TEXT_DONE) { this.removeWorkerEventListener(listener); this.messages.push({ role: "assistant", content: data.response }); resolve({ response: data.response, metadata: data.metadata, interrupted: data.interrupted, }); } if (data.type === ResponseType.GENERATE_TEXT_CHUNK) { response = response + data.chunk; onResponseUpdate(response); } }; this.addWorkerEventListener(listener); this.postWorkerMessage({ type: RequestType.GENERATE_MESSAGE, modelKey: this.modelKey, messages: this.messages, tools: this.tools.map(webMCPToolToChatTemplateTool), requestId, temperature: this.temperature, enableThinking: this.enableThinking, }); }); }; public async runAgent(prompt: string): Promise { let isUser = true; this.chatMessages = [ ...this.chatMessages, { role: "user", content: prompt }, ]; while (prompt) { const prevChatMessages = this.chatMessages; const assistantMessage: ChatMessageAssistant = { role: "assistant", content: [], interrupted: false, }; this.chatMessages = [...prevChatMessages, assistantMessage]; const { interrupted, metadata } = await this.generateText( prompt, isUser ? "user" : "tool", (partialResponse) => { const parts = splitResponse(partialResponse); assistantMessage.content = parts.map((part) => typeof part === "string" ? ({ type: "response", content: part, } as ChatMessageAssistantResponse) : ({ type: "tool", result: null, time: null, functionSignature: `${part.name}(${JSON.stringify( part.arguments )})`, ...part, } as ChatMessageAssistantTool) ); this.chatMessages = [...prevChatMessages, assistantMessage]; } ); isUser = false; assistantMessage.metadata = metadata; assistantMessage.interrupted = interrupted; this.chatMessages = [...prevChatMessages, assistantMessage]; const toolCalls = assistantMessage.content.filter( (c) => c.type === "tool" ); if (toolCalls.length === 0) { prompt = ""; continue; } const toolResponses = await Promise.all( toolCalls.map((tool) => executeToolCall( { name: tool.name, arguments: tool.arguments, id: tool.id, }, this.tools || [] ) ) ); assistantMessage.metadata = metadata; assistantMessage.content = assistantMessage.content.map((message) => { if (message.type === "tool") { const toolResponse = toolResponses.find( (response) => response.id === message.id ); if (toolResponse) { return { ...message, result: toolResponse.result, time: toolResponse.time, }; } return message; } else { return message; } }); this.chatMessages = [...prevChatMessages, assistantMessage]; prompt = toolResponses.map(({ result }) => result).join("\n"); } } }