Update app.py
Browse files
app.py
CHANGED
|
@@ -358,6 +358,36 @@ def truncate_text(text, max_tokens):
|
|
| 358 |
return text
|
| 359 |
return ' '.join(words[:max_tokens])
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
| 362 |
if not question:
|
| 363 |
return "Please enter a question."
|
|
@@ -375,7 +405,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 375 |
database = None
|
| 376 |
|
| 377 |
max_attempts = 3
|
| 378 |
-
context_reduction_factor = 0.7
|
| 379 |
max_input_tokens = 31000 # Leave room for the model's response
|
| 380 |
max_output_tokens = 1000
|
| 381 |
|
|
@@ -386,7 +415,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 386 |
print(f"User Instructions: {user_instructions}")
|
| 387 |
|
| 388 |
try:
|
| 389 |
-
search_results = google_search(contextualized_question, num_results=3
|
| 390 |
except Exception as e:
|
| 391 |
print(f"Error in web search: {e}")
|
| 392 |
return f"I apologize, but I encountered an error while searching for information: {str(e)}"
|
|
@@ -407,7 +436,8 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 407 |
|
| 408 |
database.save_local("faiss_database")
|
| 409 |
|
| 410 |
-
|
|
|
|
| 411 |
|
| 412 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 413 |
|
|
@@ -425,33 +455,20 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 425 |
|
| 426 |
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
estimated_tokens = estimate_tokens(formatted_prompt)
|
| 443 |
-
|
| 444 |
-
if estimated_tokens <= max_input_tokens:
|
| 445 |
-
break
|
| 446 |
-
|
| 447 |
-
# Reduce context sizes
|
| 448 |
-
current_context = truncate_text(current_context, int(estimate_tokens(current_context) * context_reduction_factor))
|
| 449 |
-
current_conv_context = truncate_text(current_conv_context, int(estimate_tokens(current_conv_context) * context_reduction_factor))
|
| 450 |
-
current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
|
| 451 |
-
current_entities = {k: v[:max(1, int(len(v) * context_reduction_factor))] for k, v in current_entities.items()}
|
| 452 |
-
|
| 453 |
-
if estimate_tokens(current_context) + estimate_tokens(current_conv_context) + estimate_tokens(", ".join(current_topics)) + estimate_tokens(json.dumps(current_entities)) < 100:
|
| 454 |
-
raise ValueError("Context reduced too much. Unable to process the query.")
|
| 455 |
|
| 456 |
try:
|
| 457 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=max_output_tokens)
|
|
@@ -463,11 +480,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 463 |
if attempt == max_attempts - 1:
|
| 464 |
all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
|
| 465 |
|
| 466 |
-
except ValueError as ve:
|
| 467 |
-
print(f"Error in ask_question (attempt {attempt + 1}): {ve}")
|
| 468 |
-
if attempt == max_attempts - 1:
|
| 469 |
-
all_answers.append(f"I apologize, but I'm having trouble processing the query due to its length or complexity. Could you please try asking a more specific or shorter question?")
|
| 470 |
-
|
| 471 |
except Exception as e:
|
| 472 |
print(f"Error in ask_question (attempt {attempt + 1}): {e}")
|
| 473 |
if attempt == max_attempts - 1:
|
|
@@ -488,9 +500,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 488 |
if database is None:
|
| 489 |
return "No documents available. Please upload PDF documents to answer questions."
|
| 490 |
|
| 491 |
-
retriever = database.as_retriever()
|
| 492 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 493 |
-
|
|
|
|
|
|
|
| 494 |
|
| 495 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 496 |
|
|
@@ -503,19 +517,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 503 |
Provide a summarized and direct answer to the question.
|
| 504 |
"""
|
| 505 |
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
| 509 |
-
|
| 510 |
-
estimated_tokens = estimate_tokens(formatted_prompt)
|
| 511 |
-
|
| 512 |
-
if estimated_tokens <= max_input_tokens:
|
| 513 |
-
break
|
| 514 |
-
|
| 515 |
-
context_str = truncate_text(context_str, int(estimate_tokens(context_str) * context_reduction_factor))
|
| 516 |
|
| 517 |
-
|
| 518 |
-
|
| 519 |
|
| 520 |
try:
|
| 521 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=max_output_tokens)
|
|
@@ -526,11 +532,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 526 |
if attempt == max_attempts - 1:
|
| 527 |
return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
|
| 528 |
|
| 529 |
-
except ValueError as ve:
|
| 530 |
-
print(f"Error in ask_question (attempt {attempt + 1}): {ve}")
|
| 531 |
-
if attempt == max_attempts - 1:
|
| 532 |
-
return f"I apologize, but I'm having trouble processing your question due to the complexity of the document. Could you please try asking a more specific or shorter question?"
|
| 533 |
-
|
| 534 |
except Exception as e:
|
| 535 |
print(f"Error in ask_question (attempt {attempt + 1}): {e}")
|
| 536 |
if attempt == max_attempts - 1:
|
|
|
|
| 358 |
return text
|
| 359 |
return ' '.join(words[:max_tokens])
|
| 360 |
|
| 361 |
+
def estimate_tokens(text):
|
| 362 |
+
return len(text.split())
|
| 363 |
+
|
| 364 |
+
def truncate_text(text, max_tokens):
|
| 365 |
+
words = text.split()
|
| 366 |
+
if len(words) <= max_tokens:
|
| 367 |
+
return text
|
| 368 |
+
return ' '.join(words[:max_tokens])
|
| 369 |
+
|
| 370 |
+
def rerank_documents(query: str, documents: List[Document], top_k: int = 5) -> List[Document]:
|
| 371 |
+
query_embedding = sentence_model.encode([query])[0]
|
| 372 |
+
doc_embeddings = sentence_model.encode([doc.page_content for doc in documents])
|
| 373 |
+
|
| 374 |
+
similarities = cosine_similarity([query_embedding], doc_embeddings)[0]
|
| 375 |
+
|
| 376 |
+
ranked_indices = similarities.argsort()[::-1][:top_k]
|
| 377 |
+
return [documents[i] for i in ranked_indices]
|
| 378 |
+
|
| 379 |
+
def prepare_context(query: str, documents: List[Document], max_tokens: int) -> str:
|
| 380 |
+
reranked_docs = rerank_documents(query, documents)
|
| 381 |
+
|
| 382 |
+
context = ""
|
| 383 |
+
for doc in reranked_docs:
|
| 384 |
+
doc_content = f"Source: {doc.metadata.get('source', 'Unknown')}\nContent: {doc.page_content}\n\n"
|
| 385 |
+
if estimate_tokens(context + doc_content) > max_tokens:
|
| 386 |
+
break
|
| 387 |
+
context += doc_content
|
| 388 |
+
|
| 389 |
+
return truncate_text(context, max_tokens)
|
| 390 |
+
|
| 391 |
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
| 392 |
if not question:
|
| 393 |
return "Please enter a question."
|
|
|
|
| 405 |
database = None
|
| 406 |
|
| 407 |
max_attempts = 3
|
|
|
|
| 408 |
max_input_tokens = 31000 # Leave room for the model's response
|
| 409 |
max_output_tokens = 1000
|
| 410 |
|
|
|
|
| 415 |
print(f"User Instructions: {user_instructions}")
|
| 416 |
|
| 417 |
try:
|
| 418 |
+
search_results = google_search(contextualized_question, num_results=5) # Increased from 3 to 5
|
| 419 |
except Exception as e:
|
| 420 |
print(f"Error in web search: {e}")
|
| 421 |
return f"I apologize, but I encountered an error while searching for information: {str(e)}"
|
|
|
|
| 436 |
|
| 437 |
database.save_local("faiss_database")
|
| 438 |
|
| 439 |
+
# Prepare context using reranking
|
| 440 |
+
context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
|
| 441 |
|
| 442 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 443 |
|
|
|
|
| 455 |
|
| 456 |
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 457 |
|
| 458 |
+
current_conv_context = truncate_text(chatbot.get_context(), max_input_tokens // 4) # Use quarter of max_input_tokens for conversation context
|
| 459 |
+
current_topics = topics[:5] # Limit to top 5 topics
|
| 460 |
+
current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()} # Limit to top 3 entities per type
|
| 461 |
+
|
| 462 |
+
formatted_prompt = prompt_val.format(
|
| 463 |
+
context=context_str,
|
| 464 |
+
conv_context=current_conv_context,
|
| 465 |
+
question=question,
|
| 466 |
+
topics=", ".join(current_topics),
|
| 467 |
+
entities=json.dumps(current_entities)
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if estimate_tokens(formatted_prompt) > max_input_tokens:
|
| 471 |
+
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
try:
|
| 474 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=max_output_tokens)
|
|
|
|
| 480 |
if attempt == max_attempts - 1:
|
| 481 |
all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
except Exception as e:
|
| 484 |
print(f"Error in ask_question (attempt {attempt + 1}): {e}")
|
| 485 |
if attempt == max_attempts - 1:
|
|
|
|
| 500 |
if database is None:
|
| 501 |
return "No documents available. Please upload PDF documents to answer questions."
|
| 502 |
|
| 503 |
+
retriever = database.as_retriever(search_kwargs={"k": 10}) # Retrieve more documents for reranking
|
| 504 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 505 |
+
|
| 506 |
+
# Prepare context using reranking
|
| 507 |
+
context_str = prepare_context(question, relevant_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
|
| 508 |
|
| 509 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 510 |
|
|
|
|
| 517 |
Provide a summarized and direct answer to the question.
|
| 518 |
"""
|
| 519 |
|
| 520 |
+
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 521 |
+
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
+
if estimate_tokens(formatted_prompt) > max_input_tokens:
|
| 524 |
+
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
| 525 |
|
| 526 |
try:
|
| 527 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=max_output_tokens)
|
|
|
|
| 532 |
if attempt == max_attempts - 1:
|
| 533 |
return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
except Exception as e:
|
| 536 |
print(f"Error in ask_question (attempt {attempt + 1}): {e}")
|
| 537 |
if attempt == max_attempts - 1:
|