| from langchain_core.prompts import ChatPromptTemplate |
|
|
| |
| from langchain_community.chat_message_histories import ChatMessageHistory |
| from modules.chat.base import BaseRAG |
| from langchain_core.prompts import PromptTemplate |
| from langchain.memory import ConversationBufferWindowMemory |
| from langchain_core.runnables.utils import ConfigurableFieldSpec |
| from .utils import ( |
| CustomConversationalRetrievalChain, |
| create_history_aware_retriever, |
| create_stuff_documents_chain, |
| create_retrieval_chain, |
| return_questions, |
| CustomRunnableWithHistory, |
| BaseChatMessageHistory, |
| InMemoryHistory, |
| ) |
|
|
|
|
| class Langchain_RAG_V1(BaseRAG): |
| def __init__( |
| self, |
| llm, |
| memory, |
| retriever, |
| qa_prompt: str, |
| rephrase_prompt: str, |
| config: dict, |
| callbacks=None, |
| ): |
| """ |
| Initialize the Langchain_RAG class. |
| |
| Args: |
| llm (LanguageModelLike): The language model instance. |
| memory (BaseChatMessageHistory): The chat message history instance. |
| retriever (BaseRetriever): The retriever instance. |
| qa_prompt (str): The QA prompt string. |
| rephrase_prompt (str): The rephrase prompt string. |
| """ |
| self.llm = llm |
| self.config = config |
| |
| self.memory = ConversationBufferWindowMemory( |
| k=self.config["llm_params"]["memory_window"], |
| memory_key="chat_history", |
| return_messages=True, |
| output_key="answer", |
| max_token_limit=128, |
| ) |
| self.retriever = retriever |
| self.qa_prompt = qa_prompt |
| self.rephrase_prompt = rephrase_prompt |
| self.store = {} |
|
|
| self.qa_prompt = PromptTemplate( |
| template=self.qa_prompt, |
| input_variables=["context", "chat_history", "input"], |
| ) |
|
|
| self.rag_chain = CustomConversationalRetrievalChain.from_llm( |
| llm=llm, |
| chain_type="stuff", |
| retriever=retriever, |
| return_source_documents=True, |
| memory=self.memory, |
| combine_docs_chain_kwargs={"prompt": self.qa_prompt}, |
| response_if_no_docs_found="No context found", |
| ) |
|
|
| def add_history_from_list(self, history_list): |
| """ |
| TODO: Add messages from a list to the chat history. |
| """ |
| history = [] |
|
|
| return history |
|
|
| async def invoke(self, user_query, config): |
| """ |
| Invoke the chain. |
| |
| Args: |
| kwargs: The input variables. |
| |
| Returns: |
| dict: The output variables. |
| """ |
| res = await self.rag_chain.acall(user_query["input"]) |
| return res |
|
|
|
|
| class QuestionGenerator: |
| """ |
| Generate a question from the LLMs response and users input and past conversations. |
| """ |
|
|
| def __init__(self): |
| pass |
|
|
| def generate_questions(self, query, response, chat_history, context, config): |
| questions = return_questions(query, response, chat_history, context, config) |
| return questions |
|
|
|
|
| class Langchain_RAG_V2(BaseRAG): |
| def __init__( |
| self, |
| llm, |
| memory, |
| retriever, |
| qa_prompt: str, |
| rephrase_prompt: str, |
| config: dict, |
| callbacks=None, |
| ): |
| """ |
| Initialize the Langchain_RAG class. |
| |
| Args: |
| llm (LanguageModelLike): The language model instance. |
| memory (BaseChatMessageHistory): The chat message history instance. |
| retriever (BaseRetriever): The retriever instance. |
| qa_prompt (str): The QA prompt string. |
| rephrase_prompt (str): The rephrase prompt string. |
| """ |
| self.llm = llm |
| self.memory = self.add_history_from_list(memory) |
| self.retriever = retriever |
| self.qa_prompt = qa_prompt |
| self.rephrase_prompt = rephrase_prompt |
| self.store = {} |
|
|
| |
| contextualize_q_system_prompt = rephrase_prompt or ( |
| "Given a chat history and the latest user question " |
| "which might reference context in the chat history, " |
| "formulate a standalone question which can be understood " |
| "without the chat history. Do NOT answer the question, just " |
| "reformulate it if needed and otherwise return it as is." |
| ) |
| self.contextualize_q_prompt = ChatPromptTemplate.from_template( |
| contextualize_q_system_prompt |
| ) |
|
|
| |
| self.history_aware_retriever = create_history_aware_retriever( |
| self.llm, self.retriever, self.contextualize_q_prompt |
| ) |
|
|
| |
| qa_system_prompt = qa_prompt or ( |
| "You are an assistant for question-answering tasks. Use " |
| "the following pieces of retrieved context to answer the " |
| "question. If you don't know the answer, just say that you " |
| "don't know. Use three sentences maximum and keep the answer " |
| "concise." |
| "\n\n" |
| "{context}" |
| ) |
| self.qa_prompt_template = ChatPromptTemplate.from_template(qa_system_prompt) |
|
|
| |
| self.question_answer_chain = create_stuff_documents_chain( |
| self.llm, self.qa_prompt_template |
| ) |
|
|
| |
| self.rag_chain = create_retrieval_chain( |
| self.history_aware_retriever, self.question_answer_chain |
| ) |
|
|
| self.rag_chain = CustomRunnableWithHistory( |
| self.rag_chain, |
| get_session_history=self.get_session_history, |
| input_messages_key="input", |
| history_messages_key="chat_history", |
| output_messages_key="answer", |
| history_factory_config=[ |
| ConfigurableFieldSpec( |
| id="user_id", |
| annotation=str, |
| name="User ID", |
| description="Unique identifier for the user.", |
| default="", |
| is_shared=True, |
| ), |
| ConfigurableFieldSpec( |
| id="conversation_id", |
| annotation=str, |
| name="Conversation ID", |
| description="Unique identifier for the conversation.", |
| default="", |
| is_shared=True, |
| ), |
| ConfigurableFieldSpec( |
| id="memory_window", |
| annotation=int, |
| name="Number of Conversations", |
| description="Number of conversations to consider for context.", |
| default=1, |
| is_shared=True, |
| ), |
| ], |
| ).with_config(run_name="Langchain_RAG_V2") |
|
|
| if callbacks is not None: |
| self.rag_chain = self.rag_chain.with_config(callbacks=callbacks) |
|
|
| def get_session_history( |
| self, user_id: str, conversation_id: str, memory_window: int |
| ) -> BaseChatMessageHistory: |
| """ |
| Get the session history for a user and conversation. |
| |
| Args: |
| user_id (str): The user identifier. |
| conversation_id (str): The conversation identifier. |
| memory_window (int): The number of conversations to consider for context. |
| |
| Returns: |
| BaseChatMessageHistory: The chat message history. |
| """ |
| if (user_id, conversation_id) not in self.store: |
| self.store[(user_id, conversation_id)] = InMemoryHistory() |
| self.store[(user_id, conversation_id)].add_messages( |
| self.memory.messages |
| ) |
| return self.store[(user_id, conversation_id)] |
|
|
| async def invoke(self, user_query, config, **kwargs): |
| """ |
| Invoke the chain. |
| |
| Args: |
| kwargs: The input variables. |
| |
| Returns: |
| dict: The output variables. |
| """ |
| res = await self.rag_chain.ainvoke(user_query, config, **kwargs) |
| res["rephrase_prompt"] = self.rephrase_prompt |
| res["qa_prompt"] = self.qa_prompt |
| return res |
|
|
| def stream(self, user_query, config): |
| res = self.rag_chain.stream(user_query, config) |
| return res |
|
|
| def add_history_from_list(self, conversation_list): |
| """ |
| Add messages from a list to the chat history. |
| |
| Args: |
| messages (list): The list of messages to add. |
| """ |
| history = ChatMessageHistory() |
|
|
| for idx, message in enumerate(conversation_list): |
| message_type = ( |
| message.get("type", None) |
| if isinstance(message, dict) |
| else getattr(message, "type", None) |
| ) |
|
|
| message_content = ( |
| message.get("content", None) |
| if isinstance(message, dict) |
| else getattr(message, "content", None) |
| ) |
|
|
| if message_type in ["human", "user_message"]: |
| history.add_user_message(message_content) |
| elif message_type in ["ai", "ai_message"]: |
| history.add_ai_message(message_content) |
|
|
| return history |
|
|