import chromadb from sentence_transformers import SentenceTransformer from typing import List, Dict, Any, Optional import json from datetime import datetime class VectorStore: def __init__(self, collection_name: str = "chatbot_events"): self.client = chromadb.PersistentClient(path="./chroma_db") self.collection = self.client.get_or_create_collection(name=collection_name) try: self.model = SentenceTransformer('all-MiniLM-L6-v2') except Exception as e: print(f"Warning: Could not load sentence transformer model: {e}") self.model = None def add_transaction_event(self, transaction_data: Dict[str, Any], user_query: str, sql_transaction_id: Optional[int] = None) -> bool: """Add a transaction event to the vector store""" if not self.model: return False try: # Create a semantic summary of the event summary = self._create_event_summary(transaction_data, user_query) # Generate embedding embedding = self.model.encode(summary).tolist() # Create document ID - include SQL ID if available for better linking doc_id = f"transaction_{sql_transaction_id or 'unknown'}_{datetime.now().isoformat()}_{hash(summary) % 10000}" # Prepare metadata with SQL transaction linking metadata = { "type": "transaction", "transaction_type": transaction_data.get("type", "unknown"), "timestamp": datetime.now().isoformat(), "user_query": user_query, "data": json.dumps(transaction_data) } # Add SQL transaction ID to metadata for linking if sql_transaction_id is not None: metadata["sql_transaction_id"] = sql_transaction_id metadata["sql_table"] = f"{transaction_data.get('type', 'unknown')}s" # purchases or sales # Store in vector database self.collection.add( documents=[summary], embeddings=[embedding], metadatas=[metadata], ids=[doc_id] ) return True except Exception as e: print(f"Error adding transaction event: {e}") return False def get_transaction_by_sql_id(self, sql_transaction_id: int, transaction_type: str) -> Optional[Dict[str, Any]]: """Retrieve vector store entry linked to a specific SQL transaction ID""" try: # Query the collection for entries with matching SQL transaction ID results = self.collection.get( where={ "sql_transaction_id": sql_transaction_id, "transaction_type": transaction_type }, limit=1 ) if results and results['documents']: return { "id": results['ids'][0], "document": results['documents'][0], "metadata": results['metadatas'][0] } return None except Exception as e: print(f"Error retrieving transaction by SQL ID: {e}") return None def add_general_event(self, event_text: str, event_type: str = "general") -> bool: """Add a general event or information to the vector store""" if not self.model: return False try: # Generate embedding embedding = self.model.encode(event_text).tolist() # Create document ID doc_id = f"event_{datetime.now().isoformat()}_{hash(event_text) % 10000}" # Store in vector database self.collection.add( documents=[event_text], embeddings=[embedding], metadatas=[{ "type": event_type, "timestamp": datetime.now().isoformat() }], ids=[doc_id] ) return True except Exception as e: print(f"Error adding general event: {e}") return False def search_similar_events(self, query: str, n_results: int = 5) -> List[Dict[str, Any]]: """Search for similar events based on semantic similarity""" if not self.model: return [] try: # Generate query embedding query_embedding = self.model.encode(query).tolist() # Search vector database results = self.collection.query( query_embeddings=[query_embedding], n_results=n_results ) # Format results formatted_results = [] if results['documents'] and results['documents'][0]: for i, doc in enumerate(results['documents'][0]): result = { "document": doc, "distance": results['distances'][0][i] if results['distances'] else None, "metadata": results['metadatas'][0][i] if results['metadatas'] else {} } formatted_results.append(result) return formatted_results except Exception as e: print(f"Error searching events: {e}") return [] def get_recent_events(self, n_results: int = 10) -> List[Dict[str, Any]]: """Get recent events from the vector store""" try: results = self.collection.get( limit=n_results, include=["documents", "metadatas"] ) formatted_results = [] if results['documents']: for i, doc in enumerate(results['documents']): result = { "document": doc, "metadata": results['metadatas'][i] if results['metadatas'] else {} } formatted_results.append(result) # Sort by timestamp if available formatted_results.sort( key=lambda x: x.get('metadata', {}).get('timestamp', ''), reverse=True ) return formatted_results except Exception as e: print(f"Error getting recent events: {e}") return [] def _create_event_summary(self, transaction_data: Dict[str, Any], user_query: str) -> str: """Create a semantic summary of a transaction event""" summary_parts = [] # Add transaction type trans_type = transaction_data.get("type", "transaction") summary_parts.append(f"Business {trans_type} event:") # Add key details if "product" in transaction_data: summary_parts.append(f"Product: {transaction_data['product']}") if "quantity" in transaction_data: summary_parts.append(f"Quantity: {transaction_data['quantity']}") if "supplier" in transaction_data: summary_parts.append(f"Supplier: {transaction_data['supplier']}") if "customer" in transaction_data: summary_parts.append(f"Customer: {transaction_data['customer']}") if "total" in transaction_data: summary_parts.append(f"Total amount: €{transaction_data['total']}") # Add original user query for context summary_parts.append(f"Original request: {user_query}") return " | ".join(summary_parts) def delete_collection(self): """Delete the entire collection (use with caution)""" try: self.client.delete_collection(name=self.collection.name) return True except Exception as e: print(f"Error deleting collection: {e}") return False def get_collection_count(self) -> int: """Get the number of documents in the collection""" try: return self.collection.count() except Exception as e: print(f"Error getting collection count: {e}") return 0