Business_Chatbot / src /chatbot.py
Ancastal's picture
Upload folder using huggingface_hub
401b16c verified
from typing import Dict, Any, Optional
from entity_extractor import EntityExtractor
from database_manager import DatabaseManager
from vector_store import VectorStore
from nl_to_sql import NaturalLanguageToSQL
from intent_classifier import IntentClassifier, IntentType
from rag_handler import RAGHandler
from transaction_clarifier import TransactionClarifier, ClarificationStatus
from models import ChatbotRequest, ChatbotResponse, PendingTransaction
class Chatbot:
def __init__(self):
self.entity_extractor = EntityExtractor()
self.db_manager = DatabaseManager()
self.vector_store = VectorStore()
self.nl_to_sql = NaturalLanguageToSQL()
self.intent_classifier = IntentClassifier()
self.rag_handler = RAGHandler()
self.transaction_clarifier = TransactionClarifier()
# Store pending transactions by session_id
self.pending_transactions: Dict[str, PendingTransaction] = {}
def process_message(self, request: ChatbotRequest) -> ChatbotResponse:
"""Process a user message and return appropriate response"""
message = request.message.strip()
session_id = request.session_id or "default"
# Check if we're waiting for clarification on a pending transaction
if session_id in self.pending_transactions:
print("A transaction is pending...")
return self._handle_transaction_clarification(message, session_id)
# Classify intent using OpenAI
intent_result = self.intent_classifier.classify_intent(message)
print(f"🎯 Intent: {intent_result.intent.value} (confidence: {intent_result.confidence:.2f})")
print(f"πŸ“ Reasoning: {intent_result.reasoning}")
# Route to appropriate handler based on classified intent
if intent_result.intent == IntentType.TRANSACTION:
response = self._handle_transaction_request(message, session_id)
elif intent_result.intent == IntentType.QUERY:
response = self._handle_query_request(message)
elif intent_result.intent == IntentType.SEMANTIC_SEARCH:
response = self._handle_search_request(message)
else: # GENERAL_INFO
response = self._handle_general_information(message)
# Add intent information to response
response.intent_detected = intent_result.intent.value
response.intent_confidence = intent_result.confidence
return response
def _handle_transaction_request(self, message: str, session_id: str) -> ChatbotResponse:
"""Handle transaction requests (purchases/sales) with interactive clarification"""
try:
# Extract entities
entities = self.entity_extractor.extract_entities(message)
# Check if transaction is complete
status, clarification = self.transaction_clarifier.analyze_transaction_completeness(entities)
if status == ClarificationStatus.COMPLETE:
# Transaction is complete, process it
return self._complete_transaction(entities, message)
elif status == ClarificationStatus.NEEDS_CLARIFICATION:
# Store pending transaction and ask for clarification
pending = PendingTransaction(
entities=entities,
missing_fields=clarification.missing_fields,
session_id=session_id,
original_message=message
)
self.pending_transactions[session_id] = pending
clarification_message = self.transaction_clarifier.format_clarification_message(clarification)
return ChatbotResponse(
response=clarification_message,
entities_extracted=entities,
awaiting_clarification=True
)
else:
return ChatbotResponse(
response="Transaction cancelled.",
entities_extracted=entities
)
except Exception as e:
return ChatbotResponse(
response=f"Error processing transaction: {str(e)}",
sql_executed=None,
entities_extracted=None,
vector_stored=False
)
def _complete_transaction(self, entities, original_message: str) -> ChatbotResponse:
"""Complete a transaction with all required information"""
try:
# Process transaction in database and get the SQL transaction ID
transaction_id, result_message = self.db_manager.process_transaction(entities)
# Store in vector store with SQL transaction ID for linking
transaction_data = {
"type": entities.transaction_type,
"product": entities.product,
"quantity": entities.quantity,
"supplier": entities.supplier,
"customer": entities.customer,
"unit_price": entities.unit_price,
"total": entities.total_amount
}
vector_stored = self.vector_store.add_transaction_event(
transaction_data,
original_message,
sql_transaction_id=transaction_id
)
return ChatbotResponse(
response=result_message,
sql_executed="Transaction processed successfully",
entities_extracted=entities,
vector_stored=vector_stored
)
except Exception as e:
return ChatbotResponse(
response=f"Error completing transaction: {str(e)}",
entities_extracted=entities
)
def _handle_transaction_clarification(self, message: str, session_id: str) -> ChatbotResponse:
"""Handle user response to transaction clarification questions"""
try:
pending = self.pending_transactions.get(session_id)
if not pending:
return ChatbotResponse(
response="No pending transaction found. Please start a new transaction."
)
# Check if user wants to cancel
if message.lower() in ['cancel', 'quit', 'stop', 'abort']:
del self.pending_transactions[session_id]
return ChatbotResponse(
response="Transaction cancelled. You can start a new one anytime."
)
# Add this clarification response to the accumulated responses
pending.clarification_responses.append(message)
# Process the clarification response
updated_entities, is_complete = self.transaction_clarifier.process_clarification_response(
pending.entities,
pending.missing_fields,
message
)
if is_complete:
# Transaction is now complete
# Combine original message with all clarification responses for complete context
clarifications = "\n".join([f"Clarification {i+1}: {resp}" for i, resp in enumerate(pending.clarification_responses)])
full_context = f"{pending.original_message}\n\n{clarifications}"
del self.pending_transactions[session_id]
return self._complete_transaction(updated_entities, full_context)
else:
# Still need more information
status, clarification = self.transaction_clarifier.analyze_transaction_completeness(updated_entities)
if status == ClarificationStatus.NEEDS_CLARIFICATION:
# Update the pending transaction
pending.entities = updated_entities
pending.missing_fields = clarification.missing_fields
clarification_message = self.transaction_clarifier.format_clarification_message(clarification)
return ChatbotResponse(
response=f"Thank you! I still need a bit more information:\n\n{clarification_message}",
entities_extracted=updated_entities,
awaiting_clarification=True
)
else:
# Something went wrong or was cancelled
# Still include all clarification context even if completion is unexpected
clarifications = "\n".join([f"Clarification {i+1}: {resp}" for i, resp in enumerate(pending.clarification_responses)])
full_context = f"{pending.original_message}\n\n{clarifications}"
del self.pending_transactions[session_id]
return self._complete_transaction(updated_entities, full_context)
except Exception as e:
# Clean up on error
if session_id in self.pending_transactions:
del self.pending_transactions[session_id]
return ChatbotResponse(
response=f"Error processing your response: {str(e)}. Please start a new transaction."
)
def _handle_query_request(self, message: str) -> ChatbotResponse:
"""Handle query requests using OpenAI LLM to generate SQL"""
try:
# Use OpenAI to convert natural language to SQL
sql_query, explanation = self.nl_to_sql.convert_to_sql(message)
# Validate the generated SQL
is_valid, validation_message = self.nl_to_sql.validate_sql(sql_query)
if not is_valid:
suggestion = self.nl_to_sql.suggest_corrections(message, validation_message)
return ChatbotResponse(
response=f"I couldn't process that query: {validation_message}\n\n{suggestion}",
sql_executed=sql_query
)
# Execute the SQL query
results = self.db_manager.query_data(sql_query)
# Format and return results
if not results:
return ChatbotResponse(
response="No results found for your query.",
sql_executed=sql_query
)
# Check for error in results
if len(results) == 1 and "error" in results[0]:
return ChatbotResponse(
response=f"Query execution error: {results[0]['error']}\n\nGenerated SQL: {sql_query}",
sql_executed=sql_query
)
# Format successful results
formatted_response = self._format_sql_results(results, explanation)
return ChatbotResponse(
response=formatted_response,
sql_executed=sql_query
)
except Exception as e:
return ChatbotResponse(response=f"Error processing query: {str(e)}")
def _handle_search_request(self, message: str) -> ChatbotResponse:
"""Handle semantic search requests using RAG"""
try:
# Enhance the search query for better retrieval
enhanced_query = self.rag_handler.enhance_search_query(message)
print(f"πŸ” Enhanced query: {enhanced_query}")
# Search vector store for similar events
results = self.vector_store.search_similar_events(enhanced_query, 8)
if not results:
return ChatbotResponse(response="I couldn't find any relevant information to answer your query.")
# Use RAG to generate an intelligent response
rag_response = self.rag_handler.generate_rag_response(message, results)
return ChatbotResponse(
response=rag_response,
vector_stored=False
)
except Exception as e:
return ChatbotResponse(response=f"Error processing your search: {str(e)}")
def _handle_general_information(self, message: str) -> ChatbotResponse:
"""Handle general information storage"""
try:
# Store in vector store
stored = self.vector_store.add_general_event(message, "general_info")
if stored:
return ChatbotResponse(
response="Information stored successfully. I can help you find similar information later.",
vector_stored=True
)
else:
return ChatbotResponse(
response="Information noted, but vector storage is not available.",
vector_stored=False
)
except Exception as e:
return ChatbotResponse(response=f"Error storing information: {str(e)}")
def _format_recent_transactions(self, data: Dict[str, list]) -> str:
"""Format recent transactions for display"""
response = "Recent Transactions:\n\n"
# Combine and sort all transactions
all_transactions = []
for purchase in data.get("purchases", []):
all_transactions.append(purchase)
for sale in data.get("sales", []):
all_transactions.append(sale)
# Sort by date
all_transactions.sort(key=lambda x: x.get("date", ""), reverse=True)
if not all_transactions:
return "No recent transactions found."
for transaction in all_transactions[:10]: # Show top 10
trans_type = transaction.get("type", "unknown").upper()
date = transaction.get("date", "")[:10] # Just the date part
if trans_type == "PURCHASE":
response += f"πŸ›’ {date} - PURCHASE: {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} from {transaction.get('supplier', 'Unknown')} - €{transaction.get('total_cost', 0)}\n"
else:
response += f"πŸ’° {date} - SALE: {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} to {transaction.get('customer', 'Unknown')} - €{transaction.get('total_amount', 0)}\n"
return response
def _format_search_results(self, results: list, search_term: str) -> str:
"""Format search results for display"""
if not results:
return f"No transactions found for '{search_term}'."
response = f"Found {len(results)} transaction(s) for '{search_term}':\n\n"
for transaction in results:
trans_type = transaction.get("type", "unknown").upper()
date = transaction.get("date", "")[:10]
if trans_type == "PURCHASE":
response += f"πŸ›’ {date} - {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} from {transaction.get('supplier', 'Unknown')} - €{transaction.get('total', 0)}\n"
else:
response += f"πŸ’° {date} - {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} to {transaction.get('customer', 'Unknown')} - €{transaction.get('total', 0)}\n"
return response
def _format_sql_results(self, results: list, explanation: str) -> str:
"""Format SQL query results for display"""
response = f"πŸ“Š Query Results:\n{explanation}\n\n"
if not results:
return response + "No data found."
# Handle single value results (like COUNT, SUM)
if len(results) == 1 and len(results[0]) == 1:
key, value = list(results[0].items())[0]
return response + f"**{key.replace('_', ' ').title()}:** {value}"
# Handle multiple rows
response += "```\n"
# Add headers
if results:
headers = list(results[0].keys())
response += " | ".join(f"{header.replace('_', ' ').title():<15}" for header in headers) + "\n"
response += "-" * (len(headers) * 17) + "\n"
# Add data rows
for row in results[:20]: # Limit to first 20 rows
formatted_row = []
for value in row.values():
if value is None:
formatted_row.append("N/A".ljust(15))
elif isinstance(value, float):
formatted_row.append(f"{value:.2f}".ljust(15))
else:
formatted_row.append(str(value)[:15].ljust(15))
response += " | ".join(formatted_row) + "\n"
if len(results) > 20:
response += f"\n... and {len(results) - 20} more rows\n"
response += "```"
return response
def get_linked_transaction_data(self, sql_transaction_id: int, transaction_type: str) -> Optional[Dict[str, Any]]:
"""Retrieve complete transaction data from both SQL and vector stores"""
try:
# Get SQL data
sql_data = self.db_manager.get_transaction_by_id(sql_transaction_id, transaction_type)
# Get vector store data
vector_data = self.vector_store.get_transaction_by_sql_id(sql_transaction_id, transaction_type)
if sql_data:
result = {
"sql_data": sql_data,
"vector_data": vector_data,
"linked": vector_data is not None
}
return result
return None
except Exception as e:
print(f"Error retrieving linked transaction data: {e}")
return None
def close(self):
"""Clean up resources"""
self.db_manager.close()