OnurKerimoglu commited on
Commit
c6fa9db
·
1 Parent(s): 3446837

introduced src.rag.py

Browse files
Files changed (1) hide show
  1. src/rag.py +143 -0
src/rag.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import dotenv
3
+ import os
4
+ from langchain.document_loaders import UnstructuredURLLoader, PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+
7
+ # from langchain.embeddings import OpenAIEmbeddings
8
+ from langchain_community.embeddings import HuggingFaceEmbeddings
9
+
10
+ from langchain.vectorstores import Chroma
11
+ from langchain.chat_models import ChatOpenAI
12
+ from langchain.chains import RetrievalQA
13
+ from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
14
+ from tqdm import tqdm
15
+
16
+ class RAG():
17
+ def __init__(
18
+ self,
19
+ urls=[],
20
+ pdfs=[],
21
+ k=3):
22
+ # Input arguments
23
+ self.urls = urls # Source URLS to encode in vectorestore
24
+ self.pdfs = pdfs # Source PDFs to encode in vectorestore
25
+ self.k = 3 # Number of relevant chunks to retrieve
26
+
27
+ # Load environment variables that should contain a 'OPENAI_API_KEY'
28
+ dotenv.load_dotenv(dotenv.find_dotenv())
29
+
30
+ # Placeholders:
31
+ self.QAbot = None
32
+
33
+ # Setup the bots
34
+ self.setup_rag_bots()
35
+
36
+ def load_data(self, urls, pdfs):
37
+ documents = []
38
+ if urls:
39
+ url_loader = UnstructuredURLLoader(urls=urls)
40
+ documents.extend(url_loader.load())
41
+ for pdf in pdfs:
42
+ pdf_loader = PyPDFLoader(pdf)
43
+ documents.extend(pdf_loader.load())
44
+ return documents
45
+
46
+ def sources_to_texts(self, urls, pdfs):
47
+
48
+ documents = self.load_data(urls, pdfs)
49
+
50
+ # Retrieval system
51
+ chunk_size = 1000
52
+ chunk_overlap = 200
53
+
54
+ text_splitter = RecursiveCharacterTextSplitter(
55
+ chunk_size=chunk_size,
56
+ chunk_overlap=chunk_overlap)
57
+ texts = text_splitter.split_documents(documents)
58
+ return texts
59
+
60
+ def create_embeddings(self):
61
+ # embeddings = OpenAIEmbeddings()
62
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
63
+ return embeddings
64
+
65
+ def create_retriever(self, texts, embeddings):
66
+ # Create embeddings and vector store
67
+ vectorstore = Chroma.from_documents(texts, embeddings)
68
+ retriever = vectorstore.as_retriever(search_kwargs={"k": self.k})
69
+ return retriever
70
+
71
+ def create_llm(self):
72
+ # Create the language model
73
+ llm = ChatOpenAI(
74
+ model_name="gpt-4o-mini",
75
+ temperature=0)
76
+ return llm
77
+
78
+ def create_QAbot(self, retriever, llm):
79
+ # Create a QAbot
80
+ # System prompt and prompt template
81
+ system_template = """You are an AI assistant that answers questions based on the given context.
82
+ Your responses should be informative and relevant to the question asked.
83
+ If you don't know the answer or if the information is not present in the context, say so."""
84
+
85
+ human_template = """Context: {context}
86
+
87
+ Question: {question}
88
+
89
+ Answer: """
90
+
91
+ # Create the prompt
92
+ system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
93
+ human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
94
+ prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
95
+ QAbot = RetrievalQA.from_chain_type(
96
+ llm=llm,
97
+ chain_type="stuff",
98
+ retriever=retriever,
99
+ return_source_documents=True,
100
+ chain_type_kwargs={"prompt": prompt}
101
+ )
102
+ return QAbot
103
+
104
+ def setup_rag_bots(self):
105
+ # Initial data
106
+ texts = self.sources_to_texts(self.urls, self.pdfs)
107
+ # Create embeddings
108
+ embeddings = self.create_embeddings()
109
+ # Create the retriever
110
+ retriever = self.create_retriever(texts, embeddings)
111
+ # Create the llm and prompts
112
+ llm = self.create_llm()
113
+ # Create a QA bot
114
+ self.QAbot = self.create_QAbot(
115
+ retriever,
116
+ llm
117
+ )
118
+
119
+ def ask_QAbot(self, question):
120
+ result = self.QAbot({"query": question})
121
+ sources = [doc.metadata.get('source', 'Unknown source') for doc in result["source_documents"]]
122
+ response = {
123
+ "question": question,
124
+ "answer": result["result"],
125
+ "sources": sources
126
+ }
127
+ return response
128
+
129
+
130
+ if __name__ == "__main__":
131
+ rag = RAG(
132
+ urls = [
133
+ "https://en.wikipedia.org/wiki/Artificial_intelligence",
134
+ "https://en.wikipedia.org/wiki/Machine_learning"
135
+ ],
136
+ pdfs = ["/home/onur/WORK/DS/repos/chat_with_docs/docs/the-big-book-of-mlops-v10-072023 - Databricks.pdf"]
137
+ )
138
+ response = rag.ask_QAbot("What is Machine Learning?")
139
+ print(f"Question: {response['question']}")
140
+ print(f"Answer: {response['answer']}")
141
+ print("Sources:")
142
+ for source in response['sources']:
143
+ print(f"- {source}")