Ulaşcan Akbulut commited on
Commit
05caa09
·
1 Parent(s): 1766eea

Add Rag file

Browse files
Files changed (1) hide show
  1. RAG_public.py +234 -0
RAG_public.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import
2
+ import os
3
+ #from dotenv import load_dotenv
4
+ from langchain_openai import ChatOpenAI
5
+ from pymilvus import connections, utility
6
+ from langchain_openai import OpenAIEmbeddings
7
+ from langchain_milvus.vectorstores import Milvus
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain.chains import create_history_aware_retriever
10
+ from langchain_core.chat_history import BaseChatMessageHistory
11
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
12
+ from langchain_core.runnables.history import RunnableWithMessageHistory
13
+ from langchain_community.chat_message_histories import ChatMessageHistory
14
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
+ from langchain.chains.combine_documents import create_stuff_documents_chain
16
+
17
+ # Environment Settings
18
+ #load_dotenv()
19
+ openai_api_key = os.getenv("OPENAI_API_KEY")
20
+ cloud_api_key = os.getenv("CLOUD_API_KEY")
21
+ cloud_uri = os.getenv("URI")
22
+
23
+ # Database Connection
24
+ class DatabaseManagement:
25
+ """
26
+ Connects Milvus database
27
+ """
28
+ def __init__(self):
29
+ """
30
+ Connects to Milvus server and calls initiliaze_database function
31
+ """
32
+ # Connects to Milvus server
33
+ connections.connect(alias="default", uri=cloud_uri, token=cloud_api_key, timeout=120)
34
+ print("Connected to the Milvus Server")
35
+
36
+ # Manages vectorstore
37
+ class VectorStoreManagement:
38
+ """
39
+ Creates vectorstore from Milvus if vectorstore is not defined or defined as None
40
+
41
+ Methods
42
+ ------
43
+
44
+ create_vectorstore()
45
+ Checks whether vectorstore is defined or not defined. If is defined, splits the data into
46
+ smaller chunks and creates vectorstore from Milvus
47
+ """
48
+ def __init__(self, document):
49
+ """
50
+ Initialize document, embedding and vectorstore and calls create_vectorstore function
51
+
52
+ Parameters
53
+ ----------
54
+ document: list
55
+ Document from langchain_core.documents inside a list
56
+ embedding:
57
+ Openai embeddings
58
+ """
59
+ self.document = document
60
+ self.vectorstore = None
61
+ self.create_vectorstore()
62
+
63
+ def create_vectorstore(self):
64
+ """
65
+ create_vectorstore()
66
+ Checks whether vectorstore is defined or not defined. If it is defined, splits the data into
67
+ smaller chunks and creates vectorstore from Milvus
68
+ """
69
+
70
+ # Define collection name
71
+ collection_name = "RAG_Milvus"
72
+
73
+ # Creates collection under ChatRAG database
74
+ if collection_name not in utility.list_collections():
75
+ print("RAG_Milvus collection does not exist under the ChatRAG database")
76
+ # Split the string data into smaller chunks
77
+ textsplitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
78
+ chunks_data = textsplitter.split_documents(documents=self.document)
79
+
80
+ # Create vectorstore from Milvus
81
+ self.vectorstore = Milvus.from_documents(documents=chunks_data,
82
+ embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
83
+ collection_name=collection_name,
84
+ connection_args={"uri":cloud_uri,
85
+ "token":cloud_api_key})
86
+ print("RAG_Milvus collection is created under ChatRAG database")
87
+ else:
88
+ print("RAG_Milvus collection already exist")
89
+ self.vectorstore = Milvus(embedding_function=OpenAIEmbeddings(openai_api_key=openai_api_key),
90
+ collection_name=collection_name,
91
+ connection_args={"uri":cloud_uri,
92
+ "token":cloud_api_key})
93
+
94
+ # RAG class to retrieve ai response for a given user query
95
+ class RAG:
96
+ """
97
+ ChatRAG that uses Retrieval Augmented Generation model for large language model
98
+ with the langchain
99
+
100
+ Methods
101
+ -------
102
+
103
+ model():
104
+ Creates llm from openai. Uses the model gpt-3.5-turbo-0125 with temperature=0
105
+ Creates retriever from vectorstore
106
+ Defines contextualize_q_prompt to use it in history_aware_retriever where llm, retriever and contextualize_q_prompt is combined
107
+ Defines qa_prompt (question/answer) to use it in create_stuff_documents_chain where llm and qa_prompt is combined for question_answer_chain
108
+ Defines rag chain by combining history_aware_retriever and question_answer_chain
109
+
110
+ get_session_history(session_id):
111
+ Stores chat history and session_id in a dictionary
112
+
113
+ conversational_rag_chain(input):
114
+ Creates conversational rag chain and invokes the ai response
115
+ """
116
+
117
+ def __init__(self, document):
118
+ """
119
+ Initilization of document and store to store the chat history
120
+
121
+ Parameters
122
+ ----------
123
+ document: list
124
+ Document from langchain.schema inside a list
125
+ """
126
+
127
+ self.document = document
128
+ self.database_manager = DatabaseManagement()
129
+ self.vectorstore_manager = VectorStoreManagement(self.document)
130
+ self.store = {}
131
+
132
+ # RAG model
133
+ def model(self):
134
+ """
135
+ Creates llm from openai. Uses the model gpt-3.5-turbo-0125 with temperature=0
136
+ Creates retriever from vectorstore
137
+ Defines contextualize_q_prompt to use it in history_aware_retriever where llm, retriever and contextualize_q_prompt is combined
138
+ Defines qa_prompt (question/answer) to use it in create_stuff_documents_chain where llm and qa_prompt is combined for question_answer_chain
139
+ Defines rag chain by combining history_aware_retriever and question_answer_chain
140
+ """
141
+
142
+ # Create llm from chatopenai
143
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
144
+
145
+ # Create retriever. Its function is to return relevant documents from documents with respect to similarity search and user input.
146
+ retriever = self.vectorstore_manager.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6})
147
+
148
+ # System prompt that tells the language model on how to handle the latest user query in the context of the entire conversation history
149
+ # It tells the model to take the chat history and the latest user question and rephrase the question so it can be understood independently
150
+ # of the history
151
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
152
+ which might reference context in the chat history, formulate a standalone question \
153
+ which can be understood without the chat history. Do NOT answer the question, \
154
+ just reformulate it if needed and otherwise return it as is."""
155
+
156
+ # Create customized Chat Prompt Template with a customized system prompt
157
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
158
+ ("system", contextualize_q_system_prompt),
159
+ MessagesPlaceholder("chat_history"),
160
+ ("human", "{input}"),])
161
+
162
+ # Create history aware retriever. It combines current user query with the chat history so that
163
+ # ai response is relevant to the previous question/answer
164
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
165
+
166
+ # Create custom question/answer prompt
167
+ qa_system_prompt = """You are an assistant for question-answering tasks. \
168
+ Use the following pieces of retrieved context to answer the question. \
169
+ If you don't know the answer, just say that you don't know. \
170
+ Use three sentences maximum and keep the answer concise. \
171
+
172
+ {context}"""
173
+
174
+ # Create custom question answer Chat Prompt
175
+ qa_prompt = ChatPromptTemplate.from_messages([
176
+ ("system", qa_system_prompt),
177
+ MessagesPlaceholder("chat_history"),
178
+ ("human", "{input}"),])
179
+
180
+ # Create question/answer chain. It combines llm and qa_prompt.
181
+ # It uses llm and retrieved context to asnwer question.
182
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
183
+
184
+ # RAG chain that combines the history aware retriever and question/answer chain
185
+ # It makes sure that that retrieved documents are related to the chat history and user query
186
+ self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
187
+
188
+ # Method/function to store chat history
189
+ def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
190
+ """
191
+ Stores chat history and session_id in a dictionary
192
+
193
+ Parameters
194
+ ----------
195
+ session_id: str
196
+ session_id in string format
197
+ Returns
198
+ -------
199
+ store: dict
200
+ Dictionary that has key: session_id and value: chat history
201
+ """
202
+ if session_id not in self.store:
203
+ self.store[session_id] = ChatMessageHistory()
204
+ return self.store[session_id]
205
+
206
+ #Create conversational RAG chain
207
+ def conversational_rag_chain(self, input):
208
+ """
209
+ Creates conversational rag chain and invokes it
210
+
211
+ Parameters
212
+ ----------
213
+ input: str
214
+ User's query
215
+ Returns
216
+ -------
217
+ str
218
+ AI response
219
+ """
220
+ conversational_rag_chain = RunnableWithMessageHistory(
221
+ self.rag_chain,
222
+ self.get_session_history,
223
+ input_messages_key="input",
224
+ history_messages_key="chat_history",
225
+ output_messages_key="answer")
226
+
227
+ result = conversational_rag_chain.invoke({"input": str(input)},
228
+ config={"configurable": {"session_id": "6161"}})
229
+
230
+ l = []
231
+ for doc in result["context"]:
232
+ l.append(doc.metadata["pdf_url"])
233
+
234
+ return result["answer"], l