Spaces:
Sleeping
Sleeping
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Any, Optional | |
| import json | |
| from datetime import datetime | |
| class VectorStore: | |
| def __init__(self, collection_name: str = "chatbot_events"): | |
| self.client = chromadb.PersistentClient(path="./chroma_db") | |
| self.collection = self.client.get_or_create_collection(name=collection_name) | |
| try: | |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') | |
| except Exception as e: | |
| print(f"Warning: Could not load sentence transformer model: {e}") | |
| self.model = None | |
| def add_transaction_event(self, transaction_data: Dict[str, Any], user_query: str, sql_transaction_id: Optional[int] = None) -> bool: | |
| """Add a transaction event to the vector store""" | |
| if not self.model: | |
| return False | |
| try: | |
| # Create a semantic summary of the event | |
| summary = self._create_event_summary(transaction_data, user_query) | |
| # Generate embedding | |
| embedding = self.model.encode(summary).tolist() | |
| # Create document ID - include SQL ID if available for better linking | |
| doc_id = f"transaction_{sql_transaction_id or 'unknown'}_{datetime.now().isoformat()}_{hash(summary) % 10000}" | |
| # Prepare metadata with SQL transaction linking | |
| metadata = { | |
| "type": "transaction", | |
| "transaction_type": transaction_data.get("type", "unknown"), | |
| "timestamp": datetime.now().isoformat(), | |
| "user_query": user_query, | |
| "data": json.dumps(transaction_data) | |
| } | |
| # Add SQL transaction ID to metadata for linking | |
| if sql_transaction_id is not None: | |
| metadata["sql_transaction_id"] = sql_transaction_id | |
| metadata["sql_table"] = f"{transaction_data.get('type', 'unknown')}s" # purchases or sales | |
| # Store in vector database | |
| self.collection.add( | |
| documents=[summary], | |
| embeddings=[embedding], | |
| metadatas=[metadata], | |
| ids=[doc_id] | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"Error adding transaction event: {e}") | |
| return False | |
| def get_transaction_by_sql_id(self, sql_transaction_id: int, transaction_type: str) -> Optional[Dict[str, Any]]: | |
| """Retrieve vector store entry linked to a specific SQL transaction ID""" | |
| try: | |
| # Query the collection for entries with matching SQL transaction ID | |
| results = self.collection.get( | |
| where={ | |
| "sql_transaction_id": sql_transaction_id, | |
| "transaction_type": transaction_type | |
| }, | |
| limit=1 | |
| ) | |
| if results and results['documents']: | |
| return { | |
| "id": results['ids'][0], | |
| "document": results['documents'][0], | |
| "metadata": results['metadatas'][0] | |
| } | |
| return None | |
| except Exception as e: | |
| print(f"Error retrieving transaction by SQL ID: {e}") | |
| return None | |
| def add_general_event(self, event_text: str, event_type: str = "general") -> bool: | |
| """Add a general event or information to the vector store""" | |
| if not self.model: | |
| return False | |
| try: | |
| # Generate embedding | |
| embedding = self.model.encode(event_text).tolist() | |
| # Create document ID | |
| doc_id = f"event_{datetime.now().isoformat()}_{hash(event_text) % 10000}" | |
| # Store in vector database | |
| self.collection.add( | |
| documents=[event_text], | |
| embeddings=[embedding], | |
| metadatas=[{ | |
| "type": event_type, | |
| "timestamp": datetime.now().isoformat() | |
| }], | |
| ids=[doc_id] | |
| ) | |
| return True | |
| except Exception as e: | |
| print(f"Error adding general event: {e}") | |
| return False | |
| def search_similar_events(self, query: str, n_results: int = 5) -> List[Dict[str, Any]]: | |
| """Search for similar events based on semantic similarity""" | |
| if not self.model: | |
| return [] | |
| try: | |
| # Generate query embedding | |
| query_embedding = self.model.encode(query).tolist() | |
| # Search vector database | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=n_results | |
| ) | |
| # Format results | |
| formatted_results = [] | |
| if results['documents'] and results['documents'][0]: | |
| for i, doc in enumerate(results['documents'][0]): | |
| result = { | |
| "document": doc, | |
| "distance": results['distances'][0][i] if results['distances'] else None, | |
| "metadata": results['metadatas'][0][i] if results['metadatas'] else {} | |
| } | |
| formatted_results.append(result) | |
| return formatted_results | |
| except Exception as e: | |
| print(f"Error searching events: {e}") | |
| return [] | |
| def get_recent_events(self, n_results: int = 10) -> List[Dict[str, Any]]: | |
| """Get recent events from the vector store""" | |
| try: | |
| results = self.collection.get( | |
| limit=n_results, | |
| include=["documents", "metadatas"] | |
| ) | |
| formatted_results = [] | |
| if results['documents']: | |
| for i, doc in enumerate(results['documents']): | |
| result = { | |
| "document": doc, | |
| "metadata": results['metadatas'][i] if results['metadatas'] else {} | |
| } | |
| formatted_results.append(result) | |
| # Sort by timestamp if available | |
| formatted_results.sort( | |
| key=lambda x: x.get('metadata', {}).get('timestamp', ''), | |
| reverse=True | |
| ) | |
| return formatted_results | |
| except Exception as e: | |
| print(f"Error getting recent events: {e}") | |
| return [] | |
| def _create_event_summary(self, transaction_data: Dict[str, Any], user_query: str) -> str: | |
| """Create a semantic summary of a transaction event""" | |
| summary_parts = [] | |
| # Add transaction type | |
| trans_type = transaction_data.get("type", "transaction") | |
| summary_parts.append(f"Business {trans_type} event:") | |
| # Add key details | |
| if "product" in transaction_data: | |
| summary_parts.append(f"Product: {transaction_data['product']}") | |
| if "quantity" in transaction_data: | |
| summary_parts.append(f"Quantity: {transaction_data['quantity']}") | |
| if "supplier" in transaction_data: | |
| summary_parts.append(f"Supplier: {transaction_data['supplier']}") | |
| if "customer" in transaction_data: | |
| summary_parts.append(f"Customer: {transaction_data['customer']}") | |
| if "total" in transaction_data: | |
| summary_parts.append(f"Total amount: €{transaction_data['total']}") | |
| # Add original user query for context | |
| summary_parts.append(f"Original request: {user_query}") | |
| return " | ".join(summary_parts) | |
| def delete_collection(self): | |
| """Delete the entire collection (use with caution)""" | |
| try: | |
| self.client.delete_collection(name=self.collection.name) | |
| return True | |
| except Exception as e: | |
| print(f"Error deleting collection: {e}") | |
| return False | |
| def get_collection_count(self) -> int: | |
| """Get the number of documents in the collection""" | |
| try: | |
| return self.collection.count() | |
| except Exception as e: | |
| print(f"Error getting collection count: {e}") | |
| return 0 |