TomData commited on
Commit
7137e35
·
verified ·
1 Parent(s): 6d37989

Update src/chatbot.py

Browse files
Files changed (1) hide show
  1. src/chatbot.py +298 -296
src/chatbot.py CHANGED
@@ -1,296 +1,298 @@
1
- from langchain_core.prompts import ChatPromptTemplate
2
- from langchain_community.llms.huggingface_hub import HuggingFaceHub
3
- from langchain_community.embeddings import HuggingFaceEmbeddings
4
- from langchain_community.vectorstores import FAISS
5
-
6
-
7
- from langchain.chains.combine_documents import create_stuff_documents_chain
8
- from langchain.chains import create_retrieval_chain
9
-
10
- from langchain_community.docstore.in_memory import InMemoryDocstore
11
- from faiss import IndexFlatL2
12
-
13
- #import functools
14
- import pandas as pd
15
-
16
- # Load environmental variables from .env-file
17
- from dotenv import load_dotenv, find_dotenv
18
- load_dotenv(find_dotenv())
19
-
20
- # Define important variables
21
- embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions?
22
- llm = HuggingFaceHub(
23
- # ToDo: Try different models here
24
- repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
25
- # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
26
- # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb
27
- # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
28
- task="text-generation",
29
- model_kwargs={
30
- "max_new_tokens": 512,
31
- "top_k": 30,
32
- "temperature": 0.1,
33
- "repetition_penalty": 1.03,
34
- }
35
- )
36
- # ToDo: Experiment with different templates
37
- prompt_test = ChatPromptTemplate.from_template("""<s>[INST]
38
- Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:
39
-
40
- Context: {context}
41
-
42
- Question: {input}
43
- [/INST]"""
44
-
45
- )
46
- prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:
47
-
48
- <context>
49
- {context}
50
- </context>
51
-
52
- Frage: {input}
53
- """
54
- # Returns the answer in German
55
- )
56
- prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context:
57
-
58
- <context>
59
- {context}
60
- </context>
61
-
62
- Question: {input}
63
- """
64
- # Returns the answer in English
65
- )
66
-
67
- db_all = FAISS.load_local(folder_path="./src/FAISS", index_name="speeches_1949_09_12",
68
- embeddings=embeddings, allow_dangerous_deserialization=True)
69
-
70
- def get_vectorstore(inputs, embeddings):
71
- """
72
- Combine multiple FAISS vector stores into a single vector store based on the specified inputs.
73
-
74
- Parameters
75
- ----------
76
- inputs : list of str
77
- A list of strings specifying which vector stores to combine. Each string represents a specific
78
- index or a special keyword "All". If "All" is the first entry in the list,
79
- it directly return the pre-defined vectorstore for all speeches
80
-
81
- embeddings : Embeddings
82
- An instance of embeddings that will be used to load the vector stores. The specific type and
83
- structure of `embeddings` depend on the implementation of the `get_vectorstore` function.
84
-
85
- Returns
86
- -------
87
- FAISS
88
- A FAISS vector store that combines the specified indices into a single vector store.
89
-
90
- """
91
-
92
- # Default folder path
93
- folder_path = "./src/FAISS"
94
-
95
-
96
- if inputs[0] == "All" or inputs[0] is None:
97
- return db_all
98
-
99
- # Initialize empty db
100
- embedding_function = embeddings
101
- dimensions = len(embedding_function.embed_query("dummy"))
102
-
103
- db = FAISS(
104
- embedding_function=embedding_function,
105
- index=IndexFlatL2(dimensions),
106
- docstore=InMemoryDocstore(),
107
- index_to_docstore_id={},
108
- normalize_L2=False
109
- )
110
-
111
- # Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ...
112
- for input in inputs:
113
- # Ignore if user also selected All among other legislatures
114
- if input == "All":
115
- continue
116
- # Retrieve selected index and merge vector stores
117
- index = input.split(".")[0]
118
- index_name = f'{index}_legislature'
119
- local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
120
- embeddings=embeddings, allow_dangerous_deserialization=True)
121
- db.merge_from(local_db)
122
- print('Successfully merged inputs')
123
- return db
124
-
125
- def RAG(llm, prompt, db, question):
126
- """
127
- Apply Retrieval-Augmented Generation (RAG) by providing the context and the question to the
128
- language model using a predefined template.
129
-
130
- Parameters:
131
- ----------
132
- llm : LanguageModel
133
- An instance of the language model to be used for generating responses.
134
-
135
- prompt : str
136
- A predefined template or prompt that structures how the context and question are presented to the language model.
137
-
138
- db : VectorStore
139
- A vector store instance that supports retrieval of relevant documents based on the input question.
140
-
141
- question : str
142
- The question or query to be answered by the language model.
143
-
144
- Returns:
145
- -------
146
- str
147
- The response generated by the language model, based on the retrieved context and provided question.
148
- """
149
- # Create a document chain using the provided language model and prompt template
150
- document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
151
- # Convert the vector store into a retriever
152
- retriever = db.as_retriever()
153
- # Create a retrieval chain that integrates the retriever with the document chain
154
- retrieval_chain = create_retrieval_chain(retriever, document_chain)
155
- # Invoke the retrieval chain with the input question to get the final response
156
- response = retrieval_chain.invoke({"input": question})
157
-
158
- return response
159
-
160
-
161
- def chatbot(message, history, db_inputs, prompt_language, llm=llm):
162
- """
163
- Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model.
164
-
165
- Parameters:
166
- -----------
167
- message : str
168
- The message or question to be answered by the chatbot.
169
-
170
- history : list
171
- The history of previous interactions or messages.
172
-
173
- db_inputs : list
174
- A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All".
175
-
176
- prompt_language : str
177
- The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English.
178
-
179
- llm : LLM, optional
180
- An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`.
181
-
182
- Returns:
183
- --------
184
- str
185
- The response generated by the chatbot.
186
- """
187
-
188
- db = get_vectorstore(inputs = db_inputs, embeddings=embeddings)
189
-
190
- # Select prompt based on user input
191
- if prompt_language == "DE":
192
- prompt = prompt_de
193
- raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
194
- # Only necessary because mistral does include it´s json structure in the output including its input content
195
- try:
196
- response = raw_response['answer'].split("Antwort: ")[1]
197
- except:
198
- response = raw_response['answer']
199
- return response
200
- else:
201
- prompt = prompt_en
202
- raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
203
- # Only necessary because mistral does include it´s json structure in the output including its input content
204
- try:
205
- response = raw_response['answer'].split("Answer: ")[1]
206
- except:
207
- response = raw_response['answer']
208
-
209
- return response
210
-
211
-
212
- def keyword_search(query, n=10, embeddings=embeddings, method="ss", party_filter="All"):
213
- """
214
- Retrieve speech contents based on keywords using a specified method.
215
-
216
- Parameters:
217
- ----------
218
- db : FAISS
219
- The FAISS vector store containing speech embeddings.
220
-
221
- query : str
222
- The keyword(s) to search for in the speech contents.
223
-
224
- n : int, optional
225
- The number of speech contents to retrieve (default is 10).
226
-
227
- embeddings : Embeddings, optional
228
- An instance of embeddings used for embedding queries (default is embeddings).
229
-
230
- method : str, optional
231
- The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr'
232
- (maximal marginal relevance) (default is 'ss').
233
-
234
- party_filter : str, optional
235
- A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve
236
- speeches from all parties (default is 'All').
237
-
238
- Returns:
239
- -------
240
- pandas.DataFrame
241
- A DataFrame containing the speech contents, dates, and party affiliations.
242
-
243
- Notes:
244
- -----
245
- - The `db` parameter should be a FAISS vector store containing speech embeddings.
246
- - The `query` parameter specifies the keyword(s) to search for in the speech contents.
247
- - The `n` parameter determines the number of speech contents to retrieve (default is 10).
248
- - The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings).
249
- - The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search)
250
- and 'mmr' (maximal marginal relevance) (default is 'ss').
251
- - The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve
252
- speeches from all parties (default is 'All').
253
- """
254
-
255
- db = get_vectorstore(inputs=["All"], embeddings=embeddings)
256
- query_embedding = embeddings.embed_query(query)
257
-
258
- # Maximal Marginal Relevance
259
- if method == "mmr":
260
- df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance'])
261
- results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n)
262
- for doc in results:
263
- party = doc[0].metadata["party"]
264
- if party != party_filter and party_filter != 'All':
265
- continue
266
- speech_content = doc[0].page_content
267
- speech_date = doc[0].metadata["date"]
268
- score = round(doc[1], ndigits=2)
269
- df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
270
- 'Date': [speech_date],
271
- 'Party': [party],
272
- 'Relevance': [score]})], ignore_index=True)
273
- df_res.sort_values('Relevance', inplace=True, ascending=True)
274
-
275
- # Similarity Search
276
- elif method == "ss":
277
- kws_data = []
278
- results = db.similarity_search_by_vector(query_embedding, k=n)
279
- for doc in results:
280
- party = doc.metadata["party"]
281
- if party != party_filter and party_filter != 'All':
282
- continue
283
- speech_content = doc.page_content
284
- speech_date = doc.metadata["date"]
285
- speech_date = speech_date.strftime("%Y-%m-%d")
286
- print(speech_date)
287
- # Error here?
288
- kws_entry = {'Speech Content': speech_content,
289
- 'Date': speech_date,
290
- 'Party': party}
291
-
292
- kws_data.append(kws_entry)
293
-
294
- df_res = pd.DataFrame(kws_data)
295
-
296
- return df_res
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_community.llms.huggingface_hub import HuggingFaceHub
3
+ from langchain_community.embeddings import HuggingFaceEmbeddings
4
+ from langchain_community.vectorstores import FAISS
5
+
6
+
7
+ from langchain.chains.combine_documents import create_stuff_documents_chain
8
+ from langchain.chains import create_retrieval_chain
9
+
10
+ from langchain_community.docstore.in_memory import InMemoryDocstore
11
+ from faiss import IndexFlatL2
12
+
13
+ #import functools
14
+ import pandas as pd
15
+
16
+ # Load environmental variables from .env-file
17
+ from dotenv import load_dotenv, find_dotenv
18
+ load_dotenv(find_dotenv())
19
+
20
+ # Define important variables
21
+ embeddings = HuggingFaceEmbeddings(model_name="paraphrase-multilingual-MiniLM-L12-v2") # Remove embedding input parameter from functions?
22
+ llm = HuggingFaceHub(
23
+ # ToDo: Try different models here
24
+ repo_id = "mistralai/Ministral-8B-Instruct-2410"
25
+ #repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
26
+ # repo_id="CohereForAI/c4ai-command-r-v01", # too large 69gb
27
+ # repo_id="CohereForAI/c4ai-command-r-v01-4bit", # too large 22gb
28
+ # repo_id="meta-llama/Meta-Llama-3-8B", # too large 16 gb
29
+ task="text-generation",
30
+ model_kwargs={
31
+ "max_new_tokens": 512,
32
+ "top_k": 30,
33
+ "temperature": 0.1,
34
+ "repetition_penalty": 1.03,
35
+ }
36
+ )
37
+ # ToDo: Experiment with different templates
38
+ prompt_test = ChatPromptTemplate.from_template("""<s>[INST]
39
+ Instruction: Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:
40
+
41
+ Context: {context}
42
+
43
+ Question: {input}
44
+ [/INST]"""
45
+
46
+ )
47
+ prompt_de = ChatPromptTemplate.from_template("""Beantworte die folgende Frage auf deutsch und nur auf der Grundlage des angegebenen Kontexts:
48
+
49
+ <context>
50
+ {context}
51
+ </context>
52
+
53
+ Frage: {input}
54
+ """
55
+ # Returns the answer in German
56
+ )
57
+ prompt_en = ChatPromptTemplate.from_template("""Answer the following question in English and solely based on the provided context:
58
+
59
+ <context>
60
+ {context}
61
+ </context>
62
+
63
+ Question: {input}
64
+ """
65
+ # Returns the answer in English
66
+ )
67
+
68
+ db_all = FAISS.load_local(folder_path="./src/FAISS", index_name="speeches_1949_09_12",
69
+ embeddings=embeddings, allow_dangerous_deserialization=True)
70
+
71
+ def get_vectorstore(inputs, embeddings):
72
+ """
73
+ Combine multiple FAISS vector stores into a single vector store based on the specified inputs.
74
+
75
+ Parameters
76
+ ----------
77
+ inputs : list of str
78
+ A list of strings specifying which vector stores to combine. Each string represents a specific
79
+ index or a special keyword "All". If "All" is the first entry in the list,
80
+ it directly return the pre-defined vectorstore for all speeches
81
+
82
+ embeddings : Embeddings
83
+ An instance of embeddings that will be used to load the vector stores. The specific type and
84
+ structure of `embeddings` depend on the implementation of the `get_vectorstore` function.
85
+
86
+ Returns
87
+ -------
88
+ FAISS
89
+ A FAISS vector store that combines the specified indices into a single vector store.
90
+
91
+ """
92
+
93
+ # Default folder path
94
+ folder_path = "./src/FAISS"
95
+
96
+
97
+ if inputs[0] == "All" or inputs[0] is None:
98
+ return db_all
99
+
100
+ # Initialize empty db
101
+ embedding_function = embeddings
102
+ dimensions = len(embedding_function.embed_query("dummy"))
103
+
104
+ db = FAISS(
105
+ embedding_function=embedding_function,
106
+ index=IndexFlatL2(dimensions),
107
+ docstore=InMemoryDocstore(),
108
+ index_to_docstore_id={},
109
+ normalize_L2=False
110
+ )
111
+
112
+ # Retrieve inputs: 20. Legislaturperiode, 19. Legislaturperiode, ...
113
+ for input in inputs:
114
+ # Ignore if user also selected All among other legislatures
115
+ if input == "All":
116
+ continue
117
+ # Retrieve selected index and merge vector stores
118
+ index = input.split(".")[0]
119
+ index_name = f'{index}_legislature'
120
+ local_db = FAISS.load_local(folder_path=folder_path, index_name=index_name,
121
+ embeddings=embeddings, allow_dangerous_deserialization=True)
122
+ db.merge_from(local_db)
123
+ print('Successfully merged inputs')
124
+ return db
125
+
126
+
127
+ def RAG(llm, prompt, db, question):
128
+ """
129
+ Apply Retrieval-Augmented Generation (RAG) by providing the context and the question to the
130
+ language model using a predefined template.
131
+
132
+ Parameters:
133
+ ----------
134
+ llm : LanguageModel
135
+ An instance of the language model to be used for generating responses.
136
+
137
+ prompt : str
138
+ A predefined template or prompt that structures how the context and question are presented to the language model.
139
+
140
+ db : VectorStore
141
+ A vector store instance that supports retrieval of relevant documents based on the input question.
142
+
143
+ question : str
144
+ The question or query to be answered by the language model.
145
+
146
+ Returns:
147
+ -------
148
+ str
149
+ The response generated by the language model, based on the retrieved context and provided question.
150
+ """
151
+ # Create a document chain using the provided language model and prompt template
152
+ document_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)
153
+ # Convert the vector store into a retriever
154
+ retriever = db.as_retriever()
155
+ # Create a retrieval chain that integrates the retriever with the document chain
156
+ retrieval_chain = create_retrieval_chain(retriever, document_chain)
157
+ # Invoke the retrieval chain with the input question to get the final response
158
+ response = retrieval_chain.invoke({"input": question})
159
+
160
+ return response
161
+
162
+
163
+ def chatbot(message, history, db_inputs, prompt_language, llm=llm):
164
+ """
165
+ Generate a response from the chatbot based on the provided message, history, database inputs, prompt language, and LLM model.
166
+
167
+ Parameters:
168
+ -----------
169
+ message : str
170
+ The message or question to be answered by the chatbot.
171
+
172
+ history : list
173
+ The history of previous interactions or messages.
174
+
175
+ db_inputs : list
176
+ A list of strings specifying which vector stores to combine. Each string represents a specific index or a special keyword "All".
177
+
178
+ prompt_language : str
179
+ The language of the prompt to be used for generating the response. Should be either "DE" for German or "EN" for English.
180
+
181
+ llm : LLM, optional
182
+ An instance of the Language Model to be used for generating the response. Defaults to the global variable `llm`.
183
+
184
+ Returns:
185
+ --------
186
+ str
187
+ The response generated by the chatbot.
188
+ """
189
+
190
+ db = get_vectorstore(inputs = db_inputs, embeddings=embeddings)
191
+
192
+ # Select prompt based on user input
193
+ if prompt_language == "DE":
194
+ prompt = prompt_de
195
+ raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
196
+ # Only necessary because mistral does include it´s json structure in the output including its input content
197
+ try:
198
+ response = raw_response['answer'].split("Antwort: ")[1]
199
+ except:
200
+ response = raw_response['answer']
201
+ return response
202
+ else:
203
+ prompt = prompt_en
204
+ raw_response = RAG(llm=llm, prompt=prompt, db=db, question=message)
205
+ # Only necessary because mistral does include it´s json structure in the output including its input content
206
+ try:
207
+ response = raw_response['answer'].split("Answer: ")[1]
208
+ except:
209
+ response = raw_response['answer']
210
+
211
+ return response
212
+
213
+
214
+ def keyword_search(query, n=10, embeddings=embeddings, method="ss", party_filter="All"):
215
+ """
216
+ Retrieve speech contents based on keywords using a specified method.
217
+
218
+ Parameters:
219
+ ----------
220
+ db : FAISS
221
+ The FAISS vector store containing speech embeddings.
222
+
223
+ query : str
224
+ The keyword(s) to search for in the speech contents.
225
+
226
+ n : int, optional
227
+ The number of speech contents to retrieve (default is 10).
228
+
229
+ embeddings : Embeddings, optional
230
+ An instance of embeddings used for embedding queries (default is embeddings).
231
+
232
+ method : str, optional
233
+ The method used for retrieving speech contents. Options are 'ss' (semantic search) and 'mmr'
234
+ (maximal marginal relevance) (default is 'ss').
235
+
236
+ party_filter : str, optional
237
+ A filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve
238
+ speeches from all parties (default is 'All').
239
+
240
+ Returns:
241
+ -------
242
+ pandas.DataFrame
243
+ A DataFrame containing the speech contents, dates, and party affiliations.
244
+
245
+ Notes:
246
+ -----
247
+ - The `db` parameter should be a FAISS vector store containing speech embeddings.
248
+ - The `query` parameter specifies the keyword(s) to search for in the speech contents.
249
+ - The `n` parameter determines the number of speech contents to retrieve (default is 10).
250
+ - The `embeddings` parameter is an instance of embeddings used for embedding queries (default is embeddings).
251
+ - The `method` parameter specifies the method used for retrieving speech contents. Options are 'ss' (semantic search)
252
+ and 'mmr' (maximal marginal relevance) (default is 'ss').
253
+ - The `party_filter` parameter is a filter for retrieving speech contents by party affiliation. Specify 'All' to retrieve
254
+ speeches from all parties (default is 'All').
255
+ """
256
+
257
+ db = get_vectorstore(inputs=["All"], embeddings=embeddings)
258
+ query_embedding = embeddings.embed_query(query)
259
+
260
+ # Maximal Marginal Relevance
261
+ if method == "mmr":
262
+ df_res = pd.DataFrame(columns=['Speech Content', 'Date', 'Party', 'Relevance'])
263
+ results = db.max_marginal_relevance_search_with_score_by_vector(query_embedding, k=n)
264
+ for doc in results:
265
+ party = doc[0].metadata["party"]
266
+ if party != party_filter and party_filter != 'All':
267
+ continue
268
+ speech_content = doc[0].page_content
269
+ speech_date = doc[0].metadata["date"]
270
+ score = round(doc[1], ndigits=2)
271
+ df_res = pd.concat([df_res, pd.DataFrame({'Speech Content': [speech_content],
272
+ 'Date': [speech_date],
273
+ 'Party': [party],
274
+ 'Relevance': [score]})], ignore_index=True)
275
+ df_res.sort_values('Relevance', inplace=True, ascending=True)
276
+
277
+ # Similarity Search
278
+ elif method == "ss":
279
+ kws_data = []
280
+ results = db.similarity_search_by_vector(query_embedding, k=n)
281
+ for doc in results:
282
+ party = doc.metadata["party"]
283
+ if party != party_filter and party_filter != 'All':
284
+ continue
285
+ speech_content = doc.page_content
286
+ speech_date = doc.metadata["date"]
287
+ speech_date = speech_date.strftime("%Y-%m-%d")
288
+ print(speech_date)
289
+ # Error here?
290
+ kws_entry = {'Speech Content': speech_content,
291
+ 'Date': speech_date,
292
+ 'Party': party}
293
+
294
+ kws_data.append(kws_entry)
295
+
296
+ df_res = pd.DataFrame(kws_data)
297
+
298
+ return df_res