Update app.py
Browse files
app.py
CHANGED
|
@@ -41,7 +41,11 @@ class Agent1:
|
|
| 41 |
chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
|
| 42 |
response = chain.run(query=user_input).strip()
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
|
| 47 |
queries = self.rephrase_and_split(user_input)
|
|
@@ -50,39 +54,6 @@ class Agent1:
|
|
| 50 |
results[query] = google_search(query)
|
| 51 |
return results
|
| 52 |
|
| 53 |
-
class Agent2:
|
| 54 |
-
def __init__(self, model):
|
| 55 |
-
self.model = model
|
| 56 |
-
|
| 57 |
-
def validate_response(self, user_query: str, response: str) -> bool:
|
| 58 |
-
validation_prompt = PromptTemplate(
|
| 59 |
-
input_variables=["query", "response"],
|
| 60 |
-
template="""
|
| 61 |
-
Evaluate if the following response fully answers the user's query.
|
| 62 |
-
User query: {query}
|
| 63 |
-
Response: {response}
|
| 64 |
-
|
| 65 |
-
Does the response fully answer the query? Answer with Yes or No:"""
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
chain = LLMChain(llm=self.model, prompt=validation_prompt)
|
| 69 |
-
result = chain.run(query=user_query, response=response).strip().lower()
|
| 70 |
-
return result == 'yes'
|
| 71 |
-
|
| 72 |
-
def generate_follow_up_query(self, user_query: str, response: str) -> str:
|
| 73 |
-
follow_up_prompt = PromptTemplate(
|
| 74 |
-
input_variables=["query", "response"],
|
| 75 |
-
template="""
|
| 76 |
-
The following response did not fully answer the user's query.
|
| 77 |
-
User query: {query}
|
| 78 |
-
Response: {response}
|
| 79 |
-
|
| 80 |
-
Generate a follow-up query to get more relevant information:"""
|
| 81 |
-
)
|
| 82 |
-
|
| 83 |
-
chain = LLMChain(llm=self.model, prompt=follow_up_prompt)
|
| 84 |
-
return chain.run(query=user_query, response=response).strip()
|
| 85 |
-
|
| 86 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
| 87 |
"""Loads and splits the document into pages."""
|
| 88 |
loader = PyPDFLoader(file.name)
|
|
@@ -270,7 +241,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 270 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 271 |
embed = get_embeddings()
|
| 272 |
agent1 = Agent1(model)
|
| 273 |
-
agent2 = Agent2(model)
|
| 274 |
|
| 275 |
if os.path.exists("faiss_database"):
|
| 276 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
|
@@ -279,7 +249,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 279 |
|
| 280 |
max_attempts = 3
|
| 281 |
context_reduction_factor = 0.7
|
| 282 |
-
agent2_max_attempts = 2
|
| 283 |
|
| 284 |
for attempt in range(max_attempts):
|
| 285 |
try:
|
|
@@ -350,20 +319,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
| 350 |
else:
|
| 351 |
answer = full_response.strip()
|
| 352 |
|
| 353 |
-
for agent2_attempt in range(agent2_max_attempts):
|
| 354 |
-
if agent2.validate_response(question, answer):
|
| 355 |
-
break
|
| 356 |
-
|
| 357 |
-
if agent2_attempt < agent2_max_attempts - 1:
|
| 358 |
-
follow_up_query = agent2.generate_follow_up_query(question, answer)
|
| 359 |
-
follow_up_results = agent1.process(follow_up_query)
|
| 360 |
-
follow_up_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": follow_up_query}) for results in follow_up_results.values() for result in results if result["text"]]
|
| 361 |
-
database.add_documents(follow_up_docs)
|
| 362 |
-
context_str += "\n" + "\n".join([f"Follow-up Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in follow_up_docs])
|
| 363 |
-
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
| 364 |
-
full_response = generate_chunked_response(model, formatted_prompt)
|
| 365 |
-
answer = full_response.strip()
|
| 366 |
-
|
| 367 |
if web_search:
|
| 368 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
| 369 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
|
|
|
| 41 |
chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
|
| 42 |
response = chain.run(query=user_input).strip()
|
| 43 |
|
| 44 |
+
# Remove any lines that contain instructions or explanations
|
| 45 |
+
rephrased_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.startswith("Rephrase") and "query" not in q.lower()]
|
| 46 |
+
|
| 47 |
+
# If no valid rephrased queries, return the original input
|
| 48 |
+
return rephrased_queries if rephrased_queries else [user_input]
|
| 49 |
|
| 50 |
def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
|
| 51 |
queries = self.rephrase_and_split(user_input)
|
|
|
|
| 54 |
results[query] = google_search(query)
|
| 55 |
return results
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
| 58 |
"""Loads and splits the document into pages."""
|
| 59 |
loader = PyPDFLoader(file.name)
|
|
|
|
| 241 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 242 |
embed = get_embeddings()
|
| 243 |
agent1 = Agent1(model)
|
|
|
|
| 244 |
|
| 245 |
if os.path.exists("faiss_database"):
|
| 246 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
|
|
|
| 249 |
|
| 250 |
max_attempts = 3
|
| 251 |
context_reduction_factor = 0.7
|
|
|
|
| 252 |
|
| 253 |
for attempt in range(max_attempts):
|
| 254 |
try:
|
|
|
|
| 319 |
else:
|
| 320 |
answer = full_response.strip()
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
if web_search:
|
| 323 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
| 324 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|