Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, Optional | |
| from entity_extractor import EntityExtractor | |
| from database_manager import DatabaseManager | |
| from vector_store import VectorStore | |
| from nl_to_sql import NaturalLanguageToSQL | |
| from intent_classifier import IntentClassifier, IntentType | |
| from rag_handler import RAGHandler | |
| from transaction_clarifier import TransactionClarifier, ClarificationStatus | |
| from models import ChatbotRequest, ChatbotResponse, PendingTransaction | |
| class Chatbot: | |
| def __init__(self): | |
| self.entity_extractor = EntityExtractor() | |
| self.db_manager = DatabaseManager() | |
| self.vector_store = VectorStore() | |
| self.nl_to_sql = NaturalLanguageToSQL() | |
| self.intent_classifier = IntentClassifier() | |
| self.rag_handler = RAGHandler() | |
| self.transaction_clarifier = TransactionClarifier() | |
| # Store pending transactions by session_id | |
| self.pending_transactions: Dict[str, PendingTransaction] = {} | |
| def process_message(self, request: ChatbotRequest) -> ChatbotResponse: | |
| """Process a user message and return appropriate response""" | |
| message = request.message.strip() | |
| session_id = request.session_id or "default" | |
| # Check if we're waiting for clarification on a pending transaction | |
| if session_id in self.pending_transactions: | |
| print("A transaction is pending...") | |
| return self._handle_transaction_clarification(message, session_id) | |
| # Classify intent using OpenAI | |
| intent_result = self.intent_classifier.classify_intent(message) | |
| print(f"π― Intent: {intent_result.intent.value} (confidence: {intent_result.confidence:.2f})") | |
| print(f"π Reasoning: {intent_result.reasoning}") | |
| # Route to appropriate handler based on classified intent | |
| if intent_result.intent == IntentType.TRANSACTION: | |
| response = self._handle_transaction_request(message, session_id) | |
| elif intent_result.intent == IntentType.QUERY: | |
| response = self._handle_query_request(message) | |
| elif intent_result.intent == IntentType.SEMANTIC_SEARCH: | |
| response = self._handle_search_request(message) | |
| else: # GENERAL_INFO | |
| response = self._handle_general_information(message) | |
| # Add intent information to response | |
| response.intent_detected = intent_result.intent.value | |
| response.intent_confidence = intent_result.confidence | |
| return response | |
| def _handle_transaction_request(self, message: str, session_id: str) -> ChatbotResponse: | |
| """Handle transaction requests (purchases/sales) with interactive clarification""" | |
| try: | |
| # Extract entities | |
| entities = self.entity_extractor.extract_entities(message) | |
| # Check if transaction is complete | |
| status, clarification = self.transaction_clarifier.analyze_transaction_completeness(entities) | |
| if status == ClarificationStatus.COMPLETE: | |
| # Transaction is complete, process it | |
| return self._complete_transaction(entities, message) | |
| elif status == ClarificationStatus.NEEDS_CLARIFICATION: | |
| # Store pending transaction and ask for clarification | |
| pending = PendingTransaction( | |
| entities=entities, | |
| missing_fields=clarification.missing_fields, | |
| session_id=session_id, | |
| original_message=message | |
| ) | |
| self.pending_transactions[session_id] = pending | |
| clarification_message = self.transaction_clarifier.format_clarification_message(clarification) | |
| return ChatbotResponse( | |
| response=clarification_message, | |
| entities_extracted=entities, | |
| awaiting_clarification=True | |
| ) | |
| else: | |
| return ChatbotResponse( | |
| response="Transaction cancelled.", | |
| entities_extracted=entities | |
| ) | |
| except Exception as e: | |
| return ChatbotResponse( | |
| response=f"Error processing transaction: {str(e)}", | |
| sql_executed=None, | |
| entities_extracted=None, | |
| vector_stored=False | |
| ) | |
| def _complete_transaction(self, entities, original_message: str) -> ChatbotResponse: | |
| """Complete a transaction with all required information""" | |
| try: | |
| # Process transaction in database and get the SQL transaction ID | |
| transaction_id, result_message = self.db_manager.process_transaction(entities) | |
| # Store in vector store with SQL transaction ID for linking | |
| transaction_data = { | |
| "type": entities.transaction_type, | |
| "product": entities.product, | |
| "quantity": entities.quantity, | |
| "supplier": entities.supplier, | |
| "customer": entities.customer, | |
| "unit_price": entities.unit_price, | |
| "total": entities.total_amount | |
| } | |
| vector_stored = self.vector_store.add_transaction_event( | |
| transaction_data, | |
| original_message, | |
| sql_transaction_id=transaction_id | |
| ) | |
| return ChatbotResponse( | |
| response=result_message, | |
| sql_executed="Transaction processed successfully", | |
| entities_extracted=entities, | |
| vector_stored=vector_stored | |
| ) | |
| except Exception as e: | |
| return ChatbotResponse( | |
| response=f"Error completing transaction: {str(e)}", | |
| entities_extracted=entities | |
| ) | |
| def _handle_transaction_clarification(self, message: str, session_id: str) -> ChatbotResponse: | |
| """Handle user response to transaction clarification questions""" | |
| try: | |
| pending = self.pending_transactions.get(session_id) | |
| if not pending: | |
| return ChatbotResponse( | |
| response="No pending transaction found. Please start a new transaction." | |
| ) | |
| # Check if user wants to cancel | |
| if message.lower() in ['cancel', 'quit', 'stop', 'abort']: | |
| del self.pending_transactions[session_id] | |
| return ChatbotResponse( | |
| response="Transaction cancelled. You can start a new one anytime." | |
| ) | |
| # Add this clarification response to the accumulated responses | |
| pending.clarification_responses.append(message) | |
| # Process the clarification response | |
| updated_entities, is_complete = self.transaction_clarifier.process_clarification_response( | |
| pending.entities, | |
| pending.missing_fields, | |
| message | |
| ) | |
| if is_complete: | |
| # Transaction is now complete | |
| # Combine original message with all clarification responses for complete context | |
| clarifications = "\n".join([f"Clarification {i+1}: {resp}" for i, resp in enumerate(pending.clarification_responses)]) | |
| full_context = f"{pending.original_message}\n\n{clarifications}" | |
| del self.pending_transactions[session_id] | |
| return self._complete_transaction(updated_entities, full_context) | |
| else: | |
| # Still need more information | |
| status, clarification = self.transaction_clarifier.analyze_transaction_completeness(updated_entities) | |
| if status == ClarificationStatus.NEEDS_CLARIFICATION: | |
| # Update the pending transaction | |
| pending.entities = updated_entities | |
| pending.missing_fields = clarification.missing_fields | |
| clarification_message = self.transaction_clarifier.format_clarification_message(clarification) | |
| return ChatbotResponse( | |
| response=f"Thank you! I still need a bit more information:\n\n{clarification_message}", | |
| entities_extracted=updated_entities, | |
| awaiting_clarification=True | |
| ) | |
| else: | |
| # Something went wrong or was cancelled | |
| # Still include all clarification context even if completion is unexpected | |
| clarifications = "\n".join([f"Clarification {i+1}: {resp}" for i, resp in enumerate(pending.clarification_responses)]) | |
| full_context = f"{pending.original_message}\n\n{clarifications}" | |
| del self.pending_transactions[session_id] | |
| return self._complete_transaction(updated_entities, full_context) | |
| except Exception as e: | |
| # Clean up on error | |
| if session_id in self.pending_transactions: | |
| del self.pending_transactions[session_id] | |
| return ChatbotResponse( | |
| response=f"Error processing your response: {str(e)}. Please start a new transaction." | |
| ) | |
| def _handle_query_request(self, message: str) -> ChatbotResponse: | |
| """Handle query requests using OpenAI LLM to generate SQL""" | |
| try: | |
| # Use OpenAI to convert natural language to SQL | |
| sql_query, explanation = self.nl_to_sql.convert_to_sql(message) | |
| # Validate the generated SQL | |
| is_valid, validation_message = self.nl_to_sql.validate_sql(sql_query) | |
| if not is_valid: | |
| suggestion = self.nl_to_sql.suggest_corrections(message, validation_message) | |
| return ChatbotResponse( | |
| response=f"I couldn't process that query: {validation_message}\n\n{suggestion}", | |
| sql_executed=sql_query | |
| ) | |
| # Execute the SQL query | |
| results = self.db_manager.query_data(sql_query) | |
| # Format and return results | |
| if not results: | |
| return ChatbotResponse( | |
| response="No results found for your query.", | |
| sql_executed=sql_query | |
| ) | |
| # Check for error in results | |
| if len(results) == 1 and "error" in results[0]: | |
| return ChatbotResponse( | |
| response=f"Query execution error: {results[0]['error']}\n\nGenerated SQL: {sql_query}", | |
| sql_executed=sql_query | |
| ) | |
| # Format successful results | |
| formatted_response = self._format_sql_results(results, explanation) | |
| return ChatbotResponse( | |
| response=formatted_response, | |
| sql_executed=sql_query | |
| ) | |
| except Exception as e: | |
| return ChatbotResponse(response=f"Error processing query: {str(e)}") | |
| def _handle_search_request(self, message: str) -> ChatbotResponse: | |
| """Handle semantic search requests using RAG""" | |
| try: | |
| # Enhance the search query for better retrieval | |
| enhanced_query = self.rag_handler.enhance_search_query(message) | |
| print(f"π Enhanced query: {enhanced_query}") | |
| # Search vector store for similar events | |
| results = self.vector_store.search_similar_events(enhanced_query, 8) | |
| if not results: | |
| return ChatbotResponse(response="I couldn't find any relevant information to answer your query.") | |
| # Use RAG to generate an intelligent response | |
| rag_response = self.rag_handler.generate_rag_response(message, results) | |
| return ChatbotResponse( | |
| response=rag_response, | |
| vector_stored=False | |
| ) | |
| except Exception as e: | |
| return ChatbotResponse(response=f"Error processing your search: {str(e)}") | |
| def _handle_general_information(self, message: str) -> ChatbotResponse: | |
| """Handle general information storage""" | |
| try: | |
| # Store in vector store | |
| stored = self.vector_store.add_general_event(message, "general_info") | |
| if stored: | |
| return ChatbotResponse( | |
| response="Information stored successfully. I can help you find similar information later.", | |
| vector_stored=True | |
| ) | |
| else: | |
| return ChatbotResponse( | |
| response="Information noted, but vector storage is not available.", | |
| vector_stored=False | |
| ) | |
| except Exception as e: | |
| return ChatbotResponse(response=f"Error storing information: {str(e)}") | |
| def _format_recent_transactions(self, data: Dict[str, list]) -> str: | |
| """Format recent transactions for display""" | |
| response = "Recent Transactions:\n\n" | |
| # Combine and sort all transactions | |
| all_transactions = [] | |
| for purchase in data.get("purchases", []): | |
| all_transactions.append(purchase) | |
| for sale in data.get("sales", []): | |
| all_transactions.append(sale) | |
| # Sort by date | |
| all_transactions.sort(key=lambda x: x.get("date", ""), reverse=True) | |
| if not all_transactions: | |
| return "No recent transactions found." | |
| for transaction in all_transactions[:10]: # Show top 10 | |
| trans_type = transaction.get("type", "unknown").upper() | |
| date = transaction.get("date", "")[:10] # Just the date part | |
| if trans_type == "PURCHASE": | |
| response += f"π {date} - PURCHASE: {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} from {transaction.get('supplier', 'Unknown')} - β¬{transaction.get('total_cost', 0)}\n" | |
| else: | |
| response += f"π° {date} - SALE: {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} to {transaction.get('customer', 'Unknown')} - β¬{transaction.get('total_amount', 0)}\n" | |
| return response | |
| def _format_search_results(self, results: list, search_term: str) -> str: | |
| """Format search results for display""" | |
| if not results: | |
| return f"No transactions found for '{search_term}'." | |
| response = f"Found {len(results)} transaction(s) for '{search_term}':\n\n" | |
| for transaction in results: | |
| trans_type = transaction.get("type", "unknown").upper() | |
| date = transaction.get("date", "")[:10] | |
| if trans_type == "PURCHASE": | |
| response += f"π {date} - {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} from {transaction.get('supplier', 'Unknown')} - β¬{transaction.get('total', 0)}\n" | |
| else: | |
| response += f"π° {date} - {transaction.get('quantity', 0)}x {transaction.get('product', 'Unknown')} to {transaction.get('customer', 'Unknown')} - β¬{transaction.get('total', 0)}\n" | |
| return response | |
| def _format_sql_results(self, results: list, explanation: str) -> str: | |
| """Format SQL query results for display""" | |
| response = f"π Query Results:\n{explanation}\n\n" | |
| if not results: | |
| return response + "No data found." | |
| # Handle single value results (like COUNT, SUM) | |
| if len(results) == 1 and len(results[0]) == 1: | |
| key, value = list(results[0].items())[0] | |
| return response + f"**{key.replace('_', ' ').title()}:** {value}" | |
| # Handle multiple rows | |
| response += "```\n" | |
| # Add headers | |
| if results: | |
| headers = list(results[0].keys()) | |
| response += " | ".join(f"{header.replace('_', ' ').title():<15}" for header in headers) + "\n" | |
| response += "-" * (len(headers) * 17) + "\n" | |
| # Add data rows | |
| for row in results[:20]: # Limit to first 20 rows | |
| formatted_row = [] | |
| for value in row.values(): | |
| if value is None: | |
| formatted_row.append("N/A".ljust(15)) | |
| elif isinstance(value, float): | |
| formatted_row.append(f"{value:.2f}".ljust(15)) | |
| else: | |
| formatted_row.append(str(value)[:15].ljust(15)) | |
| response += " | ".join(formatted_row) + "\n" | |
| if len(results) > 20: | |
| response += f"\n... and {len(results) - 20} more rows\n" | |
| response += "```" | |
| return response | |
| def get_linked_transaction_data(self, sql_transaction_id: int, transaction_type: str) -> Optional[Dict[str, Any]]: | |
| """Retrieve complete transaction data from both SQL and vector stores""" | |
| try: | |
| # Get SQL data | |
| sql_data = self.db_manager.get_transaction_by_id(sql_transaction_id, transaction_type) | |
| # Get vector store data | |
| vector_data = self.vector_store.get_transaction_by_sql_id(sql_transaction_id, transaction_type) | |
| if sql_data: | |
| result = { | |
| "sql_data": sql_data, | |
| "vector_data": vector_data, | |
| "linked": vector_data is not None | |
| } | |
| return result | |
| return None | |
| except Exception as e: | |
| print(f"Error retrieving linked transaction data: {e}") | |
| return None | |
| def close(self): | |
| """Clean up resources""" | |
| self.db_manager.close() |