| | from modules.chat.helpers import get_prompt |
| | from modules.chat.chat_model_loader import ChatModelLoader |
| | from modules.vectorstore.store_manager import VectorStoreManager |
| | from modules.retriever.retriever import Retriever |
| | from modules.chat.langchain.langchain_rag import CustomConversationalRetrievalChain |
| |
|
| |
|
| | class LLMTutor: |
| | def __init__(self, config, user, logger=None): |
| | """ |
| | Initialize the LLMTutor class. |
| | |
| | Args: |
| | config (dict): Configuration dictionary. |
| | user (str): User identifier. |
| | logger (Logger, optional): Logger instance. Defaults to None. |
| | """ |
| | self.config = config |
| | self.llm = self.load_llm() |
| | self.user = user |
| | self.logger = logger |
| | self.vector_db = VectorStoreManager(config, logger=self.logger) |
| | if self.config["vectorstore"]["embedd_files"]: |
| | self.vector_db.create_database() |
| | self.vector_db.save_database() |
| |
|
| | def update_llm(self, new_config): |
| | """ |
| | Update the LLM and VectorStoreManager based on new configuration. |
| | |
| | Args: |
| | new_config (dict): New configuration dictionary. |
| | """ |
| | changes = self.get_config_changes(self.config, new_config) |
| | self.config = new_config |
| |
|
| | if "chat_model" in changes: |
| | self.llm = self.load_llm() |
| |
|
| | if "vectorstore" in changes: |
| | self.vector_db = VectorStoreManager( |
| | self.config, logger=self.logger |
| | ) |
| | if self.config["vectorstore"]["embedd_files"]: |
| | self.vector_db.create_database() |
| | self.vector_db.save_database() |
| |
|
| | def get_config_changes(self, old_config, new_config): |
| | """ |
| | Get the changes between the old and new configuration. |
| | |
| | Args: |
| | old_config (dict): Old configuration dictionary. |
| | new_config (dict): New configuration dictionary. |
| | |
| | Returns: |
| | dict: Dictionary containing the changes. |
| | """ |
| | changes = {} |
| | for key in new_config: |
| | if old_config.get(key) != new_config[key]: |
| | changes[key] = (old_config.get(key), new_config[key]) |
| | return changes |
| |
|
| | def retrieval_qa_chain(self, llm, qa_prompt, rephrase_prompt, db, memory=None): |
| | """ |
| | Create a Retrieval QA Chain. |
| | |
| | Args: |
| | llm (LLM): The language model instance. |
| | qa_prompt (str): The QA prompt string. |
| | rephrase_prompt (str): The rephrase prompt string. |
| | db (VectorStore): The vector store instance. |
| | memory (Memory, optional): Memory instance. Defaults to None. |
| | |
| | Returns: |
| | Chain: The retrieval QA chain instance. |
| | """ |
| | retriever = Retriever(self.config)._return_retriever(db) |
| |
|
| | if self.config["llm_params"]["use_history"]: |
| | qa_chain = CustomConversationalRetrievalChain( |
| | llm=llm, |
| | memory=memory, |
| | retriever=retriever, |
| | qa_prompt=qa_prompt, |
| | rephrase_prompt=rephrase_prompt, |
| | ) |
| | return qa_chain |
| |
|
| | def load_llm(self): |
| | """ |
| | Load the language model. |
| | |
| | Returns: |
| | LLM: The loaded language model instance. |
| | """ |
| | chat_model_loader = ChatModelLoader(self.config) |
| | llm = chat_model_loader.load_chat_model() |
| | return llm |
| |
|
| | def qa_bot(self, memory=None, qa_prompt=None, rephrase_prompt=None): |
| | """ |
| | Create a QA bot instance. |
| | |
| | Args: |
| | memory (Memory, optional): Memory instance. Defaults to None. |
| | qa_prompt (str, optional): QA prompt string. Defaults to None. |
| | rephrase_prompt (str, optional): Rephrase prompt string. Defaults to None. |
| | |
| | Returns: |
| | Chain: The QA bot chain instance. |
| | """ |
| | if qa_prompt is None: |
| | qa_prompt = get_prompt(self.config, "qa") |
| | if rephrase_prompt is None: |
| | rephrase_prompt = get_prompt(self.config, "rephrase") |
| | db = self.vector_db.load_database() |
| | |
| | if len(db) == 0: |
| | raise ValueError( |
| | "No documents in the database. Populate the database first." |
| | ) |
| | qa = self.retrieval_qa_chain(self.llm, qa_prompt, rephrase_prompt, db, memory) |
| |
|
| | return qa |
| |
|
| | def final_result(query): |
| | """ |
| | Get the final result for a given query. |
| | |
| | Args: |
| | query (str): The query string. |
| | |
| | Returns: |
| | str: The response string. |
| | """ |
| | qa_result = qa_bot() |
| | response = qa_result({"query": query}) |
| | return response |
| |
|