Spaces:
Runtime error
Runtime error
| from langchain_community.llms import Ollama | |
| from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.memory import ConversationBufferWindowMemory | |
| from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables.history import RunnableWithMessageHistory | |
| from langchain.schema import Document | |
| from src.utils import load_config | |
| from src.vectorstore import VectorDB | |
| def format_docs(docs: list[Document]): | |
| return '\n\n'.join(doc.page_content for doc in docs) | |
| class OllamaChain: | |
| def __init__(self, chat_memory) -> None: | |
| prompt = PromptTemplate( | |
| template="""<|begin_of_text|> | |
| <|start_header_id|>system<|end_header_id|> | |
| You are a honest and unbiased AI assistant | |
| <|eot_id|> | |
| <|start_header_id|>user<|end_header_id|> | |
| Previous conversation={chat_history} | |
| Question: {input} | |
| Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""", | |
| input_variables=['chat_history', 'input'] | |
| ) | |
| self.memory = ConversationBufferWindowMemory( | |
| memory_key='chat_history', | |
| chat_memory=chat_memory, | |
| k=3, | |
| return_messages=True | |
| ) | |
| config = load_config() | |
| llm = Ollama(**config['chat_model']) | |
| # llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1) | |
| self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser()) | |
| # runnable = prompt | llm | |
| def run(self, user_input): | |
| response = self.llm_chain.invoke(user_input) | |
| return response['text'] | |
| class OllamaRAGChain: | |
| def __init__(self, chat_memory, uploaded_file=None): | |
| # initialize vector db using config | |
| from src.utils import load_config | |
| config = load_config() | |
| vector_db_config = config.get('vector_database', {}) | |
| db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma' | |
| index_name = 'default' | |
| self.vector_db = VectorDB(db_name, index_name) | |
| if uploaded_file: | |
| self.update_knowledge_base(uploaded_file) | |
| # initialize llm | |
| config = load_config() | |
| self.llm = Ollama(**config['chat_model']) | |
| # initialize memory | |
| self.chat_memory = chat_memory | |
| # initialize sub chain with history message | |
| contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \ | |
| in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \ | |
| standalone question which is incorporated from the latest question and history and can be understood without \ | |
| the chat history. | |
| Do NOT answer the question, just reformulate it if needed and otherwise return it as is.""" | |
| self.contextual_q_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ('system', contextual_q_system_prompt), | |
| MessagesPlaceholder('chat_history'), | |
| ('human', '{input}'), | |
| ] | |
| ) | |
| self.history_aware_retriever = create_history_aware_retriever( | |
| self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt | |
| ) | |
| # initialize qa chain | |
| qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\ | |
| context to answer the question. If you don't know the answer, just say that you don't know. | |
| Context: {context}""" | |
| qa_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ('system', qa_system_prompt), | |
| MessagesPlaceholder('chat_history'), | |
| ('human', '{input}'), | |
| ] | |
| ) | |
| self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt) | |
| rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain) | |
| self.conversation_rag_chain = RunnableWithMessageHistory( | |
| rag_chain, | |
| lambda session_id: chat_memory, | |
| input_messages_key='input', | |
| history_messages_key='chat_history', | |
| output_messages_key='answer' | |
| ) | |
| def run(self, user_input): | |
| config = {"configurable": {"session_id": "any"}} | |
| response = self.conversation_rag_chain.invoke({'input': user_input}, config) | |
| return response['answer'] | |
| def update_chain(self, uploaded_pdf): | |
| self.update_knowledge_base(uploaded_pdf) | |
| self.history_aware_retriever = create_history_aware_retriever( | |
| self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt | |
| ) | |
| self.conversation_rag_chain = RunnableWithMessageHistory( | |
| create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain), | |
| lambda session_id: self.chat_memory, | |
| input_messages_key='input', | |
| history_messages_key='chat_history', | |
| output_messages_key='answer' | |
| ) | |
| def update_knowledge_base(self, uploaded_pdf): | |
| self.vector_db.index(uploaded_pdf) |