OnurKerimoglu commited on
Commit
46dfa3e
·
1 Parent(s): 2550c52

rag: use mdoel Mistral-Nemo-Base-2407 instead of zephyr; various minor fixes; added docstrings to the class methods

Browse files
Files changed (1) hide show
  1. src/rag.py +97 -24
src/rag.py CHANGED
@@ -1,18 +1,14 @@
1
 
2
  import dotenv
3
- import os
4
- from langchain_community.document_loaders import UnstructuredURLLoader, PyPDFLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
 
7
- from langchain_huggingface import HuggingFaceEmbeddings
 
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_community.vectorstores import FAISS
10
  from langchain_openai import ChatOpenAI
11
- # from langchain_community.llms import HuggingFaceHub
12
- from langchain_huggingface import HuggingFaceEndpoint
13
  from langchain.chains import RetrievalQA
14
  from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
15
- from tqdm import tqdm
16
 
17
  class RAG():
18
  def __init__(
@@ -27,7 +23,8 @@ class RAG():
27
 
28
  # Constants
29
  # self.use_model = 'gpt-4o-mini'
30
- self.use_model = 'zephyr-7b-alpha'
 
31
 
32
  # self.use_vectordb = 'chroma'
33
  self.use_vectordb = 'faiss'
@@ -42,9 +39,17 @@ class RAG():
42
  self.QAbot = None
43
 
44
  # Setup the bots
45
- self.setup_rag_bots()
46
 
47
  def load_data(self, urls, pdfs):
 
 
 
 
 
 
 
 
48
  documents = []
49
  if urls:
50
  url_loader = UnstructuredURLLoader(urls=urls)
@@ -54,9 +59,15 @@ class RAG():
54
  documents.extend(pdf_loader.load())
55
  return documents
56
 
57
- def sources_to_texts(self, urls, pdfs):
58
-
59
- documents = self.load_data(urls, pdfs)
 
 
 
 
 
 
60
 
61
  # Retrieval system
62
  chunk_size = 1000
@@ -75,7 +86,14 @@ class RAG():
75
  return embeddings
76
 
77
  def create_retriever(self, texts, embeddings):
78
- # Create embeddings and vector store
 
 
 
 
 
 
 
79
  if self.use_vectordb == 'chroma':
80
  print ('Creating vectore store with Chroma')
81
  vectorstore = Chroma.from_documents(texts, embeddings)
@@ -87,23 +105,59 @@ class RAG():
87
  return retriever
88
 
89
  def create_llm(self):
90
- # Create the language model
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  if self.use_model == 'gpt-4o-mini':
92
  print(f'As llm, using OpenAI model: {self.use_model}')
93
  llm = ChatOpenAI(
94
  model_name="gpt-4o-mini",
95
  temperature=0)
96
- elif self.use_model == 'zephyr-7b-alpha':
97
- print(f'As llm, using HF-Endpint: {self.use_model}')
 
 
 
 
 
 
 
 
 
98
  llm = HuggingFaceEndpoint(
99
- repo_id=f"huggingfaceh4/{self.use_model}",
 
100
  temperature=0.1,
101
- max_new_tokens=512
102
- )
 
103
  return llm
104
 
105
  def create_QAbot(self, retriever, llm):
106
- # Create a QAbot
 
 
 
 
 
 
 
 
 
 
 
 
107
  # System prompt and prompt template
108
  system_template = """You are an AI assistant that answers questions based on the given context.
109
  Your responses should be informative and relevant to the question asked.
@@ -128,9 +182,20 @@ class RAG():
128
  )
129
  return QAbot
130
 
131
- def setup_rag_bots(self):
 
 
 
 
 
 
 
 
 
 
132
  # Initial data
133
- texts = self.sources_to_texts(self.urls, self.pdfs)
 
134
  # Create embeddings
135
  embeddings = self.create_embeddings()
136
  # Create the retriever
@@ -144,6 +209,14 @@ class RAG():
144
  )
145
 
146
  def ask_QAbot(self, question):
 
 
 
 
 
 
 
 
147
  result = self.QAbot.invoke({"query": question})
148
  sources = [doc.metadata.get('source', 'Unknown source') for doc in result["source_documents"]]
149
  response = {
@@ -159,8 +232,8 @@ if __name__ == "__main__":
159
  urls = [
160
  "https://en.wikipedia.org/wiki/Artificial_intelligence",
161
  "https://en.wikipedia.org/wiki/Machine_learning"
162
- ],
163
- pdfs = ["/home/onur/WORK/DS/repos/chat_with_docs/docs/the-big-book-of-mlops-v10-072023 - Databricks.pdf"]
164
  )
165
  response = rag.ask_QAbot("What is Machine Learning?")
166
  print(f"Question: {response['question']}")
 
1
 
2
  import dotenv
 
 
 
3
 
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.document_loaders import UnstructuredURLLoader, PyPDFLoader
6
  from langchain_community.vectorstores import Chroma
7
  from langchain_community.vectorstores import FAISS
8
  from langchain_openai import ChatOpenAI
9
+ from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
 
10
  from langchain.chains import RetrievalQA
11
  from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
 
12
 
13
  class RAG():
14
  def __init__(
 
23
 
24
  # Constants
25
  # self.use_model = 'gpt-4o-mini'
26
+ # self.use_model = 'zephyr-7b-alpha'
27
+ self.use_model = 'Mistral-Nemo-Base-2407'
28
 
29
  # self.use_vectordb = 'chroma'
30
  self.use_vectordb = 'faiss'
 
39
  self.QAbot = None
40
 
41
  # Setup the bots
42
+ self.setup_rag_bot()
43
 
44
  def load_data(self, urls, pdfs):
45
+ """
46
+ Loads data from the input URLs and PDFs.
47
+ Args:
48
+ urls: List of URLs to load.
49
+ pdfs: List of PDF files to load.
50
+ Returns:
51
+ A list of Document objects loaded from the input URLs and PDFs.
52
+ """
53
  documents = []
54
  if urls:
55
  url_loader = UnstructuredURLLoader(urls=urls)
 
59
  documents.extend(pdf_loader.load())
60
  return documents
61
 
62
+ def sources_to_texts(self, documents):
63
+ """
64
+ Takes a list of URLs and PDFs and converts them into a list of text chunks.
65
+ The text chunks are split into chunks of a certain size with a certain amount of overlap.
66
+ Args:
67
+ documents: a list of document objects loaded from the input data
68
+ Returns:
69
+ A list of text chunks.
70
+ """
71
 
72
  # Retrieval system
73
  chunk_size = 1000
 
86
  return embeddings
87
 
88
  def create_retriever(self, texts, embeddings):
89
+ """
90
+ Creates a retriever from the given texts and embeddings.
91
+ Args:
92
+ texts: A list of text strings to encode in the vector store.
93
+ embeddings: An instance of langchain.Embeddings to use for encoding the texts.
94
+ Returns:
95
+ An instance of langchain.Retriever.
96
+ """
97
  if self.use_vectordb == 'chroma':
98
  print ('Creating vectore store with Chroma')
99
  vectorstore = Chroma.from_documents(texts, embeddings)
 
105
  return retriever
106
 
107
  def create_llm(self):
108
+ """
109
+ Instantiates a language model based on the specified model type.
110
+
111
+ This function supports two models:
112
+ - 'gpt-4o-mini' through the ChatOpenAI interface
113
+ - 'Mistral-Nemo-Base-2407' through the HuggingFaceEndpoint, with provider: novita
114
+ ('zephyr-7b-alpha' through the HuggingFaceEndpoint is being tested, but not working at the moment)
115
+ The model is determined by the `self.use_model` attribute.
116
+ Returns an instance of the selected language model.
117
+
118
+ Returns:
119
+ llm: An instance of the chosen language model, either ChatOpenAI or HuggingFaceEndpoint.
120
+ """
121
+
122
  if self.use_model == 'gpt-4o-mini':
123
  print(f'As llm, using OpenAI model: {self.use_model}')
124
  llm = ChatOpenAI(
125
  model_name="gpt-4o-mini",
126
  temperature=0)
127
+ # elif self.use_model == 'zephyr-7b-alpha':
128
+ # print(f'As llm, using HF-Endpint: {self.use_model}')
129
+ # llm = HuggingFaceEndpoint(
130
+ # repo_id=f"HuggingFaceH4/{self.use_model}",
131
+ # temperature=0.1,
132
+ # max_new_tokens=512,
133
+ # do_sample=False
134
+ # )
135
+ elif self.use_model == 'Mistral-Nemo-Base-2407':
136
+ provider = "novita"
137
+ print(f'As llm, using HF-Endpint: {self.use_model} through provider: {provider}')
138
  llm = HuggingFaceEndpoint(
139
+ repo_id="mistralai/Mistral-Nemo-Base-2407",
140
+ provider=provider,
141
  temperature=0.1,
142
+ max_new_tokens=512,
143
+ do_sample=False
144
+ )
145
  return llm
146
 
147
  def create_QAbot(self, retriever, llm):
148
+ """
149
+ Creates a QAbot (Question-Answering bot) from the given retriever and language model.
150
+
151
+ The QAbot is a type of RetrievalQA chain built with Langchain that, for a given question:
152
+ - uses the given retriever to get the relevant documents
153
+ - and the given language model to generate an answer.
154
+ Args:
155
+ retriever: An instance of langchain.Retriever.
156
+ llm: An instance of langchain.LLM.
157
+ Returns:
158
+ QAbot: An instance of langchain.RetrievalQA.
159
+ """
160
+
161
  # System prompt and prompt template
162
  system_template = """You are an AI assistant that answers questions based on the given context.
163
  Your responses should be informative and relevant to the question asked.
 
182
  )
