Spaces:
Build error
Build error
Ulaşcan Akbulut commited on
Commit ·
05caa09
1
Parent(s): 1766eea
Add Rag file
Browse files- RAG_public.py +234 -0
RAG_public.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Import
|
| 2 |
+
import os
|
| 3 |
+
#from dotenv import load_dotenv
|
| 4 |
+
from langchain_openai import ChatOpenAI
|
| 5 |
+
from pymilvus import connections, utility
|
| 6 |
+
from langchain_openai import OpenAIEmbeddings
|
| 7 |
+
from langchain_milvus.vectorstores import Milvus
|
| 8 |
+
from langchain.chains import create_retrieval_chain
|
| 9 |
+
from langchain.chains import create_history_aware_retriever
|
| 10 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
| 11 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 12 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
| 13 |
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
| 14 |
+
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 16 |
+
|
| 17 |
+
# Environment Settings
|
| 18 |
+
#load_dotenv()
|
| 19 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
cloud_api_key = os.getenv("CLOUD_API_KEY")
|
| 21 |
+
cloud_uri = os.getenv("URI")
|
| 22 |
+
|
| 23 |
+
# Database Connection
|
| 24 |
+
class DatabaseManagement:
|
| 25 |
+
"""
|
| 26 |
+
Connects Milvus database
|
| 27 |
+
"""
|
| 28 |
+
def __init__(self):
|
| 29 |
+
"""
|
| 30 |
+
Connects to Milvus server and calls initiliaze_database function
|
| 31 |
+
"""
|
| 32 |
+
# Connects to Milvus server
|
| 33 |
+
connections.connect(alias="default", uri=cloud_uri, token=cloud_api_key, timeout=120)
|
| 34 |
+
print("Connected to the Milvus Server")
|
| 35 |
+
|
| 36 |
+
# Manages vectorstore
|
| 37 |
+
class VectorStoreManagement:
|
| 38 |
+
"""
|
| 39 |
+
Creates vectorstore from Milvus if vectorstore is not defined or defined as None
|
| 40 |
+
|
| 41 |
+
Methods
|
| 42 |
+
------
|
| 43 |
+
|
| 44 |
+
create_vectorstore()
|
| 45 |
+
Checks whether vectorstore is defined or not defined. If is defined, splits the data into
|
| 46 |
+
smaller chunks and creates vectorstore from Milvus
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, document):
|
| 49 |
+
"""
|
| 50 |
+
Initialize document, embedding and vectorstore and calls create_vectorstore function
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
document: list
|
| 55 |
+
Document from langchain_core.documents inside a list
|
| 56 |
+
embedding:
|
| 57 |
+
Openai embeddings
|
| 58 |
+
"""
|
| 59 |
+
self.document = document
|
| 60 |
+
self.vectorstore = None
|
| 61 |
+
self.create_vectorstore()
|
| 62 |
+
|
| 63 |
+
def create_vectorstore(self):
|
| 64 |
+
"""
|
| 65 |
+
create_vectorstore()
|
| 66 |
+
Checks whether vectorstore is defined or not defined. If it is defined, splits the data into
|
| 67 |
+
smaller chunks and creates vectorstore from Milvus
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Define collection name
|
| 71 |
+
collection_name = "RAG_Milvus"
|
| 72 |
+
|
| 73 |
+
# Creates collection under ChatRAG database
|
| 74 |
+
if collection_name not in utility.list_collections():
|
| 75 |
+
print("RAG_Milvus collection does not exist under the ChatRAG database")
|
| 76 |
+
# Split the string data into smaller chunks
|
| 77 |
+
textsplitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200, length_function=len)
|
| 78 |
+
chunks_data = textsplitter.split_documents(documents=self.document)
|
| 79 |
+
|
| 80 |
+
# Create vectorstore from Milvus
|
| 81 |
+
self.vectorstore = Milvus.from_documents(documents=chunks_data,
|
| 82 |
+
embedding=OpenAIEmbeddings(openai_api_key=openai_api_key),
|
| 83 |
+
collection_name=collection_name,
|
| 84 |
+
connection_args={"uri":cloud_uri,
|
| 85 |
+
"token":cloud_api_key})
|
| 86 |
+
print("RAG_Milvus collection is created under ChatRAG database")
|
| 87 |
+
else:
|
| 88 |
+
print("RAG_Milvus collection already exist")
|
| 89 |
+
self.vectorstore = Milvus(embedding_function=OpenAIEmbeddings(openai_api_key=openai_api_key),
|
| 90 |
+
collection_name=collection_name,
|
| 91 |
+
connection_args={"uri":cloud_uri,
|
| 92 |
+
"token":cloud_api_key})
|
| 93 |
+
|
| 94 |
+
# RAG class to retrieve ai response for a given user query
|
| 95 |
+
class RAG:
|
| 96 |
+
"""
|
| 97 |
+
ChatRAG that uses Retrieval Augmented Generation model for large language model
|
| 98 |
+
with the langchain
|
| 99 |
+
|
| 100 |
+
Methods
|
| 101 |
+
-------
|
| 102 |
+
|
| 103 |
+
model():
|
| 104 |
+
Creates llm from openai. Uses the model gpt-3.5-turbo-0125 with temperature=0
|
| 105 |
+
Creates retriever from vectorstore
|
| 106 |
+
Defines contextualize_q_prompt to use it in history_aware_retriever where llm, retriever and contextualize_q_prompt is combined
|
| 107 |
+
Defines qa_prompt (question/answer) to use it in create_stuff_documents_chain where llm and qa_prompt is combined for question_answer_chain
|
| 108 |
+
Defines rag chain by combining history_aware_retriever and question_answer_chain
|
| 109 |
+
|
| 110 |
+
get_session_history(session_id):
|
| 111 |
+
Stores chat history and session_id in a dictionary
|
| 112 |
+
|
| 113 |
+
conversational_rag_chain(input):
|
| 114 |
+
Creates conversational rag chain and invokes the ai response
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
def __init__(self, document):
|
| 118 |
+
"""
|
| 119 |
+
Initilization of document and store to store the chat history
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
document: list
|
| 124 |
+
Document from langchain.schema inside a list
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
self.document = document
|
| 128 |
+
self.database_manager = DatabaseManagement()
|
| 129 |
+
self.vectorstore_manager = VectorStoreManagement(self.document)
|
| 130 |
+
self.store = {}
|
| 131 |
+
|
| 132 |
+
# RAG model
|
| 133 |
+
def model(self):
|
| 134 |
+
"""
|
| 135 |
+
Creates llm from openai. Uses the model gpt-3.5-turbo-0125 with temperature=0
|
| 136 |
+
Creates retriever from vectorstore
|
| 137 |
+
Defines contextualize_q_prompt to use it in history_aware_retriever where llm, retriever and contextualize_q_prompt is combined
|
| 138 |
+
Defines qa_prompt (question/answer) to use it in create_stuff_documents_chain where llm and qa_prompt is combined for question_answer_chain
|
| 139 |
+
Defines rag chain by combining history_aware_retriever and question_answer_chain
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
# Create llm from chatopenai
|
| 143 |
+
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
| 144 |
+
|
| 145 |
+
# Create retriever. Its function is to return relevant documents from documents with respect to similarity search and user input.
|
| 146 |
+
retriever = self.vectorstore_manager.vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6})
|
| 147 |
+
|
| 148 |
+
# System prompt that tells the language model on how to handle the latest user query in the context of the entire conversation history
|
| 149 |
+
# It tells the model to take the chat history and the latest user question and rephrase the question so it can be understood independently
|
| 150 |
+
# of the history
|
| 151 |
+
contextualize_q_system_prompt = """Given a chat history and the latest user question \
|
| 152 |
+
which might reference context in the chat history, formulate a standalone question \
|
| 153 |
+
which can be understood without the chat history. Do NOT answer the question, \
|
| 154 |
+
just reformulate it if needed and otherwise return it as is."""
|
| 155 |
+
|
| 156 |
+
# Create customized Chat Prompt Template with a customized system prompt
|
| 157 |
+
contextualize_q_prompt = ChatPromptTemplate.from_messages([
|
| 158 |
+
("system", contextualize_q_system_prompt),
|
| 159 |
+
MessagesPlaceholder("chat_history"),
|
| 160 |
+
("human", "{input}"),])
|
| 161 |
+
|
| 162 |
+
# Create history aware retriever. It combines current user query with the chat history so that
|
| 163 |
+
# ai response is relevant to the previous question/answer
|
| 164 |
+
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
|
| 165 |
+
|
| 166 |
+
# Create custom question/answer prompt
|
| 167 |
+
qa_system_prompt = """You are an assistant for question-answering tasks. \
|
| 168 |
+
Use the following pieces of retrieved context to answer the question. \
|
| 169 |
+
If you don't know the answer, just say that you don't know. \
|
| 170 |
+
Use three sentences maximum and keep the answer concise. \
|
| 171 |
+
|
| 172 |
+
{context}"""
|
| 173 |
+
|
| 174 |
+
# Create custom question answer Chat Prompt
|
| 175 |
+
qa_prompt = ChatPromptTemplate.from_messages([
|
| 176 |
+
("system", qa_system_prompt),
|
| 177 |
+
MessagesPlaceholder("chat_history"),
|
| 178 |
+
("human", "{input}"),])
|
| 179 |
+
|
| 180 |
+
# Create question/answer chain. It combines llm and qa_prompt.
|
| 181 |
+
# It uses llm and retrieved context to asnwer question.
|
| 182 |
+
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
|
| 183 |
+
|
| 184 |
+
# RAG chain that combines the history aware retriever and question/answer chain
|
| 185 |
+
# It makes sure that that retrieved documents are related to the chat history and user query
|
| 186 |
+
self.rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
|
| 187 |
+
|
| 188 |
+
# Method/function to store chat history
|
| 189 |
+
def get_session_history(self, session_id: str) -> BaseChatMessageHistory:
|
| 190 |
+
"""
|
| 191 |
+
Stores chat history and session_id in a dictionary
|
| 192 |
+
|
| 193 |
+
Parameters
|
| 194 |
+
----------
|
| 195 |
+
session_id: str
|
| 196 |
+
session_id in string format
|
| 197 |
+
Returns
|
| 198 |
+
-------
|
| 199 |
+
store: dict
|
| 200 |
+
Dictionary that has key: session_id and value: chat history
|
| 201 |
+
"""
|
| 202 |
+
if session_id not in self.store:
|
| 203 |
+
self.store[session_id] = ChatMessageHistory()
|
| 204 |
+
return self.store[session_id]
|
| 205 |
+
|
| 206 |
+
#Create conversational RAG chain
|
| 207 |
+
def conversational_rag_chain(self, input):
|
| 208 |
+
"""
|
| 209 |
+
Creates conversational rag chain and invokes it
|
| 210 |
+
|
| 211 |
+
Parameters
|
| 212 |
+
----------
|
| 213 |
+
input: str
|
| 214 |
+
User's query
|
| 215 |
+
Returns
|
| 216 |
+
-------
|
| 217 |
+
str
|
| 218 |
+
AI response
|
| 219 |
+
"""
|
| 220 |
+
conversational_rag_chain = RunnableWithMessageHistory(
|
| 221 |
+
self.rag_chain,
|
| 222 |
+
self.get_session_history,
|
| 223 |
+
input_messages_key="input",
|
| 224 |
+
history_messages_key="chat_history",
|
| 225 |
+
output_messages_key="answer")
|
| 226 |
+
|
| 227 |
+
result = conversational_rag_chain.invoke({"input": str(input)},
|
| 228 |
+
config={"configurable": {"session_id": "6161"}})
|
| 229 |
+
|
| 230 |
+
l = []
|
| 231 |
+
for doc in result["context"]:
|
| 232 |
+
l.append(doc.metadata["pdf_url"])
|
| 233 |
+
|
| 234 |
+
return result["answer"], l
|