| from modules.chat.helpers import get_prompt |
| from modules.chat.chat_model_loader import ChatModelLoader |
| from modules.vectorstore.store_manager import VectorStoreManager |
| from modules.retriever.retriever import Retriever |
| from modules.chat.langchain.langchain_rag import ( |
| Langchain_RAG_V2, |
| QuestionGenerator, |
| ) |
|
|
|
|
| class LLMTutor: |
| def __init__(self, config, user, logger=None): |
| """ |
| Initialize the LLMTutor class. |
| |
| Args: |
| config (dict): Configuration dictionary. |
| user (str): User identifier. |
| logger (Logger, optional): Logger instance. Defaults to None. |
| """ |
| self.config = config |
| self.llm = self.load_llm() |
| self.user = user |
| self.logger = logger |
| self.vector_db = VectorStoreManager(config, logger=self.logger).load_database() |
| self.qa_prompt = get_prompt(config, "qa") |
| self.rephrase_prompt = get_prompt( |
| config, "rephrase" |
| ) |
|
|
| |
| |
| |
| |
|
|
| def update_llm(self, old_config, new_config): |
| """ |
| Update the LLM and VectorStoreManager based on new configuration. |
| |
| Args: |
| new_config (dict): New configuration dictionary. |
| """ |
| changes = self.get_config_changes(old_config, new_config) |
|
|
| if "llm_params.llm_loader" in changes: |
| self.llm = self.load_llm() |
|
|
| if "vectorstore.db_option" in changes: |
| self.vector_db = VectorStoreManager( |
| self.config, logger=self.logger |
| ).load_database() |
|
|
| |
| |
| |
| |
|
|
| if "llm_params.llm_style" in changes: |
| self.qa_prompt = get_prompt( |
| self.config, "qa" |
| ) |
|
|
| def get_config_changes(self, old_config, new_config): |
| """ |
| Get the changes between the old and new configuration. |
| |
| Args: |
| old_config (dict): Old configuration dictionary. |
| new_config (dict): New configuration dictionary. |
| |
| Returns: |
| dict: Dictionary containing the changes. |
| """ |
| changes = {} |
|
|
| def compare_dicts(old, new, parent_key=""): |
| for key in new: |
| full_key = f"{parent_key}.{key}" if parent_key else key |
| if isinstance(new[key], dict) and isinstance(old.get(key), dict): |
| compare_dicts(old.get(key, {}), new[key], full_key) |
| elif old.get(key) != new[key]: |
| changes[full_key] = (old.get(key), new[key]) |
| |
| for key in old: |
| if key not in new: |
| full_key = f"{parent_key}.{key}" if parent_key else key |
| changes[full_key] = (old[key], None) |
|
|
| compare_dicts(old_config, new_config) |
| return changes |
|
|
| def retrieval_qa_chain( |
| self, llm, qa_prompt, rephrase_prompt, db, memory=None, callbacks=None |
| ): |
| """ |
| Create a Retrieval QA Chain. |
| |
| Args: |
| llm (LLM): The language model instance. |
| qa_prompt (str): The QA prompt string. |
| rephrase_prompt (str): The rephrase prompt string. |
| db (VectorStore): The vector store instance. |
| memory (Memory, optional): Memory instance. Defaults to None. |
| |
| Returns: |
| Chain: The retrieval QA chain instance. |
| """ |
| retriever = Retriever(self.config)._return_retriever(db) |
|
|
| if self.config["llm_params"]["llm_arch"] == "langchain": |
| self.qa_chain = Langchain_RAG_V2( |
| llm=llm, |
| memory=memory, |
| retriever=retriever, |
| qa_prompt=qa_prompt, |
| rephrase_prompt=rephrase_prompt, |
| config=self.config, |
| callbacks=callbacks, |
| ) |
|
|
| self.question_generator = QuestionGenerator() |
| else: |
| raise ValueError( |
| f"Invalid LLM Architecture: {self.config['llm_params']['llm_arch']}" |
| ) |
| return self.qa_chain |
|
|
| def load_llm(self): |
| """ |
| Load the language model. |
| |
| Returns: |
| LLM: The loaded language model instance. |
| """ |
| chat_model_loader = ChatModelLoader(self.config) |
| llm = chat_model_loader.load_chat_model() |
| return llm |
|
|
| def qa_bot(self, memory=None, callbacks=None): |
| """ |
| Create a QA bot instance. |
| |
| Args: |
| memory (Memory, optional): Memory instance. Defaults to None. |
| qa_prompt (str, optional): QA prompt string. Defaults to None. |
| rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None. |
| |
| Returns: |
| Chain: The QA bot chain instance. |
| """ |
| |
| if len(self.vector_db) == 0: |
| raise ValueError( |
| "No documents in the database. Populate the database first." |
| ) |
|
|
| qa = self.retrieval_qa_chain( |
| self.llm, |
| self.qa_prompt, |
| self.rephrase_prompt, |
| self.vector_db, |
| memory, |
| callbacks=callbacks, |
| ) |
|
|
| return qa |
|
|