import { isWebGPUAvailable } from "./webGpu"; import { updatePrompt, updateSearchResults, getDisableAiResponseSetting, getSummarizeLinksSetting, getUseLargerModelSetting, updateResponse, getSearchResults, updateUrlsDescriptions, getUrlsDescriptions, getDisableWebGpuUsageSetting, } from "./pubSub"; import { SearchResults, search } from "./search"; import { query, debug } from "./urlParams"; import toast from "react-hot-toast"; import { isRunningOnMobile } from "./mobileDetection"; export async function prepareTextGeneration() { if (query === null) return; document.title = query; updatePrompt(query); updateLoadingToast("Searching the web..."); let searchResults = await search( query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query, 30, ); if (searchResults.length === 0) { const queryKeywords = await getKeywords(query, 10); searchResults = await search(queryKeywords.join(" "), 30); } if (searchResults.length === 0) { toast( "It looks like your current search did not return any results. Try refining your search by adding more keywords or rephrasing your query.", { position: "bottom-center", duration: 10000, icon: "💡", }, ); } updateSearchResults(searchResults); updateUrlsDescriptions( searchResults.reduce( (acc, [, snippet, url]) => ({ ...acc, [url]: snippet }), {}, ), ); dismissLoadingToast(); if (getDisableAiResponseSetting() && !getSummarizeLinksSetting()) return; if (debug) console.time("Response Generation Time"); updateLoadingToast("Loading AI model..."); try { try { if (!isWebGPUAvailable) throw Error("WebGPU is not available."); if (getDisableWebGpuUsageSetting()) throw Error("WebGPU is disabled."); if (getUseLargerModelSetting()) { try { await generateTextWithWebLlm(); } catch (error) { await generateTextWithRatchet(); } } else { try { await generateTextWithRatchet(); } catch (error) { await generateTextWithWebLlm(); } } } catch (error) { await generateTextWithWllama(); } } catch (error) { console.error("Error while generating response with wllama:", error); toast.error( "Could not generate response. The browser may be out of memory. Please close this tab and run this search again in a new one.", { duration: 10000 }, ); } finally { dismissLoadingToast(); } if (debug) { console.timeEnd("Response Generation Time"); } } function updateLoadingToast(text: string) { toast.loading(text, { id: "text-generation-loading-toast", position: "bottom-center", }); } function dismissLoadingToast() { toast.dismiss("text-generation-loading-toast"); } async function generateTextWithWebLlm() { const { CreateWebWorkerEngine, CreateEngine, hasModelInCache } = await import( "@mlc-ai/web-llm" ); const availableModels = { Llama: "Llama-3-8B-Instruct-q4f16_1", Mistral: "Mistral-7B-Instruct-v0.2-q4f16_1", Gemma: "gemma-2b-it-q4f16_1", Phi: "Phi2-q4f16_1", TinyLlama: "TinyLlama-1.1B-Chat-v0.4-q0f16", }; const selectedModel = getUseLargerModelSetting() ? availableModels.Llama : availableModels.Gemma; const isModelCached = await hasModelInCache(selectedModel); let initProgressCallback: | import("@mlc-ai/web-llm").InitProgressCallback | undefined; if (isModelCached) { updateLoadingToast("Generating response..."); } else { initProgressCallback = (report) => { updateLoadingToast( `Loading: ${report.text.replaceAll("[", "(").replaceAll("]", ")")}`, ); }; } const engine = Worker ? await CreateWebWorkerEngine( new Worker(new URL("./webLlmWorker.ts", import.meta.url), { type: "module", }), selectedModel, { initProgressCallback }, ) : await CreateEngine(selectedModel, { initProgressCallback }); if (!getDisableAiResponseSetting()) { updateLoadingToast("Generating response..."); const completion = await engine.chat.completions.create({ stream: true, messages: [{ role: "user", content: getMainPrompt() }], max_gen_len: 768, }); let streamedMessage = ""; for await (const chunk of completion) { const deltaContent = chunk.choices[0].delta.content; if (deltaContent) streamedMessage += deltaContent; updateResponse(streamedMessage); } } await engine.resetChat(); if (getSummarizeLinksSetting()) { updateLoadingToast("Summarizing links..."); for (const [title, snippet, url] of getSearchResults()) { const completion = await engine.chat.completions.create({ stream: true, messages: [ { role: "user", content: await getLinkSummarizationPrompt([title, snippet, url]), }, ], max_gen_len: 768, }); let streamedMessage = ""; for await (const chunk of completion) { const deltaContent = chunk.choices[0].delta.content; if (deltaContent) streamedMessage += deltaContent; updateUrlsDescriptions({ ...getUrlsDescriptions(), [url]: streamedMessage, }); } await engine.resetChat(); } } if (debug) { console.info(await engine.runtimeStatsText()); } engine.unload(); } async function generateTextWithWllama() { const { initializeWllama, runCompletion, exitWllama } = await import( "./wllama" ); const commonSamplingConfig: import("@wllama/wllama").SamplingConfig = { temp: 0.35, dynatemp_range: 0.25, top_k: 0, top_p: 1, min_p: 0.05, tfs_z: 0.95, typical_p: 0.85, penalty_freq: 0.5, penalty_repeat: 1.176, penalty_last_n: -1, mirostat: 2, mirostat_tau: 3.5, }; const availableModels: { [key in | "mobileDefault" | "mobileLarger" | "desktopDefault" | "desktopLarger"]: { url: string; userPrefix: string; assistantPrefix: string; messageSuffix: string; sampling: import("@wllama/wllama").SamplingConfig; }; } = { mobileDefault: { url: "https://huggingface.co/Felladrin/gguf-vicuna-160m/resolve/main/vicuna-160m.Q8_0.gguf", userPrefix: "USER:\n", assistantPrefix: "ASSISTANT:\n", messageSuffix: "\n", sampling: commonSamplingConfig, }, mobileLarger: { url: "https://huggingface.co/Felladrin/gguf-zephyr-220m-dpo-full/resolve/main/zephyr-220m-dpo-full.Q8_0.gguf", userPrefix: "<|user|>\n", assistantPrefix: "<|assistant|>\n", messageSuffix: "\n", sampling: commonSamplingConfig, }, desktopDefault: { url: "https://huggingface.co/Felladrin/gguf-TinyLlama-1.1B-1T-OpenOrca/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf", userPrefix: "<|im_start|>user\n", assistantPrefix: "<|im_start|>assistant\n", messageSuffix: "<|im_end|>\n", sampling: commonSamplingConfig, }, desktopLarger: { url: "https://huggingface.co/Felladrin/gguf-stablelm-2-1_6b-chat/resolve/main/stablelm-2-1_6b-chat.Q8_0.gguf", userPrefix: "<|im_start|>user\n", assistantPrefix: "<|im_start|>assistant\n", messageSuffix: "<|im_end|>\n", sampling: commonSamplingConfig, }, }; const defaultModel = isRunningOnMobile ? availableModels.mobileDefault : availableModels.desktopDefault; const largerModel = isRunningOnMobile ? availableModels.mobileLarger : availableModels.desktopLarger; const selectedModel = getUseLargerModelSetting() ? largerModel : defaultModel; await initializeWllama({ modelUrl: selectedModel.url, modelConfig: { n_ctx: 2048, }, }); if (!getDisableAiResponseSetting()) { const prompt = [ selectedModel.userPrefix, "Hello!", selectedModel.messageSuffix, selectedModel.assistantPrefix, "Hi! How can I help you?", selectedModel.messageSuffix, selectedModel.userPrefix, ["Take a look at this info:", getFormattedSearchResults(5)].join("\n\n"), selectedModel.messageSuffix, selectedModel.assistantPrefix, "Alright!", selectedModel.messageSuffix, selectedModel.userPrefix, "Now I'm going to write my question, and if this info is useful you can use them in your answer. Ready?", selectedModel.messageSuffix, selectedModel.assistantPrefix, "I'm ready to answer!", selectedModel.messageSuffix, selectedModel.userPrefix, query, selectedModel.messageSuffix, selectedModel.assistantPrefix, ].join(""); if (!query) throw Error("Query is empty."); updateLoadingToast("Generating response..."); const completion = await runCompletion({ prompt, nPredict: 768, sampling: selectedModel.sampling, onNewToken: (_token, _piece, currentText) => { updateResponse(currentText); }, }); updateResponse(completion); } if (getSummarizeLinksSetting()) { updateLoadingToast("Summarizing links..."); for (const [title, snippet, url] of getSearchResults()) { const prompt = [ selectedModel.userPrefix, "Hello!", selectedModel.messageSuffix, selectedModel.assistantPrefix, "Hi! How can I help you?", selectedModel.messageSuffix, selectedModel.userPrefix, ["Context:", `${title}: ${snippet}`].join("\n"), "\n", ["Question:", "What is this text about?"].join("\n"), selectedModel.messageSuffix, selectedModel.assistantPrefix, ["Answer:", "This text is about"].join("\n"), ].join(""); const completion = await runCompletion({ prompt, nPredict: 128, sampling: selectedModel.sampling, onNewToken: (_token, _piece, currentText) => { updateUrlsDescriptions({ ...getUrlsDescriptions(), [url]: `This link is about ${currentText}`, }); }, }); updateUrlsDescriptions({ ...getUrlsDescriptions(), [url]: `This link is about ${completion}`, }); } } await exitWllama(); } async function generateTextWithRatchet() { const { initializeRatchet, runCompletion, exitRatchet } = await import( "./ratchet" ); await initializeRatchet((loadingProgressPercentage) => updateLoadingToast(`Loading: ${Math.floor(loadingProgressPercentage)}%`), ); if (!getDisableAiResponseSetting()) { if (!query) throw Error("Query is empty."); updateLoadingToast("Generating response..."); let response = ""; await runCompletion(getMainPrompt(), (completionChunk) => { response += completionChunk; updateResponse(response); }); if (!endsWithASign(response)) { response += "."; updateResponse(response); } } if (getSummarizeLinksSetting()) { updateLoadingToast("Summarizing links..."); for (const [title, snippet, url] of getSearchResults()) { let response = ""; await runCompletion( await getLinkSummarizationPrompt([title, snippet, url]), (completionChunk) => { response += completionChunk; updateUrlsDescriptions({ ...getUrlsDescriptions(), [url]: response, }); }, ); if (!endsWithASign(response)) { response += "."; updateUrlsDescriptions({ ...getUrlsDescriptions(), [url]: response, }); } } } await exitRatchet(); } async function fetchPageContent( url: string, options?: { maxLength?: number; }, ) { const response = await fetch(`https://r.jina.ai/${url}`); if (!response) { throw new Error("No response from server"); } else if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } const text = await response.text(); return text.trim().substring(0, options?.maxLength); } function endsWithASign(text: string) { return text.endsWith(".") || text.endsWith("!") || text.endsWith("?"); } function getMainPrompt() { return [ "Provide a concise response to the request below.", "If the information from the Web Search Results below is useful, you can use it to complement your response. Otherwise, ignore it.", "", "Web Search Results:", "", getFormattedSearchResults(5), "", "Request:", "", query, ].join("\n"); } async function getLinkSummarizationPrompt([ title, snippet, url, ]: SearchResults[0]) { let prompt = ""; try { const pageContent = await fetchPageContent(url, { maxLength: 2500 }); prompt = [ `The context below is related to a link found when searching for "${query}":`, "", "[BEGIN OF CONTEXT]", `Snippet: ${snippet}`, "", pageContent, "[END OF CONTEXT]", "", "Now, tell me: What is this link about and how is it related to the search?", "", "Note: Don't cite the link in your response. Just write a few sentences to indicate if it's worth visiting.", ].join("\n"); } catch (error) { prompt = [ `When searching for "${query}", this link was found: [${title}](${url} "${snippet}")`, "", "Now, tell me: What is this link about and how is it related to the search?", "", "Note: Don't cite the link in your response. Just write a few sentences to indicate if it's worth visiting.", ].join("\n"); } return prompt; } function getFormattedSearchResults(limit?: number) { return getSearchResults() .slice(0, limit) .map(([title, snippet, url]) => `${title}\n${url}\n${snippet}`) .join("\n\n"); } async function getKeywords(text: string, limit?: number) { return (await import("keyword-extractor")).default .extract(text, { language: "english" }) .slice(0, limit); }