Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import requests
|
|
| 7 |
import random
|
| 8 |
import urllib.parse
|
| 9 |
from tempfile import NamedTemporaryFile
|
| 10 |
-
from typing import List
|
| 11 |
from bs4 import BeautifulSoup
|
| 12 |
from langchain.prompts import PromptTemplate
|
| 13 |
from langchain.chains import LLMChain
|
|
@@ -17,10 +17,72 @@ from langchain_community.document_loaders import PyPDFLoader
|
|
| 17 |
from langchain_core.output_parsers import StrOutputParser
|
| 18 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 19 |
from langchain_community.llms import HuggingFaceHub
|
| 20 |
-
from langchain_core.documents import Document
|
| 21 |
|
| 22 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
| 25 |
"""Loads and splits the document into pages."""
|
| 26 |
loader = PyPDFLoader(file.name)
|
|
@@ -207,6 +269,8 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 207 |
|
| 208 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 209 |
embed = get_embeddings()
|
|
|
|
|
|
|
| 210 |
|
| 211 |
if os.path.exists("faiss_database"):
|
| 212 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
|
@@ -219,16 +283,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 219 |
for attempt in range(max_attempts):
|
| 220 |
try:
|
| 221 |
if web_search:
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
if rephrased_query == original_query:
|
| 228 |
-
print("Warning: Query was not rephrased. Using original query for search.")
|
| 229 |
-
|
| 230 |
-
search_results = google_search(rephrased_query)
|
| 231 |
-
web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
|
| 232 |
|
| 233 |
if database is None:
|
| 234 |
database = FAISS.from_documents(web_docs, embed)
|
|
@@ -237,20 +295,17 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 237 |
|
| 238 |
database.save_local("faiss_database")
|
| 239 |
|
| 240 |
-
context_str = "\n".join([f"
|
| 241 |
|
| 242 |
prompt_template = """
|
| 243 |
Answer the question based on the following web search results:
|
| 244 |
Web Search Results:
|
| 245 |
{context}
|
| 246 |
-
Original Question: {
|
| 247 |
-
Rephrased Search Query: {rephrased_query}
|
| 248 |
If the web search results don't contain relevant information, state that the information is not available in the search results.
|
| 249 |
Provide a concise and direct answer to the original question without mentioning the web search or these instructions.
|
| 250 |
Do not include any source information in your answer.
|
| 251 |
"""
|
| 252 |
-
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 253 |
-
formatted_prompt = prompt_val.format(context=context_str, original_question=question, rephrased_query=rephrased_query)
|
| 254 |
else:
|
| 255 |
if database is None:
|
| 256 |
return "No documents available. Please upload documents or enable web search to answer questions."
|
|
@@ -259,7 +314,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 259 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 260 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 261 |
|
| 262 |
-
# Reduce context if we're not on the first attempt
|
| 263 |
if attempt > 0:
|
| 264 |
words = context_str.split()
|
| 265 |
context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
|
|
@@ -273,8 +327,9 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 273 |
Provide a concise and direct answer to the question.
|
| 274 |
Do not include any source information in your answer.
|
| 275 |
"""
|
| 276 |
-
|
| 277 |
-
|
|
|
|
| 278 |
|
| 279 |
full_response = generate_chunked_response(model, formatted_prompt)
|
| 280 |
|
|
@@ -294,7 +349,16 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 294 |
else:
|
| 295 |
answer = full_response.strip()
|
| 296 |
|
| 297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
if web_search:
|
| 299 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
| 300 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
|
|
|
| 7 |
import random
|
| 8 |
import urllib.parse
|
| 9 |
from tempfile import NamedTemporaryFile
|
| 10 |
+
from typing import List, Dict
|
| 11 |
from bs4 import BeautifulSoup
|
| 12 |
from langchain.prompts import PromptTemplate
|
| 13 |
from langchain.chains import LLMChain
|
|
|
|
| 17 |
from langchain_core.output_parsers import StrOutputParser
|
| 18 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 19 |
from langchain_community.llms import HuggingFaceHub
|
| 20 |
+
from langchain_core.documents import Document
|
| 21 |
|
| 22 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 23 |
|
| 24 |
+
class Agent1:
|
| 25 |
+
def __init__(self, model):
|
| 26 |
+
self.model = model
|
| 27 |
+
|
| 28 |
+
def rephrase_and_split(self, user_input: str) -> List[str]:
|
| 29 |
+
rephrase_prompt = PromptTemplate(
|
| 30 |
+
input_variables=["query"],
|
| 31 |
+
template="""
|
| 32 |
+
Your task is to rephrase the given query into one or more concise, search-engine-friendly formats.
|
| 33 |
+
If the query contains multiple distinct questions, split them.
|
| 34 |
+
Provide ONLY the rephrased queries without any additional text or explanations, one per line.
|
| 35 |
+
|
| 36 |
+
Query: {query}
|
| 37 |
+
|
| 38 |
+
Rephrased queries:"""
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
|
| 42 |
+
response = chain.run(query=user_input).strip()
|
| 43 |
+
|
| 44 |
+
return [q.strip() for q in response.split('\n') if q.strip()]
|
| 45 |
+
|
| 46 |
+
def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
|
| 47 |
+
queries = self.rephrase_and_split(user_input)
|
| 48 |
+
results = {}
|
| 49 |
+
for query in queries:
|
| 50 |
+
results[query] = google_search(query)
|
| 51 |
+
return results
|
| 52 |
+
|
| 53 |
+
class Agent2:
|
| 54 |
+
def __init__(self, model):
|
| 55 |
+
self.model = model
|
| 56 |
+
|
| 57 |
+
def validate_response(self, user_query: str, response: str) -> bool:
|
| 58 |
+
validation_prompt = PromptTemplate(
|
| 59 |
+
input_variables=["query", "response"],
|
| 60 |
+
template="""
|
| 61 |
+
Evaluate if the following response fully answers the user's query.
|
| 62 |
+
User query: {query}
|
| 63 |
+
Response: {response}
|
| 64 |
+
|
| 65 |
+
Does the response fully answer the query? Answer with Yes or No:"""
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
chain = LLMChain(llm=self.model, prompt=validation_prompt)
|
| 69 |
+
result = chain.run(query=user_query, response=response).strip().lower()
|
| 70 |
+
return result == 'yes'
|
| 71 |
+
|
| 72 |
+
def generate_follow_up_query(self, user_query: str, response: str) -> str:
|
| 73 |
+
follow_up_prompt = PromptTemplate(
|
| 74 |
+
input_variables=["query", "response"],
|
| 75 |
+
template="""
|
| 76 |
+
The following response did not fully answer the user's query.
|
| 77 |
+
User query: {query}
|
| 78 |
+
Response: {response}
|
| 79 |
+
|
| 80 |
+
Generate a follow-up query to get more relevant information:"""
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
chain = LLMChain(llm=self.model, prompt=follow_up_prompt)
|
| 84 |
+
return chain.run(query=user_query, response=response).strip()
|
| 85 |
+
|
| 86 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
| 87 |
"""Loads and splits the document into pages."""
|
| 88 |
loader = PyPDFLoader(file.name)
|
|
|
|
| 269 |
|
| 270 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 271 |
embed = get_embeddings()
|
| 272 |
+
agent1 = Agent1(model)
|
| 273 |
+
agent2 = Agent2(model)
|
| 274 |
|
| 275 |
if os.path.exists("faiss_database"):
|
| 276 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
|
|
|
| 283 |
for attempt in range(max_attempts):
|
| 284 |
try:
|
| 285 |
if web_search:
|
| 286 |
+
search_results = agent1.process(question)
|
| 287 |
+
web_docs = []
|
| 288 |
+
for query, results in search_results.items():
|
| 289 |
+
web_docs.extend([Document(page_content=result["text"], metadata={"source": result["link"], "query": query}) for result in results if result["text"]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
if database is None:
|
| 292 |
database = FAISS.from_documents(web_docs, embed)
|
|
|
|
| 295 |
|
| 296 |
database.save_local("faiss_database")
|
| 297 |
|
| 298 |
+
context_str = "\n".join([f"Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
|
| 299 |
|
| 300 |
prompt_template = """
|
| 301 |
Answer the question based on the following web search results:
|
| 302 |
Web Search Results:
|
| 303 |
{context}
|
| 304 |
+
Original Question: {question}
|
|
|
|
| 305 |
If the web search results don't contain relevant information, state that the information is not available in the search results.
|
| 306 |
Provide a concise and direct answer to the original question without mentioning the web search or these instructions.
|
| 307 |
Do not include any source information in your answer.
|
| 308 |
"""
|
|
|
|
|
|
|
| 309 |
else:
|
| 310 |
if database is None:
|
| 311 |
return "No documents available. Please upload documents or enable web search to answer questions."
|
|
|
|
| 314 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 315 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 316 |
|
|
|
|
| 317 |
if attempt > 0:
|
| 318 |
words = context_str.split()
|
| 319 |
context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
|
|
|
|
| 327 |
Provide a concise and direct answer to the question.
|
| 328 |
Do not include any source information in your answer.
|
| 329 |
"""
|
| 330 |
+
|
| 331 |
+
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 332 |
+
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
| 333 |
|
| 334 |
full_response = generate_chunked_response(model, formatted_prompt)
|
| 335 |
|
|
|
|
| 349 |
else:
|
| 350 |
answer = full_response.strip()
|
| 351 |
|
| 352 |
+
if not agent2.validate_response(question, answer):
|
| 353 |
+
follow_up_query = agent2.generate_follow_up_query(question, answer)
|
| 354 |
+
follow_up_results = agent1.process(follow_up_query)
|
| 355 |
+
follow_up_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": follow_up_query}) for results in follow_up_results.values() for result in results if result["text"]]
|
| 356 |
+
database.add_documents(follow_up_docs)
|
| 357 |
+
context_str += "\n" + "\n".join([f"Follow-up Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in follow_up_docs])
|
| 358 |
+
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
| 359 |
+
full_response = generate_chunked_response(model, formatted_prompt)
|
| 360 |
+
answer = full_response.strip()
|
| 361 |
+
|
| 362 |
if web_search:
|
| 363 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
| 364 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|