| | import time |
| | import re |
| | import json |
| | import os |
| | from datetime import datetime |
| |
|
| | import gradio as gr |
| | import torch |
| |
|
| | import modules.shared as shared |
| | from modules import chat, ui as ui_module |
| | from modules.utils import gradio |
| | from modules.text_generation import generate_reply_HF, generate_reply_custom |
| | from .llm_web_search import get_webpage_content, langchain_search_duckduckgo, langchain_search_searxng, Generator |
| | from .langchain_websearch import LangchainCompressor |
| |
|
| |
|
| | params = { |
| | "display_name": "LLM Web Search", |
| | "is_tab": True, |
| | "enable": True, |
| | "search results per query": 5, |
| | "langchain similarity score threshold": 0.5, |
| | "instant answers": True, |
| | "regular search results": True, |
| | "search command regex": "", |
| | "default search command regex": r"Search_web\(\"(.*)\"\)", |
| | "open url command regex": "", |
| | "default open url command regex": r"Open_url\(\"(.*)\"\)", |
| | "display search results in chat": True, |
| | "display extracted URL content in chat": True, |
| | "searxng url": "", |
| | "cpu only": True, |
| | "chunk size": 500, |
| | "duckduckgo results per query": 10, |
| | "append current datetime": False, |
| | "default system prompt filename": None, |
| | "force search prefix": "Search_web", |
| | "ensemble weighting": 0.5, |
| | "keyword retriever": "bm25", |
| | "splade batch size": 2, |
| | "chunking method": "character-based", |
| | "chunker breakpoint_threshold_amount": 30 |
| | } |
| | custom_system_message_filename = None |
| | extension_path = os.path.dirname(os.path.abspath(__file__)) |
| | langchain_compressor = None |
| | update_history = None |
| | force_search = False |
| |
|
| |
|
| | def setup(): |
| | """ |
| | Is executed when the extension gets imported. |
| | :return: |
| | """ |
| | global params |
| | os.environ["TOKENIZERS_PARALLELISM"] = "true" |
| | os.environ["QDRANT__TELEMETRY_DISABLED"] = "true" |
| |
|
| | try: |
| | with open(os.path.join(extension_path, "settings.json"), "r") as f: |
| | saved_params = json.load(f) |
| | params.update(saved_params) |
| | save_settings() |
| | except FileNotFoundError: |
| | save_settings() |
| |
|
| | if not os.path.exists(os.path.join(extension_path, "system_prompts")): |
| | os.makedirs(os.path.join(extension_path, "system_prompts")) |
| |
|
| | toggle_extension(params["enable"]) |
| |
|
| |
|
| | def save_settings(): |
| | global params |
| | with open(os.path.join(extension_path, "settings.json"), "w") as f: |
| | json.dump(params, f, indent=4) |
| | current_datetime = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
| | return gr.HTML(f'<span style="color:lawngreen"> Settings were saved at {current_datetime}</span>', |
| | visible=True) |
| |
|
| |
|
| | def toggle_extension(_enable: bool): |
| | global langchain_compressor, custom_system_message_filename |
| | if _enable: |
| | langchain_compressor = LangchainCompressor(device="cpu" if params["cpu only"] else "cuda", |
| | keyword_retriever=params["keyword retriever"], |
| | model_cache_dir=os.path.join(extension_path, "hf_models")) |
| | compressor_model = langchain_compressor.embeddings.client |
| | compressor_model.to(compressor_model._target_device) |
| | custom_system_message_filename = params.get("default system prompt filename") |
| | else: |
| | if not params["cpu only"] and 'langchain_compressor' in globals(): |
| | model_attrs = ["embeddings", "splade_doc_model", "splade_query_model"] |
| | for model_attr in model_attrs: |
| | if hasattr(langchain_compressor, model_attr): |
| | model = getattr(langchain_compressor, model_attr) |
| | if hasattr(model, "client"): |
| | model.client.to("cpu") |
| | del model.client |
| | else: |
| | if hasattr(model, "to"): |
| | model.to("cpu") |
| | del model |
| | torch.cuda.empty_cache() |
| | params.update({"enable": _enable}) |
| | return _enable |
| |
|
| |
|
| | def get_available_system_prompts(): |
| | try: |
| | return ["None"] + sorted(os.listdir(os.path.join(extension_path, "system_prompts"))) |
| | except FileNotFoundError: |
| | return ["None"] |
| |
|
| |
|
| | def load_system_prompt(filename: str or None): |
| | global custom_system_message_filename |
| | if not filename: |
| | return |
| | if filename == "None" or filename == "Select custom system message to load...": |
| | custom_system_message_filename = None |
| | return "" |
| | with open(os.path.join(extension_path, "system_prompts", filename), "r") as f: |
| | prompt_str = f.read() |
| |
|
| | if params["append current datetime"]: |
| | prompt_str += f"\nDate and time of conversation: {datetime.now().strftime('%A %d %B %Y %H:%M')}" |
| |
|
| | shared.settings['custom_system_message'] = prompt_str |
| | custom_system_message_filename = filename |
| | return prompt_str |
| |
|
| |
|
| | def save_system_prompt(filename, prompt): |
| | if not filename: |
| | return |
| |
|
| | with open(os.path.join(extension_path, "system_prompts", filename), "w") as f: |
| | f.write(prompt) |
| |
|
| | return gr.HTML(f'<span style="color:lawngreen"> Saved successfully</span>', |
| | visible=True) |
| |
|
| |
|
| | def check_file_exists(filename): |
| | if filename == "": |
| | return gr.HTML("", visible=False) |
| | if os.path.exists(os.path.join(extension_path, "system_prompts", filename)): |
| | return gr.HTML(f'<span style="color:orange"> Warning: Filename already exists</span>', visible=True) |
| | return gr.HTML("", visible=False) |
| |
|
| |
|
| | def timeout_save_message(): |
| | time.sleep(2) |
| | return gr.HTML("", visible=False) |
| |
|
| |
|
| | def deactivate_system_prompt(): |
| | shared.settings['custom_system_message'] = None |
| | return "None" |
| |
|
| |
|
| | def toggle_forced_search(value): |
| | global force_search |
| | force_search = value |
| |
|
| |
|
| | def ui(): |
| | """ |
| | Creates custom gradio elements when the UI is launched. |
| | :return: |
| | """ |
| | |
| | shared.gradio['custom_system_message'].value = load_system_prompt(custom_system_message_filename) |
| |
|
| | def update_result_type_setting(choice: str): |
| | if choice == "Instant answers": |
| | params.update({"instant answers": True}) |
| | params.update({"regular search results": False}) |
| | elif choice == "Regular results": |
| | params.update({"instant answers": False}) |
| | params.update({"regular search results": True}) |
| | elif choice == "Regular results and instant answers": |
| | params.update({"instant answers": True}) |
| | params.update({"regular search results": True}) |
| |
|
| | def update_regex_setting(input_str: str, setting_key: str, error_html_element: gr.component): |
| | if input_str == "": |
| | params.update({setting_key: params[f"default {setting_key}"]}) |
| | return {error_html_element: gr.HTML("", visible=False)} |
| | try: |
| | compiled = re.compile(input_str) |
| | if compiled.groups > 1: |
| | raise re.error(f"Only 1 capturing group allowed in regex, but there are {compiled.groups}.") |
| | params.update({setting_key: input_str}) |
| | return {error_html_element: gr.HTML("", visible=False)} |
| | except re.error as e: |
| | return {error_html_element: gr.HTML(f'<span style="color:red"> Invalid regex. {str(e).capitalize()}</span>', |
| | visible=True)} |
| |
|
| | def update_default_custom_system_message(check: bool): |
| | if check: |
| | params.update({"default system prompt filename": custom_system_message_filename}) |
| | else: |
| | params.update({"default system prompt filename": None}) |
| |
|
| | with gr.Row(): |
| | enable = gr.Checkbox(value=lambda: params['enable'], label='Enable LLM web search') |
| | use_cpu_only = gr.Checkbox(value=lambda: params['cpu only'], |
| | label='Run extension on CPU only ' |
| | '(Save settings and restart for the change to take effect)') |
| | with gr.Column(): |
| | save_settings_btn = gr.Button("Save settings") |
| | saved_success_elem = gr.HTML("", visible=False) |
| |
|
| | with gr.Row(): |
| | result_radio = gr.Radio( |
| | ["Regular results", "Regular results and instant answers"], |
| | label="What kind of search results should be returned?", |
| | value=lambda: "Regular results and instant answers" if |
| | (params["regular search results"] and params["instant answers"]) else "Regular results" |
| | ) |
| | with gr.Column(): |
| | search_command_regex = gr.Textbox(label="Search command regex string", |
| | placeholder=params["default search command regex"], |
| | value=lambda: params["search command regex"]) |
| | search_command_regex_error_label = gr.HTML("", visible=False) |
| |
|
| | with gr.Column(): |
| | open_url_command_regex = gr.Textbox(label="Open URL command regex string", |
| | placeholder=params["default open url command regex"], |
| | value=lambda: params["open url command regex"]) |
| | open_url_command_regex_error_label = gr.HTML("", visible=False) |
| |
|
| | with gr.Column(): |
| | show_results = gr.Checkbox(value=lambda: params['display search results in chat'], |
| | label='Display search results in chat') |
| | show_url_content = gr.Checkbox(value=lambda: params['display extracted URL content in chat'], |
| | label='Display extracted URL content in chat') |
| | gr.Markdown(value='---') |
| | with gr.Row(): |
| | with gr.Column(): |
| | gr.Markdown(value='#### Load custom system message\n' |
| | 'Select a saved custom system message from within the system_prompts folder or "None" ' |
| | 'to clear the selection') |
| | system_prompt = gr.Dropdown( |
| | choices=get_available_system_prompts(), label="Select custom system message", |
| | value=lambda: 'Select custom system message to load...' if custom_system_message_filename is None else |
| | custom_system_message_filename, elem_classes='slim-dropdown') |
| | with gr.Row(): |
| | set_system_message_as_default = gr.Checkbox( |
| | value=lambda: custom_system_message_filename == params["default system prompt filename"], |
| | label='Set this custom system message as the default') |
| | refresh_button = ui_module.create_refresh_button(system_prompt, lambda: None, |
| | lambda: {'choices': get_available_system_prompts()}, |
| | 'refresh-button', interactive=True) |
| | refresh_button.elem_id = "custom-sysprompt-refresh" |
| | delete_button = gr.Button('🗑️', elem_classes='refresh-button', interactive=True) |
| | append_datetime = gr.Checkbox(value=lambda: params['append current datetime'], |
| | label='Append current date and time when loading custom system message') |
| | with gr.Column(): |
| | gr.Markdown(value='#### Create custom system message') |
| | system_prompt_text = gr.Textbox(label="Custom system message", lines=3, |
| | value=lambda: load_system_prompt(custom_system_message_filename)) |
| | sys_prompt_filename = gr.Text(label="Filename") |
| | sys_prompt_save_button = gr.Button("Save Custom system message") |
| | system_prompt_saved_success_elem = gr.HTML("", visible=False) |
| | |
| | gr.Markdown(value='---') |
| | with gr.Accordion("Advanced settings", open=False): |
| | ensemble_weighting = gr.Slider(minimum=0, maximum=1, step=0.05, value=lambda: params["ensemble weighting"], |
| | label="Ensemble Weighting", info="Smaller values = More keyword oriented, " |
| | "Larger values = More focus on semantic similarity") |
| | with gr.Row(): |
| | keyword_retriever = gr.Radio([("Okapi BM25", "bm25"),("SPLADE", "splade")], label="Sparse keyword retriever", |
| | info="For change to take effect, toggle the extension off and on again", |
| | value=lambda: params["keyword retriever"]) |
| | splade_batch_size = gr.Slider(minimum=2, maximum=256, step=2, value=lambda: params["splade batch size"], |
| | label="SPLADE batch size", |
| | info="Smaller values = Slower retrieval (but lower VRAM usage), " |
| | "Larger values = Faster retrieval (but higher VRAM usage). " |
| | "A good trade-off seems to be setting it = 8", |
| | precision=0) |
| | with gr.Row(): |
| | chunker = gr.Radio([("Character-based", "character-based"), |
| | ("Semantic", "semantic")], label="Chunking method", |
| | value=lambda: params["chunking method"]) |
| | chunker_breakpoint_threshold_amount = gr.Slider(minimum=1, maximum=100, step=1, |
| | value=lambda: params["chunker breakpoint_threshold_amount"], |
| | label="Semantic chunking: sentence split threshold (%)", |
| | info="Defines how different two consecutive sentences have" |
| | " to be for them to be split into separate chunks", |
| | precision=0) |
| | gr.Markdown("**Note: Changing the following might result in DuckDuckGo rate limiting or the LM being overwhelmed**") |
| | num_search_results = gr.Number(label="Max. search results to return per query", minimum=1, maximum=100, |
| | value=lambda: params["search results per query"], precision=0) |
| | num_process_search_results = gr.Number(label="Number of search results to process per query", minimum=1, |
| | maximum=100, value=lambda: params["duckduckgo results per query"], |
| | precision=0) |
| | langchain_similarity_threshold = gr.Number(label="Langchain Similarity Score Threshold", minimum=0., maximum=1., |
| | value=lambda: params["langchain similarity score threshold"]) |
| | chunk_size = gr.Number(label="Max. chunk size", info="The maximal size of the individual chunks that each webpage will" |
| | " be split into, in characters", minimum=2, maximum=10000, |
| | value=lambda: params["chunk size"], precision=0) |
| |
|
| | with gr.Row(): |
| | searxng_url = gr.Textbox(label="SearXNG URL", |
| | value=lambda: params["searxng url"]) |
| |
|
| | |
| | enable.input(toggle_extension, enable, enable) |
| | use_cpu_only.change(lambda x: params.update({"cpu only": x}), use_cpu_only, None) |
| | save_settings_btn.click(save_settings, None, [saved_success_elem]) |
| | ensemble_weighting.change(lambda x: params.update({"ensemble weighting": x}), ensemble_weighting, None) |
| | keyword_retriever.change(lambda x: params.update({"keyword retriever": x}), keyword_retriever, None) |
| | splade_batch_size.change(lambda x: params.update({"splade batch size": x}), splade_batch_size, None) |
| | chunker.change(lambda x: params.update({"chunking method": x}), chunker, None) |
| | chunker_breakpoint_threshold_amount.change(lambda x: params.update({"chunker breakpoint_threshold_amount": x}), |
| | chunker_breakpoint_threshold_amount, None) |
| | num_search_results.change(lambda x: params.update({"search results per query": x}), num_search_results, None) |
| | num_process_search_results.change(lambda x: params.update({"duckduckgo results per query": x}), |
| | num_process_search_results, None) |
| | langchain_similarity_threshold.change(lambda x: params.update({"langchain similarity score threshold": x}), |
| | langchain_similarity_threshold, None) |
| | chunk_size.change(lambda x: params.update({"chunk size": x}), chunk_size, None) |
| | result_radio.change(update_result_type_setting, result_radio, None) |
| |
|
| | search_command_regex.change(lambda x: update_regex_setting(x, "search command regex", |
| | search_command_regex_error_label), |
| | search_command_regex, search_command_regex_error_label, show_progress="hidden") |
| |
|
| | open_url_command_regex.change(lambda x: update_regex_setting(x, "open url command regex", |
| | open_url_command_regex_error_label), |
| | open_url_command_regex, open_url_command_regex_error_label, show_progress="hidden") |
| |
|
| | show_results.change(lambda x: params.update({"display search results in chat": x}), show_results, None) |
| | show_url_content.change(lambda x: params.update({"display extracted URL content in chat": x}), show_url_content, |
| | None) |
| | searxng_url.change(lambda x: params.update({"searxng url": x}), searxng_url, None) |
| |
|
| | delete_button.click( |
| | lambda x: x, system_prompt, gradio('delete_filename')).then( |
| | lambda: os.path.join(extension_path, "system_prompts", ""), None, gradio('delete_root')).then( |
| | lambda: gr.update(visible=True), None, gradio('file_deleter')) |
| | shared.gradio['delete_confirm'].click( |
| | lambda: "None", None, system_prompt).then( |
| | None, None, None, _js="() => { document.getElementById('custom-sysprompt-refresh').click() }") |
| | system_prompt.change(load_system_prompt, system_prompt, shared.gradio['custom_system_message']) |
| | system_prompt.change(load_system_prompt, system_prompt, system_prompt_text) |
| | |
| | system_prompt.change(lambda x: x == params["default system prompt filename"], system_prompt, |
| | set_system_message_as_default) |
| | sys_prompt_filename.change(check_file_exists, sys_prompt_filename, system_prompt_saved_success_elem) |
| | sys_prompt_save_button.click(save_system_prompt, [sys_prompt_filename, system_prompt_text], |
| | system_prompt_saved_success_elem, |
| | show_progress="hidden").then(timeout_save_message, |
| | None, |
| | system_prompt_saved_success_elem, |
| | _js="() => { document.getElementById('custom-sysprompt-refresh').click() }", |
| | show_progress="hidden").then(lambda: "", None, |
| | sys_prompt_filename, |
| | show_progress="hidden") |
| | append_datetime.change(lambda x: params.update({"append current datetime": x}), append_datetime, None) |
| | |
| | set_system_message_as_default.input(update_default_custom_system_message, set_system_message_as_default, None) |
| |
|
| | |
| | force_search_checkbox = gr.Checkbox(value=False, visible=False, elem_id="Force-search-checkbox") |
| | force_search_checkbox.change(toggle_forced_search, force_search_checkbox, None) |
| |
|
| |
|
| | def custom_generate_reply(question, original_question, seed, state, stopping_strings, is_chat): |
| | """ |
| | Overrides the main text generation function. |
| | :return: |
| | """ |
| | global update_history, langchain_compressor |
| | if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'Exllamav2Model', |
| | 'CtransformersModel']: |
| | generate_func = generate_reply_custom |
| | else: |
| | generate_func = generate_reply_HF |
| |
|
| | if not params['enable']: |
| | for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): |
| | yield reply |
| | return |
| |
|
| | web_search = False |
| | read_webpage = False |
| | max_search_results = int(params["search results per query"]) |
| | instant_answers = params["instant answers"] |
| | |
| |
|
| | langchain_compressor.num_results = int(params["duckduckgo results per query"]) |
| | langchain_compressor.similarity_threshold = params["langchain similarity score threshold"] |
| | langchain_compressor.chunk_size = params["chunk size"] |
| | langchain_compressor.ensemble_weighting = params["ensemble weighting"] |
| | langchain_compressor.splade_batch_size = params["splade batch size"] |
| | langchain_compressor.chunking_method = params["chunking method"] |
| | langchain_compressor.chunker_breakpoint_threshold_amount = params["chunker breakpoint_threshold_amount"] |
| |
|
| | search_command_regex = params["search command regex"] |
| | open_url_command_regex = params["open url command regex"] |
| | searxng_url = params["searxng url"] |
| | display_search_results = params["display search results in chat"] |
| | display_webpage_content = params["display extracted URL content in chat"] |
| |
|
| | if search_command_regex == "": |
| | search_command_regex = params["default search command regex"] |
| | if open_url_command_regex == "": |
| | open_url_command_regex = params["default open url command regex"] |
| |
|
| | compiled_search_command_regex = re.compile(search_command_regex) |
| | compiled_open_url_command_regex = re.compile(open_url_command_regex) |
| |
|
| | if force_search: |
| | question += f" {params['force search prefix']}" |
| |
|
| | reply = None |
| | for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): |
| |
|
| | if force_search: |
| | reply = params["force search prefix"] + reply |
| |
|
| | search_re_match = compiled_search_command_regex.search(reply) |
| | if search_re_match is not None: |
| | yield reply |
| | original_model_reply = reply |
| | web_search = True |
| | search_term = search_re_match.group(1) |
| | print(f"LLM_Web_search | Searching for {search_term}...") |
| | reply += "\n```plaintext" |
| | reply += "\nSearch tool:\n" |
| | if searxng_url == "": |
| | search_generator = Generator(langchain_search_duckduckgo(search_term, |
| | langchain_compressor, |
| | max_search_results, |
| | instant_answers)) |
| | else: |
| | search_generator = Generator(langchain_search_searxng(search_term, |
| | searxng_url, |
| | langchain_compressor, |
| | max_search_results)) |
| | try: |
| | for status_message in search_generator: |
| | yield original_model_reply + f"\n*{status_message}*" |
| | search_results = search_generator.value |
| | except Exception as exc: |
| | exception_message = str(exc) |
| | reply += f"The search tool encountered an error: {exception_message}" |
| | print(f'LLM_Web_search | {search_term} generated an exception: {exception_message}') |
| | else: |
| | if search_results != "": |
| | reply += search_results |
| | else: |
| | reply += f"\nThe search tool did not return any results." |
| | reply += "```" |
| | if display_search_results: |
| | yield reply |
| | break |
| |
|
| | open_url_re_match = compiled_open_url_command_regex.search(reply) |
| | if open_url_re_match is not None: |
| | yield reply |
| | original_model_reply = reply |
| | read_webpage = True |
| | url = open_url_re_match.group(1) |
| | print(f"LLM_Web_search | Reading {url}...") |
| | reply += "\n```plaintext" |
| | reply += "\nURL opener tool:\n" |
| | try: |
| | webpage_content = get_webpage_content(url) |
| | except Exception as exc: |
| | reply += f"Couldn't open {url}. Error message: {str(exc)}" |
| | print(f'LLM_Web_search | {url} generated an exception: {str(exc)}') |
| | else: |
| | reply += f"\nText content of {url}:\n" |
| | reply += webpage_content |
| | reply += "```\n" |
| | if display_webpage_content: |
| | yield reply |
| | break |
| | yield reply |
| |
|
| | if web_search or read_webpage: |
| | display_results = web_search and display_search_results or read_webpage and display_webpage_content |
| | |
| | new_question = chat.generate_chat_prompt(f"{question}{reply}", state) |
| | new_reply = "" |
| | for new_reply in generate_func(new_question, new_question, seed, state, |
| | stopping_strings, is_chat=is_chat): |
| | if display_results: |
| | yield f"{reply}\n{new_reply}" |
| | else: |
| | yield f"{original_model_reply}\n{new_reply}" |
| |
|
| | if not display_results: |
| | update_history = [state["textbox"], f"{reply}\n{new_reply}"] |
| |
|
| |
|
| | def output_modifier(string, state, is_chat=False): |
| | """ |
| | Modifies the output string before it is presented in the UI. In chat mode, |
| | it is applied to the bot's reply. Otherwise, it is applied to the entire |
| | output. |
| | :param string: |
| | :param state: |
| | :param is_chat: |
| | :return: |
| | """ |
| | return string |
| |
|
| |
|
| | def custom_css(): |
| | """ |
| | Returns custom CSS as a string. It is applied whenever the web UI is loaded. |
| | :return: |
| | """ |
| | return '' |
| |
|
| |
|
| | def custom_js(): |
| | """ |
| | Returns custom javascript as a string. It is applied whenever the web UI is |
| | loaded. |
| | :return: |
| | """ |
| | with open(os.path.join(extension_path, "script.js"), "r") as f: |
| | return f.read() |
| |
|
| |
|
| | def chat_input_modifier(text, visible_text, state): |
| | """ |
| | Modifies both the visible and internal inputs in chat mode. Can be used to |
| | hijack the chat input with custom content. |
| | :param text: |
| | :param visible_text: |
| | :param state: |
| | :return: |
| | """ |
| | return text, visible_text |
| |
|
| |
|
| | def state_modifier(state): |
| | """ |
| | Modifies the dictionary containing the UI input parameters before it is |
| | used by the text generation functions. |
| | :param state: |
| | :return: |
| | """ |
| | return state |
| |
|
| |
|
| | def history_modifier(history): |
| | """ |
| | Modifies the chat history before the text generation in chat mode begins. |
| | :param history: |
| | :return: |
| | """ |
| | global update_history |
| | if update_history: |
| | history["internal"].append(update_history) |
| | update_history = None |
| | return history |