183
  return QAbot
184
 
185
+ def setup_rag_bot(self):
186
+ """
187
+ Sets up the RAG bot by:
188
+ - loading the data from the input URLs and PDFs
189
+ - splitting the data into chunks of text
190
+ - creating embeddings for the text chunks
191
+ - creating a retriever using the embeddings
192
+ - creating a language model and prompts
193
+ - and creating a QA bot (Question-Answering bot) using the retriever and language model.
194
+ """
195
+
196
  # Initial data
197
+ documents = self.load_data(self.urls, self.pdfs)
198
+ texts = self.sources_to_texts(documents)
199
  # Create embeddings
200
  embeddings = self.create_embeddings()
201
  # Create the retriever
 
209
  )
210
 
211
  def ask_QAbot(self, question):
212
+ """
213
+ Queries the QA bot with a specified question and retrieves the answer along with the sources.
214
+ Args:
215
+ question (str): The question to be asked to the QA bot.
216
+ Returns:
217
+ dict: A dictionary containing the question, answer, and sources.
218
+ """
219
+
220
  result = self.QAbot.invoke({"query": question})
221
  sources = [doc.metadata.get('source', 'Unknown source') for doc in result["source_documents"]]
222
  response = {
 
232
  urls = [
233
  "https://en.wikipedia.org/wiki/Artificial_intelligence",
234
  "https://en.wikipedia.org/wiki/Machine_learning"
235
+ ]
236
+ # pdfs = ["/home/onur/WORK/DS/repos/chat_with_docs/docs/the-big-book-of-mlops-v10-072023 - Databricks.pdf"]
237
  )
238
  response = rag.ask_QAbot("What is Machine Learning?")
239
  print(f"Question: {response['question']}")