from langchain_community.llms import Ollama from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder from langchain.memory import ConversationBufferWindowMemory from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables.history import RunnableWithMessageHistory from langchain.schema import Document from src.utils import load_config from src.vectorstore import VectorDB def format_docs(docs: list[Document]): return '\n\n'.join(doc.page_content for doc in docs) class OllamaChain: def __init__(self, chat_memory) -> None: prompt = PromptTemplate( template="""<|begin_of_text|> <|start_header_id|>system<|end_header_id|> You are a honest and unbiased AI assistant <|eot_id|> <|start_header_id|>user<|end_header_id|> Previous conversation={chat_history} Question: {input} Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", input_variables=['chat_history', 'input'] ) self.memory = ConversationBufferWindowMemory( memory_key='chat_history', chat_memory=chat_memory, k=3, return_messages=True ) config = load_config() llm = Ollama(**config['chat_model']) # llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1) self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser()) # runnable = prompt | llm def run(self, user_input): response = self.llm_chain.invoke(user_input) return response['text'] class OllamaRAGChain: def __init__(self, chat_memory, uploaded_file=None): # initialize vector db using config from src.utils import load_config config = load_config() vector_db_config = config.get('vector_database', {}) db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma' index_name = 'default' self.vector_db = VectorDB(db_name, index_name) if uploaded_file: self.update_knowledge_base(uploaded_file) # initialize llm config = load_config() self.llm = Ollama(**config['chat_model']) # initialize memory self.chat_memory = chat_memory # initialize sub chain with history message contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \ in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \ standalone question which is incorporated from the latest question and history and can be understood without \ the chat history. Do NOT answer the question, just reformulate it if needed and otherwise return it as is.""" self.contextual_q_prompt = ChatPromptTemplate.from_messages( [ ('system', contextual_q_system_prompt), MessagesPlaceholder('chat_history'), ('human', '{input}'), ] ) self.history_aware_retriever = create_history_aware_retriever( self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt ) # initialize qa chain qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\ context to answer the question. If you don't know the answer, just say that you don't know. Context: {context}""" qa_prompt = ChatPromptTemplate.from_messages( [ ('system', qa_system_prompt), MessagesPlaceholder('chat_history'), ('human', '{input}'), ] ) self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt) rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain) self.conversation_rag_chain = RunnableWithMessageHistory( rag_chain, lambda session_id: chat_memory, input_messages_key='input', history_messages_key='chat_history', output_messages_key='answer' ) def run(self, user_input): config = {"configurable": {"session_id": "any"}} response = self.conversation_rag_chain.invoke({'input': user_input}, config) return response['answer'] def update_chain(self, uploaded_pdf): self.update_knowledge_base(uploaded_pdf) self.history_aware_retriever = create_history_aware_retriever( self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt ) self.conversation_rag_chain = RunnableWithMessageHistory( create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain), lambda session_id: self.chat_memory, input_messages_key='input', history_messages_key='chat_history', output_messages_key='answer' ) def update_knowledge_base(self, uploaded_pdf): self.vector_db.index(uploaded_pdf)