Update app.py
Browse files
app.py
CHANGED
|
@@ -22,7 +22,12 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
| 22 |
from langchain_community.llms import HuggingFaceHub
|
| 23 |
from langchain_core.documents import Document
|
| 24 |
from sentence_transformers import SentenceTransformer
|
| 25 |
-
from llama_parse import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 28 |
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
|
|
@@ -378,10 +383,25 @@ def prepare_context(query: str, documents: List[Document], max_tokens: int) -> s
|
|
| 378 |
|
| 379 |
return truncate_text(context, max_tokens)
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
| 382 |
if not question:
|
| 383 |
return "Please enter a question."
|
| 384 |
|
|
|
|
| 385 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 386 |
|
| 387 |
# Update the chatbot's model
|
|
@@ -395,17 +415,14 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 395 |
database = None
|
| 396 |
|
| 397 |
max_attempts = 3
|
| 398 |
-
max_input_tokens = 20000
|
| 399 |
max_output_tokens = 800
|
| 400 |
|
| 401 |
if web_search:
|
| 402 |
contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
|
| 403 |
|
| 404 |
-
print(f"Contextualized question: {contextualized_question}")
|
| 405 |
-
print(f"User Instructions: {user_instructions}")
|
| 406 |
-
|
| 407 |
try:
|
| 408 |
-
search_results = google_search(contextualized_question, num_results=5)
|
| 409 |
except Exception as e:
|
| 410 |
print(f"Error in web search: {e}")
|
| 411 |
return f"I apologize, but I encountered an error while searching for information: {str(e)}"
|
|
@@ -426,8 +443,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 426 |
|
| 427 |
database.save_local("faiss_database")
|
| 428 |
|
| 429 |
-
|
| 430 |
-
context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
|
| 431 |
|
| 432 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 433 |
|
|
@@ -443,13 +459,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 443 |
Provide a concise and relevant answer to the question.
|
| 444 |
"""
|
| 445 |
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
current_topics = topics[:5] # Limit to top 5 topics
|
| 450 |
-
current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()} # Limit to top 3 entities per type
|
| 451 |
|
| 452 |
-
formatted_prompt =
|
| 453 |
context=context_str,
|
| 454 |
conv_context=current_conv_context,
|
| 455 |
question=question,
|
|
@@ -461,12 +475,17 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 461 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
| 462 |
|
| 463 |
try:
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
answer = extract_answer(full_response, user_instructions)
|
| 466 |
all_answers.append(answer)
|
| 467 |
break
|
| 468 |
except Exception as e:
|
| 469 |
-
print(f"Error in
|
| 470 |
if attempt == max_attempts - 1:
|
| 471 |
all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
|
| 472 |
|
|
@@ -490,11 +509,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 490 |
if database is None:
|
| 491 |
return "No documents available. Please upload PDF documents to answer questions."
|
| 492 |
|
| 493 |
-
retriever = database.as_retriever(search_kwargs={"k":
|
| 494 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 495 |
|
| 496 |
-
|
| 497 |
-
context_str = prepare_context(question, relevant_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
|
| 498 |
|
| 499 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 500 |
|
|
@@ -507,18 +525,22 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 507 |
Provide a summarized and direct answer to the question.
|
| 508 |
"""
|
| 509 |
|
| 510 |
-
|
| 511 |
-
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
| 512 |
|
| 513 |
if estimate_tokens(formatted_prompt) > max_input_tokens:
|
| 514 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
| 515 |
|
| 516 |
try:
|
| 517 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
answer = extract_answer(full_response, user_instructions)
|
| 519 |
return answer
|
| 520 |
except Exception as e:
|
| 521 |
-
print(f"Error in
|
| 522 |
if attempt == max_attempts - 1:
|
| 523 |
return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
|
| 524 |
|
|
@@ -591,13 +613,14 @@ with gr.Blocks() as demo:
|
|
| 591 |
|
| 592 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
| 593 |
|
|
|
|
| 594 |
def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
|
| 595 |
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
|
| 596 |
history.append((question, answer))
|
| 597 |
return "", history
|
| 598 |
|
| 599 |
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
|
| 600 |
-
|
| 601 |
clear_button = gr.Button("Clear Cache")
|
| 602 |
clear_output = gr.Textbox(label="Cache Status")
|
| 603 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|
|
|
|
| 22 |
from langchain_community.llms import HuggingFaceHub
|
| 23 |
from langchain_core.documents import Document
|
| 24 |
from sentence_transformers import SentenceTransformer
|
| 25 |
+
from llama_parse import
|
| 26 |
+
from llama_cpp import Llama
|
| 27 |
+
from llama_cpp_agent.llm_agent import LlamaCppAgent
|
| 28 |
+
from llama_cpp_agent.messages_formatter import MessagesFormatterType
|
| 29 |
+
from llama_cpp_agent.providers.llama_cpp_endpoint_provider import LlamaCppEndpointSettings
|
| 30 |
+
|
| 31 |
|
| 32 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
| 33 |
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
|
|
|
|
| 383 |
|
| 384 |
return truncate_text(context, max_tokens)
|
| 385 |
|
| 386 |
+
# Initialize LlamaCppAgent
|
| 387 |
+
def initialize_llama_cpp_agent():
|
| 388 |
+
main_model = LlamaCppEndpointSettings(
|
| 389 |
+
completions_endpoint_url="http://127.0.0.1:8080/completion"
|
| 390 |
+
)
|
| 391 |
+
llama_cpp_agent = LlamaCppAgent(
|
| 392 |
+
main_model,
|
| 393 |
+
debug_output=False,
|
| 394 |
+
system_prompt="You are an AI assistant designed to help with RAG tasks.",
|
| 395 |
+
predefined_messages_formatter_type=MessagesFormatterType.CHATML
|
| 396 |
+
)
|
| 397 |
+
return llama_cpp_agent
|
| 398 |
+
|
| 399 |
+
# Modify the ask_question function to use LlamaCppAgent
|
| 400 |
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
| 401 |
if not question:
|
| 402 |
return "Please enter a question."
|
| 403 |
|
| 404 |
+
llama_cpp_agent = initialize_llama_cpp_agent()
|
| 405 |
model = get_model(temperature, top_p, repetition_penalty)
|
| 406 |
|
| 407 |
# Update the chatbot's model
|
|
|
|
| 415 |
database = None
|
| 416 |
|
| 417 |
max_attempts = 3
|
| 418 |
+
max_input_tokens = 20000
|
| 419 |
max_output_tokens = 800
|
| 420 |
|
| 421 |
if web_search:
|
| 422 |
contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
|
| 423 |
|
|
|
|
|
|
|
|
|
|
| 424 |
try:
|
| 425 |
+
search_results = google_search(contextualized_question, num_results=5)
|
| 426 |
except Exception as e:
|
| 427 |
print(f"Error in web search: {e}")
|
| 428 |
return f"I apologize, but I encountered an error while searching for information: {str(e)}"
|
|
|
|
| 443 |
|
| 444 |
database.save_local("faiss_database")
|
| 445 |
|
| 446 |
+
context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2)
|
|
|
|
| 447 |
|
| 448 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 449 |
|
|
|
|
| 459 |
Provide a concise and relevant answer to the question.
|
| 460 |
"""
|
| 461 |
|
| 462 |
+
current_conv_context = truncate_text(chatbot.get_context(), max_input_tokens // 4)
|
| 463 |
+
current_topics = topics[:5]
|
| 464 |
+
current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()}
|
|
|
|
|
|
|
| 465 |
|
| 466 |
+
formatted_prompt = prompt_template.format(
|
| 467 |
context=context_str,
|
| 468 |
conv_context=current_conv_context,
|
| 469 |
question=question,
|
|
|
|
| 475 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
| 476 |
|
| 477 |
try:
|
| 478 |
+
# Use LlamaCppAgent for initial response generation
|
| 479 |
+
initial_response = llama_cpp_agent.get_chat_response(formatted_prompt, temperature=temperature)
|
| 480 |
+
|
| 481 |
+
# Use generate_chunked_response for further refinement if needed
|
| 482 |
+
full_response = generate_chunked_response(model, initial_response, max_tokens=max_output_tokens)
|
| 483 |
+
|
| 484 |
answer = extract_answer(full_response, user_instructions)
|
| 485 |
all_answers.append(answer)
|
| 486 |
break
|
| 487 |
except Exception as e:
|
| 488 |
+
print(f"Error in response generation: {e}")
|
| 489 |
if attempt == max_attempts - 1:
|
| 490 |
all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
|
| 491 |
|
|
|
|
| 509 |
if database is None:
|
| 510 |
return "No documents available. Please upload PDF documents to answer questions."
|
| 511 |
|
| 512 |
+
retriever = database.as_retriever(search_kwargs={"k": 5})
|
| 513 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 514 |
|
| 515 |
+
context_str = prepare_context(question, relevant_docs, max_input_tokens // 2)
|
|
|
|
| 516 |
|
| 517 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 518 |
|
|
|
|
| 525 |
Provide a summarized and direct answer to the question.
|
| 526 |
"""
|
| 527 |
|
| 528 |
+
formatted_prompt = prompt_template.format(context=context_str, question=question)
|
|
|
|
| 529 |
|
| 530 |
if estimate_tokens(formatted_prompt) > max_input_tokens:
|
| 531 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
| 532 |
|
| 533 |
try:
|
| 534 |
+
# Use LlamaCppAgent for initial response generation
|
| 535 |
+
initial_response = llama_cpp_agent.get_chat_response(formatted_prompt, temperature=temperature)
|
| 536 |
+
|
| 537 |
+
# Use generate_chunked_response for further refinement if needed
|
| 538 |
+
full_response = generate_chunked_response(model, initial_response, max_tokens=max_output_tokens)
|
| 539 |
+
|
| 540 |
answer = extract_answer(full_response, user_instructions)
|
| 541 |
return answer
|
| 542 |
except Exception as e:
|
| 543 |
+
print(f"Error in response generation: {e}")
|
| 544 |
if attempt == max_attempts - 1:
|
| 545 |
return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
|
| 546 |
|
|
|
|
| 613 |
|
| 614 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
| 615 |
|
| 616 |
+
# Update the chat function to use the modified ask_question function
|
| 617 |
def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
|
| 618 |
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
|
| 619 |
history.append((question, answer))
|
| 620 |
return "", history
|
| 621 |
|
| 622 |
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
|
| 623 |
+
|
| 624 |
clear_button = gr.Button("Clear Cache")
|
| 625 |
clear_output = gr.Textbox(label="Cache Status")
|
| 626 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|