Update main.py
Browse files
main.py
CHANGED
|
@@ -568,35 +568,97 @@ async def query_system(
|
|
| 568 |
# Get session data
|
| 569 |
session = sessions[session_id]
|
| 570 |
retriever = session["retriever"]
|
| 571 |
-
|
|
|
|
|
|
|
| 572 |
|
| 573 |
# Create chain
|
| 574 |
chain = create_chain(retriever)
|
| 575 |
|
| 576 |
-
#
|
| 577 |
messages = chat_history.messages
|
| 578 |
-
langchain_chat_history = [(messages[i].content, messages[i+1].content)
|
| 579 |
-
for i in range(0, len(messages)-1, 2) if i+1 < len(messages)]
|
| 580 |
|
| 581 |
-
#
|
| 582 |
-
|
| 583 |
|
| 584 |
-
#
|
| 585 |
-
|
| 586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
-
|
| 589 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 590 |
|
| 591 |
-
|
| 592 |
-
"
|
| 593 |
-
|
| 594 |
-
"
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
| 598 |
raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}")
|
| 599 |
|
|
|
|
| 600 |
@app.get("/sessions", response_model=List[Dict[str, Any]])
|
| 601 |
async def get_user_sessions(current_user: User = Depends(get_current_user)):
|
| 602 |
"""
|
|
|
|
| 568 |
# Get session data
|
| 569 |
session = sessions[session_id]
|
| 570 |
retriever = session["retriever"]
|
| 571 |
+
|
| 572 |
+
# Get or initialize chat history
|
| 573 |
+
chat_history = chat_manager.initialize_chat_history(session_id)
|
| 574 |
|
| 575 |
# Create chain
|
| 576 |
chain = create_chain(retriever)
|
| 577 |
|
| 578 |
+
# Extract messages properly for LangChain format
|
| 579 |
messages = chat_history.messages
|
|
|
|
|
|
|
| 580 |
|
| 581 |
+
# Process chat history safely
|
| 582 |
+
langchain_chat_history = []
|
| 583 |
|
| 584 |
+
# If messages exist, process them
|
| 585 |
+
if messages:
|
| 586 |
+
# Group messages by pairs (user, AI)
|
| 587 |
+
# This approach is safer than assuming perfect alternating pattern
|
| 588 |
+
i = 0
|
| 589 |
+
while i < len(messages) - 1:
|
| 590 |
+
user_message = messages[i].content
|
| 591 |
+
ai_message = messages[i+1].content
|
| 592 |
+
langchain_chat_history.append((user_message, ai_message))
|
| 593 |
+
i += 2
|
| 594 |
+
|
| 595 |
+
# Add debugging information
|
| 596 |
+
print(f"Chat history length: {len(langchain_chat_history)}")
|
| 597 |
+
print(f"Query: {request.query}")
|
| 598 |
|
| 599 |
+
try:
|
| 600 |
+
# Query the chain
|
| 601 |
+
result = chain.invoke({
|
| 602 |
+
"question": request.query,
|
| 603 |
+
"chat_history": langchain_chat_history
|
| 604 |
+
})
|
| 605 |
+
|
| 606 |
+
# Extract answer from result
|
| 607 |
+
answer = result.get("answer", "I couldn't find an answer to your question.")
|
| 608 |
+
|
| 609 |
+
# Update chat history
|
| 610 |
+
chat_history.add_user_message(request.query)
|
| 611 |
+
chat_history.add_ai_message(answer)
|
| 612 |
+
|
| 613 |
+
# Prepare source documents with proper error handling
|
| 614 |
+
source_docs = []
|
| 615 |
+
if "source_documents" in result and result["source_documents"]:
|
| 616 |
+
for doc in result["source_documents"]:
|
| 617 |
+
try:
|
| 618 |
+
# Different LangChain versions might structure documents differently
|
| 619 |
+
if hasattr(doc, 'page_content'):
|
| 620 |
+
# Regular Document object
|
| 621 |
+
content = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
|
| 622 |
+
source_docs.append(content)
|
| 623 |
+
elif isinstance(doc, dict) and 'page_content' in doc:
|
| 624 |
+
# Dictionary format
|
| 625 |
+
content = doc['page_content'][:100] + "..." if len(doc['page_content']) > 100 else doc['page_content']
|
| 626 |
+
source_docs.append(content)
|
| 627 |
+
elif isinstance(doc, str):
|
| 628 |
+
# String format
|
| 629 |
+
content = doc[:100] + "..." if len(doc) > 100 else doc
|
| 630 |
+
source_docs.append(content)
|
| 631 |
+
except Exception as doc_error:
|
| 632 |
+
print(f"Error processing source document: {str(doc_error)}")
|
| 633 |
+
|
| 634 |
+
return {
|
| 635 |
+
"answer": answer,
|
| 636 |
+
"session_id": session_id,
|
| 637 |
+
"source_documents": source_docs
|
| 638 |
+
}
|
| 639 |
|
| 640 |
+
except Exception as chain_error:
|
| 641 |
+
print(f"Chain invocation error: {str(chain_error)}")
|
| 642 |
+
# Provide a more graceful fallback
|
| 643 |
+
fallback_answer = "I apologize, but I encountered an error while processing your question. Please try rephrasing your query or asking about a different topic."
|
| 644 |
+
|
| 645 |
+
# Update chat history even in case of error
|
| 646 |
+
chat_history.add_user_message(request.query)
|
| 647 |
+
chat_history.add_ai_message(fallback_answer)
|
| 648 |
+
|
| 649 |
+
return {
|
| 650 |
+
"answer": fallback_answer,
|
| 651 |
+
"session_id": session_id,
|
| 652 |
+
"source_documents": []
|
| 653 |
+
}
|
| 654 |
|
| 655 |
except Exception as e:
|
| 656 |
+
print(f"Query system error: {str(e)}")
|
| 657 |
+
import traceback
|
| 658 |
+
traceback.print_exc()
|
| 659 |
raise HTTPException(status_code=500, detail=f"Error querying system: {str(e)}")
|
| 660 |
|
| 661 |
+
|
| 662 |
@app.get("/sessions", response_model=List[Dict[str, Any]])
|
| 663 |
async def get_user_sessions(current_user: User = Depends(get_current_user)):
|
| 664 |
"""
|