Al1Abdullah commited on
Commit
b29ce3b
·
verified ·
1 Parent(s): ca1edba

Create ollama_chain.py

Browse files
Files changed (1) hide show
  1. src/ollama_chain.py +129 -0
src/ollama_chain.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.llms import Ollama
2
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
3
+ from langchain.memory import ConversationBufferWindowMemory
4
+ from langchain.chains import LLMChain, create_history_aware_retriever, create_retrieval_chain
5
+ from langchain.chains.combine_documents import create_stuff_documents_chain
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain_core.runnables.history import RunnableWithMessageHistory
8
+ from langchain.schema import Document
9
+
10
+ from src.utils import load_config
11
+ from src.vectorstore import VectorDB
12
+
13
+
14
+ def format_docs(docs: list[Document]):
15
+ return '\n\n'.join(doc.page_content for doc in docs)
16
+
17
+
18
+ class OllamaChain:
19
+ def __init__(self, chat_memory) -> None:
20
+ prompt = PromptTemplate(
21
+ template="""<|begin_of_text|>
22
+ <|start_header_id|>system<|end_header_id|>
23
+ You are a honest and unbiased AI assistant
24
+ <|eot_id|>
25
+ <|start_header_id|>user<|end_header_id|>
26
+ Previous conversation={chat_history}
27
+ Question: {input}
28
+ Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
29
+ input_variables=['chat_history', 'input']
30
+ )
31
+
32
+ self.memory = ConversationBufferWindowMemory(
33
+ memory_key='chat_history',
34
+ chat_memory=chat_memory,
35
+ k=3,
36
+ return_messages=True
37
+ )
38
+
39
+ config = load_config()
40
+ llm = Ollama(**config['chat_model'])
41
+ # llm = Ollama(model='llama3:latest', temperature=0.75, num_gpu=1)
42
+
43
+ self.llm_chain = LLMChain(prompt=prompt, llm=llm, memory=self.memory, output_parser=StrOutputParser())
44
+ # runnable = prompt | llm
45
+
46
+ def run(self, user_input):
47
+ response = self.llm_chain.invoke(user_input)
48
+
49
+ return response['text']
50
+
51
+
52
+ class OllamaRAGChain:
53
+ def __init__(self, chat_memory, uploaded_file=None):
54
+ # initialize vector db
55
+ self.vector_db = VectorDB('pinecone', 'any')
56
+ if uploaded_file:
57
+ self.update_knowledge_base()
58
+
59
+ # initialize llm
60
+ config = load_config()
61
+ self.llm = Ollama(**config['chat_model'])
62
+
63
+ # initialize memory
64
+ self.chat_memory = chat_memory
65
+
66
+ # initialize sub chain with history message
67
+ contextual_q_system_prompt = """Given a chat history and the latest user question which might refer to context \
68
+ in the chat history. Check if the user's question refers to the chat history or not. If does, formulate a \
69
+ standalone question which is incorporated from the latest question and history and can be understood without \
70
+ the chat history.
71
+ Do NOT answer the question, just reformulate it if needed and otherwise return it as is."""
72
+
73
+ self.contextual_q_prompt = ChatPromptTemplate.from_messages(
74
+ [
75
+ ('system', contextual_q_system_prompt),
76
+ MessagesPlaceholder('chat_history'),
77
+ ('human', '{input}'),
78
+ ]
79
+ )
80
+
81
+ self.history_aware_retriever = create_history_aware_retriever(
82
+ self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
83
+ )
84
+
85
+ # initialize qa chain
86
+ qa_system_prompt = """You are an assistant for question-answering tasks. Use the following pieces of retrieved\
87
+ context to answer the question. If you don't know the answer, just say that you don't know.
88
+ Context: {context}"""
89
+ qa_prompt = ChatPromptTemplate.from_messages(
90
+ [
91
+ ('system', qa_system_prompt),
92
+ MessagesPlaceholder('chat_history'),
93
+ ('human', '{input}'),
94
+ ]
95
+ )
96
+
97
+ self.question_answer_chain = create_stuff_documents_chain(self.llm, qa_prompt)
98
+
99
+ rag_chain = create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain)
100
+
101
+ self.conversation_rag_chain = RunnableWithMessageHistory(
102
+ rag_chain,
103
+ lambda session_id: chat_memory,
104
+ input_messages_key='input',
105
+ history_messages_key='chat_history',
106
+ output_messages_key='answer'
107
+ )
108
+
109
+ def run(self, user_input):
110
+ config = {"configurable": {"session_id": "any"}}
111
+ response = self.conversation_rag_chain.invoke({'input': user_input}, config)
112
+
113
+ return response['answer']
114
+
115
+ def update_chain(self, uploaded_pdf):
116
+ self.update_knowledge_base(uploaded_pdf)
117
+ self.history_aware_retriever = create_history_aware_retriever(
118
+ self.llm, self.vector_db.as_retriever(), self.contextual_q_prompt
119
+ )
120
+ self.conversation_rag_chain = RunnableWithMessageHistory(
121
+ create_retrieval_chain(self.history_aware_retriever, self.question_answer_chain),
122
+ lambda session_id: self.chat_memory,
123
+ input_messages_key='input',
124
+ history_messages_key='chat_history',
125
+ output_messages_key='answer'
126
+ )
127
+
128
+ def update_knowledge_base(self, uploaded_pdf):
129
+ self.vector_db.index(uploaded_pdf)