Update app.py
Browse files
app.py
CHANGED
|
@@ -352,7 +352,7 @@ def estimate_tokens(text):
|
|
| 352 |
# Rough estimate: 1 token ~= 4 characters
|
| 353 |
return len(text) // 4
|
| 354 |
|
| 355 |
-
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
|
| 356 |
if not question:
|
| 357 |
return "Please enter a question."
|
| 358 |
|
|
@@ -368,16 +368,15 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 368 |
else:
|
| 369 |
database = None
|
| 370 |
|
| 371 |
-
max_attempts = 3
|
| 372 |
context_reduction_factor = 0.7
|
| 373 |
-
max_tokens = 32000
|
| 374 |
|
| 375 |
if web_search:
|
| 376 |
-
contextualized_question, topics, entity_tracker,
|
| 377 |
|
| 378 |
-
# Log the contextualized question and instructions separately for debugging
|
| 379 |
print(f"Contextualized question: {contextualized_question}")
|
| 380 |
-
print(f"Instructions: {
|
| 381 |
|
| 382 |
try:
|
| 383 |
search_results = google_search(contextualized_question, num_results=3)
|
|
@@ -403,7 +402,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 403 |
|
| 404 |
context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
|
| 405 |
|
| 406 |
-
instruction_prompt = f"User Instructions: {
|
| 407 |
|
| 408 |
prompt_template = f"""
|
| 409 |
Answer the question based on the following web search results, conversation context, entity information, and user instructions:
|
|
@@ -419,7 +418,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 419 |
|
| 420 |
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 421 |
|
| 422 |
-
# Start with full context and progressively reduce if necessary
|
| 423 |
current_context = context_str
|
| 424 |
current_conv_context = chatbot.get_context()
|
| 425 |
current_topics = topics
|
|
@@ -434,13 +432,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 434 |
entities=json.dumps(current_entities)
|
| 435 |
)
|
| 436 |
|
| 437 |
-
# Estimate token count (rough estimate)
|
| 438 |
estimated_tokens = len(formatted_prompt) // 4
|
| 439 |
|
| 440 |
-
if estimated_tokens <= max_tokens - 1000:
|
| 441 |
break
|
| 442 |
|
| 443 |
-
# Reduce context if estimated token count is too high
|
| 444 |
current_context = current_context[:int(len(current_context) * context_reduction_factor)]
|
| 445 |
current_conv_context = current_conv_context[:int(len(current_conv_context) * context_reduction_factor)]
|
| 446 |
current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
|
|
@@ -450,7 +446,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 450 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
| 451 |
|
| 452 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
| 453 |
-
answer = extract_answer(full_response,
|
| 454 |
all_answers.append(answer)
|
| 455 |
break
|
| 456 |
|
|
@@ -469,12 +465,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 469 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
| 470 |
answer += sources_section
|
| 471 |
|
| 472 |
-
# Update chatbot context with the answer
|
| 473 |
chatbot.add_to_history(answer)
|
| 474 |
|
| 475 |
return answer
|
| 476 |
|
| 477 |
-
|
| 478 |
for attempt in range(max_attempts):
|
| 479 |
try:
|
| 480 |
if database is None:
|
|
@@ -484,11 +479,14 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 484 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 485 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 486 |
|
| 487 |
-
|
|
|
|
|
|
|
| 488 |
Answer the question based on the following context from the PDF document:
|
| 489 |
Context:
|
| 490 |
-
{context}
|
| 491 |
-
Question: {question}
|
|
|
|
| 492 |
Provide a summarized and direct answer to the question.
|
| 493 |
"""
|
| 494 |
|
|
@@ -498,17 +496,16 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 498 |
|
| 499 |
estimated_tokens = estimate_tokens(formatted_prompt)
|
| 500 |
|
| 501 |
-
if estimated_tokens <= max_tokens - 1000:
|
| 502 |
break
|
| 503 |
|
| 504 |
-
# Reduce context if estimated token count is too high
|
| 505 |
context_str = context_str[:int(len(context_str) * context_reduction_factor)]
|
| 506 |
|
| 507 |
if len(context_str) < 100:
|
| 508 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
| 509 |
|
| 510 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
| 511 |
-
answer = extract_answer(full_response)
|
| 512 |
|
| 513 |
return answer
|
| 514 |
|
|
@@ -524,6 +521,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
| 524 |
|
| 525 |
return "An unexpected error occurred. Please try again later."
|
| 526 |
|
|
|
|
| 527 |
def extract_answer(full_response, instructions=None):
|
| 528 |
answer_patterns = [
|
| 529 |
r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
|
|
@@ -575,6 +573,7 @@ with gr.Blocks() as demo:
|
|
| 575 |
with gr.Column(scale=2):
|
| 576 |
chatbot = gr.Chatbot(label="Conversation")
|
| 577 |
question_input = gr.Textbox(label="Ask a question")
|
|
|
|
| 578 |
submit_button = gr.Button("Submit")
|
| 579 |
with gr.Column(scale=1):
|
| 580 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
|
@@ -584,12 +583,12 @@ with gr.Blocks() as demo:
|
|
| 584 |
|
| 585 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
| 586 |
|
| 587 |
-
def chat(question, history, temperature, top_p, repetition_penalty, web_search):
|
| 588 |
-
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot)
|
| 589 |
history.append((question, answer))
|
| 590 |
return "", history
|
| 591 |
|
| 592 |
-
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox], outputs=[question_input, chatbot])
|
| 593 |
|
| 594 |
clear_button = gr.Button("Clear Cache")
|
| 595 |
clear_output = gr.Textbox(label="Cache Status")
|
|
|
|
| 352 |
# Rough estimate: 1 token ~= 4 characters
|
| 353 |
return len(text) // 4
|
| 354 |
|
| 355 |
+
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
| 356 |
if not question:
|
| 357 |
return "Please enter a question."
|
| 358 |
|
|
|
|
| 368 |
else:
|
| 369 |
database = None
|
| 370 |
|
| 371 |
+
max_attempts = 3
|
| 372 |
context_reduction_factor = 0.7
|
| 373 |
+
max_tokens = 32000
|
| 374 |
|
| 375 |
if web_search:
|
| 376 |
+
contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
|
| 377 |
|
|
|
|
| 378 |
print(f"Contextualized question: {contextualized_question}")
|
| 379 |
+
print(f"User Instructions: {user_instructions}")
|
| 380 |
|
| 381 |
try:
|
| 382 |
search_results = google_search(contextualized_question, num_results=3)
|
|
|
|
| 402 |
|
| 403 |
context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
|
| 404 |
|
| 405 |
+
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 406 |
|
| 407 |
prompt_template = f"""
|
| 408 |
Answer the question based on the following web search results, conversation context, entity information, and user instructions:
|
|
|
|
| 418 |
|
| 419 |
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
| 420 |
|
|
|
|
| 421 |
current_context = context_str
|
| 422 |
current_conv_context = chatbot.get_context()
|
| 423 |
current_topics = topics
|
|
|
|
| 432 |
entities=json.dumps(current_entities)
|
| 433 |
)
|
| 434 |
|
|
|
|
| 435 |
estimated_tokens = len(formatted_prompt) // 4
|
| 436 |
|
| 437 |
+
if estimated_tokens <= max_tokens - 1000:
|
| 438 |
break
|
| 439 |
|
|
|
|
| 440 |
current_context = current_context[:int(len(current_context) * context_reduction_factor)]
|
| 441 |
current_conv_context = current_conv_context[:int(len(current_conv_context) * context_reduction_factor)]
|
| 442 |
current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
|
|
|
|
| 446 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
| 447 |
|
| 448 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
| 449 |
+
answer = extract_answer(full_response, user_instructions)
|
| 450 |
all_answers.append(answer)
|
| 451 |
break
|
| 452 |
|
|
|
|
| 465 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
| 466 |
answer += sources_section
|
| 467 |
|
|
|
|
| 468 |
chatbot.add_to_history(answer)
|
| 469 |
|
| 470 |
return answer
|
| 471 |
|
| 472 |
+
else: # PDF document chat
|
| 473 |
for attempt in range(max_attempts):
|
| 474 |
try:
|
| 475 |
if database is None:
|
|
|
|
| 479 |
relevant_docs = retriever.get_relevant_documents(question)
|
| 480 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
| 481 |
|
| 482 |
+
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
| 483 |
+
|
| 484 |
+
prompt_template = f"""
|
| 485 |
Answer the question based on the following context from the PDF document:
|
| 486 |
Context:
|
| 487 |
+
{{context}}
|
| 488 |
+
Question: {{question}}
|
| 489 |
+
{instruction_prompt}
|
| 490 |
Provide a summarized and direct answer to the question.
|
| 491 |
"""
|
| 492 |
|
|
|
|
| 496 |
|
| 497 |
estimated_tokens = estimate_tokens(formatted_prompt)
|
| 498 |
|
| 499 |
+
if estimated_tokens <= max_tokens - 1000:
|
| 500 |
break
|
| 501 |
|
|
|
|
| 502 |
context_str = context_str[:int(len(context_str) * context_reduction_factor)]
|
| 503 |
|
| 504 |
if len(context_str) < 100:
|
| 505 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
| 506 |
|
| 507 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
| 508 |
+
answer = extract_answer(full_response, user_instructions)
|
| 509 |
|
| 510 |
return answer
|
| 511 |
|
|
|
|
| 521 |
|
| 522 |
return "An unexpected error occurred. Please try again later."
|
| 523 |
|
| 524 |
+
|
| 525 |
def extract_answer(full_response, instructions=None):
|
| 526 |
answer_patterns = [
|
| 527 |
r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
|
|
|
|
| 573 |
with gr.Column(scale=2):
|
| 574 |
chatbot = gr.Chatbot(label="Conversation")
|
| 575 |
question_input = gr.Textbox(label="Ask a question")
|
| 576 |
+
instructions_input = gr.Textbox(label="Instructions for response (optional)", placeholder="Enter any specific instructions for the response here")
|
| 577 |
submit_button = gr.Button("Submit")
|
| 578 |
with gr.Column(scale=1):
|
| 579 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
|
|
|
| 583 |
|
| 584 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
| 585 |
|
| 586 |
+
def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
|
| 587 |
+
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
|
| 588 |
history.append((question, answer))
|
| 589 |
return "", history
|
| 590 |
|
| 591 |
+
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
|
| 592 |
|
| 593 |
clear_button = gr.Button("Clear Cache")
|
| 594 |
clear_output = gr.Textbox(label="Cache Status")
|