Spaces:
Runtime error
Runtime error
File size: 5,360 Bytes
593f0ea | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | from langchain_community.llms import Ollama
from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.schema import Document
from src.utils import load_config
from src.vectorstore import VectorDB
def format_docs(docs: list[Document]):
return '\n\n'.join(doc.page_content for doc in docs)
class OllamaChain:
def __init__(self, chat_memory) -> None:
prompt = PromptTemplate(
template="""<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are a honest and unbiased AI assistant
<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Previous conversation={chat_history}
Question: {input}
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=['chat_history', 'input']
)
self.memory = ConversationBufferWindowMemory(
memory_key='chat_history',
chat_memory=chat_memory,
k=3,
return_messages=True
)
config = load_config()
llm = Ollama(**config['chat_model'])
# llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1)
self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser())
# runnable = prompt | llm
def run(self, user_input):
response = self.llm_chain.invoke(user_input)
return response['text']
class OllamaRAGChain:
def __init__(self, chat_memory, uploaded_file=None):
# initialize vector db using config
from src.utils import load_config
config = load_config()
vector_db_config = config.get('vector_database', {})
db_name = 'pinecone' if 'pinecone' in vector_db_config else 'chroma'
index_name = 'default'
self.vector_db = VectorDB(db_name, index_name)
if uploaded_file:
self.update_knowledge_base(uploaded_file)
# initialize llm
config = load_config()
self.llm = Ollama(**config['chat_model'])
# initialize memory
self.chat_memory = chat_memory
# initialize sub chain with history message
contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \
in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \
standalone question which is incorporated from the latest question and history and can be understood without \
the chat history.
Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
self.contextual_q_prompt = ChatPromptTemplate.from_messages(
[
('system', contextual_q_system_prompt),
MessagesPlaceholder('chat_history'),
('human', '{input}'),
]
)
self.history_aware_retriever = create_history_aware_retriever(
self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
)
# initialize qa chain
qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\
context to answer the question. If you don't know the answer, just say that you don't know.
Context: {context}"""
qa_prompt = ChatPromptTemplate.from_messages(
[
('system', qa_system_prompt),
MessagesPlaceholder('chat_history'),
('human', '{input}'),
]
)
self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain)
self.conversation_rag_chain = RunnableWithMessageHistory(
rag_chain,
lambda session_id: chat_memory,
input_messages_key='input',
history_messages_key='chat_history',
output_messages_key='answer'
)
def run(self, user_input):
config = {"configurable": {"session_id": "any"}}
response = self.conversation_rag_chain.invoke({'input': user_input}, config)
return response['answer']
def update_chain(self, uploaded_pdf):
self.update_knowledge_base(uploaded_pdf)
self.history_aware_retriever = create_history_aware_retriever(
self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
)
self.conversation_rag_chain = RunnableWithMessageHistory(
create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain),
lambda session_id: self.chat_memory,
input_messages_key='input',
history_messages_key='chat_history',
output_messages_key='answer'
)
def update_knowledge_base(self, uploaded_pdf):
self.vector_db.index(uploaded_pdf) |