Business_Chatbot / src /vector_store.py
Ancastal's picture
Upload folder using huggingface_hub
401b16c verified
import chromadb
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional
import json
from datetime import datetime
class VectorStore:
def __init__(self, collection_name: str = "chatbot_events"):
self.client = chromadb.PersistentClient(path="./chroma_db")
self.collection = self.client.get_or_create_collection(name=collection_name)
try:
self.model = SentenceTransformer('all-MiniLM-L6-v2')
except Exception as e:
print(f"Warning: Could not load sentence transformer model: {e}")
self.model = None
def add_transaction_event(self, transaction_data: Dict[str, Any], user_query: str, sql_transaction_id: Optional[int] = None) -> bool:
"""Add a transaction event to the vector store"""
if not self.model:
return False
try:
# Create a semantic summary of the event
summary = self._create_event_summary(transaction_data, user_query)
# Generate embedding
embedding = self.model.encode(summary).tolist()
# Create document ID - include SQL ID if available for better linking
doc_id = f"transaction_{sql_transaction_id or 'unknown'}_{datetime.now().isoformat()}_{hash(summary) % 10000}"
# Prepare metadata with SQL transaction linking
metadata = {
"type": "transaction",
"transaction_type": transaction_data.get("type", "unknown"),
"timestamp": datetime.now().isoformat(),
"user_query": user_query,
"data": json.dumps(transaction_data)
}
# Add SQL transaction ID to metadata for linking
if sql_transaction_id is not None:
metadata["sql_transaction_id"] = sql_transaction_id
metadata["sql_table"] = f"{transaction_data.get('type', 'unknown')}s" # purchases or sales
# Store in vector database
self.collection.add(
documents=[summary],
embeddings=[embedding],
metadatas=[metadata],
ids=[doc_id]
)
return True
except Exception as e:
print(f"Error adding transaction event: {e}")
return False
def get_transaction_by_sql_id(self, sql_transaction_id: int, transaction_type: str) -> Optional[Dict[str, Any]]:
"""Retrieve vector store entry linked to a specific SQL transaction ID"""
try:
# Query the collection for entries with matching SQL transaction ID
results = self.collection.get(
where={
"sql_transaction_id": sql_transaction_id,
"transaction_type": transaction_type
},
limit=1
)
if results and results['documents']:
return {
"id": results['ids'][0],
"document": results['documents'][0],
"metadata": results['metadatas'][0]
}
return None
except Exception as e:
print(f"Error retrieving transaction by SQL ID: {e}")
return None
def add_general_event(self, event_text: str, event_type: str = "general") -> bool:
"""Add a general event or information to the vector store"""
if not self.model:
return False
try:
# Generate embedding
embedding = self.model.encode(event_text).tolist()
# Create document ID
doc_id = f"event_{datetime.now().isoformat()}_{hash(event_text) % 10000}"
# Store in vector database
self.collection.add(
documents=[event_text],
embeddings=[embedding],
metadatas=[{
"type": event_type,
"timestamp": datetime.now().isoformat()
}],
ids=[doc_id]
)
return True
except Exception as e:
print(f"Error adding general event: {e}")
return False
def search_similar_events(self, query: str, n_results: int = 5) -> List[Dict[str, Any]]:
"""Search for similar events based on semantic similarity"""
if not self.model:
return []
try:
# Generate query embedding
query_embedding = self.model.encode(query).tolist()
# Search vector database
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results
)
# Format results
formatted_results = []
if results['documents'] and results['documents'][0]:
for i, doc in enumerate(results['documents'][0]):
result = {
"document": doc,
"distance": results['distances'][0][i] if results['distances'] else None,
"metadata": results['metadatas'][0][i] if results['metadatas'] else {}
}
formatted_results.append(result)
return formatted_results
except Exception as e:
print(f"Error searching events: {e}")
return []
def get_recent_events(self, n_results: int = 10) -> List[Dict[str, Any]]:
"""Get recent events from the vector store"""
try:
results = self.collection.get(
limit=n_results,
include=["documents", "metadatas"]
)
formatted_results = []
if results['documents']:
for i, doc in enumerate(results['documents']):
result = {
"document": doc,
"metadata": results['metadatas'][i] if results['metadatas'] else {}
}
formatted_results.append(result)
# Sort by timestamp if available
formatted_results.sort(
key=lambda x: x.get('metadata', {}).get('timestamp', ''),
reverse=True
)
return formatted_results
except Exception as e:
print(f"Error getting recent events: {e}")
return []
def _create_event_summary(self, transaction_data: Dict[str, Any], user_query: str) -> str:
"""Create a semantic summary of a transaction event"""
summary_parts = []
# Add transaction type
trans_type = transaction_data.get("type", "transaction")
summary_parts.append(f"Business {trans_type} event:")
# Add key details
if "product" in transaction_data:
summary_parts.append(f"Product: {transaction_data['product']}")
if "quantity" in transaction_data:
summary_parts.append(f"Quantity: {transaction_data['quantity']}")
if "supplier" in transaction_data:
summary_parts.append(f"Supplier: {transaction_data['supplier']}")
if "customer" in transaction_data:
summary_parts.append(f"Customer: {transaction_data['customer']}")
if "total" in transaction_data:
summary_parts.append(f"Total amount: €{transaction_data['total']}")
# Add original user query for context
summary_parts.append(f"Original request: {user_query}")
return " | ".join(summary_parts)
def delete_collection(self):
"""Delete the entire collection (use with caution)"""
try:
self.client.delete_collection(name=self.collection.name)
return True
except Exception as e:
print(f"Error deleting collection: {e}")
return False
def get_collection_count(self) -> int:
"""Get the number of documents in the collection"""
try:
return self.collection.count()
except Exception as e:
print(f"Error getting collection count: {e}")
return 0