Spaces:
Running
Running
| # Implement Classification | |
| import os | |
| from langchain.prompts.chat import ChatPromptTemplate | |
| from langchain.memory import ConversationBufferMemory | |
| from generator import load_llm | |
| from langchain.prompts import PromptTemplate | |
| from retrieverV2 import process_pdf_document, create_vectorstore, rag_retriever | |
| from langchain.schema import format_document | |
| from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string | |
| from langchain_core.runnables import RunnableParallel | |
| from langchain_core.runnables import RunnableLambda, RunnablePassthrough | |
| from operator import itemgetter | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| class ModelPipeLine: | |
| DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}") | |
| def __init__(self): | |
| self.curr_dir = os.path.dirname(__file__) | |
| self.knowledge_dir = "knowledge" | |
| print("Knowledge Directory:", self.knowledge_dir) | |
| self.prompt_dir = 'prompts' | |
| self.child_splitter = RecursiveCharacterTextSplitter(chunk_size=200) | |
| self.parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500) | |
| self.documents = process_pdf_document([os.path.join(self.knowledge_dir, 'depression_1.pdf'), os.path.join(self.knowledge_dir, 'depression_2.pdf')]) | |
| self.vectorstore, self.store = create_vectorstore() | |
| self.retriever = rag_retriever(self.vectorstore, self.store, self.documents, self.parent_splitter, self.child_splitter) # Create the retriever | |
| self.llm = load_llm() # Load the LLM model | |
| self.memory = ConversationBufferMemory(return_messages=True, | |
| output_key="answer", | |
| input_key="question") # Instantiate ConversationBufferMemory | |
| def get_prompts(self, system_file_path='system_prompt_template.txt', | |
| condense_file_path='condense_question_prompt_template.txt'): | |
| with open(os.path.join(self.prompt_dir, system_file_path), 'r') as f: | |
| system_prompt_template = f.read() | |
| with open(os.path.join(self.prompt_dir, condense_file_path), 'r') as f: | |
| condense_question_prompt = f.read() | |
| # create message templates | |
| ANSWER_PROMPT = ChatPromptTemplate.from_template(system_prompt_template) | |
| # create message templates | |
| CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(condense_question_prompt) | |
| return ANSWER_PROMPT, CONDENSE_QUESTION_PROMPT | |
| def _combine_documents(self,docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"): | |
| doc_strings = [format_document(doc, document_prompt) for doc in docs] | |
| return document_separator.join(doc_strings) | |
| def create_final_chain(self): | |
| answer_prompt, condense_question_prompt = self.get_prompts() | |
| # Debugging outputs | |
| print("Condense Question Prompt:", condense_question_prompt) | |
| print("LLM:", self.llm) | |
| loaded_memory = RunnablePassthrough.assign( | |
| chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"), | |
| ) | |
| # Check if loaded_memory is valid | |
| if loaded_memory is None: | |
| raise ValueError("Loaded memory is None") | |
| standalone_question = { | |
| "standalone_question": { | |
| "question": lambda x: x["question"], | |
| "chat_history": lambda x: get_buffer_string(x["chat_history"]), | |
| } | |
| | condense_question_prompt | |
| | self.llm, | |
| } | |
| # Ensure standalone_question is valid | |
| if standalone_question is None: | |
| raise ValueError("Standalone question is None") | |
| retrieved_documents = { | |
| "docs": itemgetter("standalone_question") | self.retriever, | |
| "question": lambda x: x["standalone_question"], | |
| } | |
| final_inputs = { | |
| "context": lambda x: self._combine_documents(x["docs"]), | |
| "question": itemgetter("question"), | |
| } | |
| answer = { | |
| "answer": final_inputs | answer_prompt | self.llm, | |
| "docs": itemgetter("docs"), | |
| } | |
| final_chain = loaded_memory | standalone_question | retrieved_documents | answer | |
| return final_chain | |
| def call_conversational_rag(self,question, chain): | |
| """ | |
| Calls a conversational RAG (Retrieval-Augmented Generation) model to generate an answer to a given question. | |
| This function sends a question to the RAG model, retrieves the answer, and stores the question-answer pair in memory | |
| for context in future interactions. | |
| Parameters: | |
| question (str): The question to be answered by the RAG model. | |
| chain (LangChain object): An instance of LangChain which encapsulates the RAG model and its functionality. | |
| memory (Memory object): An object used for storing the context of the conversation. | |
| Returns: | |
| dict: A dictionary containing the generated answer from the RAG model. | |
| """ | |
| # Prepare the input for the RAG model | |
| inputs = {"question": question} | |
| # Invoke the RAG model to get an answer | |
| result = chain.invoke(inputs) | |
| # Save the current question and its answer to memory for future context | |
| self.memory.save_context(inputs, {"answer": result["answer"]}) | |
| # Return the result | |
| return result | |
| ml_pipeline = ModelPipeLine() | |
| final_chain = ml_pipeline.create_final_chain() | |
| question = "i am feeling sad" | |
| res = ml_pipeline.call_conversational_rag(question,final_chain) | |
| print(res['answer']) | |