setu / module_a /vector_db.py
khagu's picture
chore: finally untrack large database files
3998131
"""
Vector database module using ChromaDB
Stores and retrieves document chunks with embeddings
"""
import logging
from pathlib import Path
from typing import List, Dict, Any, Optional
try:
import chromadb
from chromadb.config import Settings
CHROMADB_AVAILABLE = True
except ImportError:
CHROMADB_AVAILABLE = False
from .config import VECTOR_DB_DIR, DEFAULT_RETRIEVAL_K
logger = logging.getLogger(__name__)
class LegalVectorDB:
"""ChromaDB vector database for legal documents"""
def __init__(self, persist_directory: Path = VECTOR_DB_DIR):
"""
Initialize ChromaDB with persistent storage
Args:
persist_directory: Directory to store the database
"""
if not CHROMADB_AVAILABLE:
raise ImportError(
"chromadb not installed. "
"Install with: pip install chromadb"
)
self.persist_directory = Path(persist_directory)
self.persist_directory.mkdir(parents=True, exist_ok=True)
logger.info(f"Initializing ChromaDB at {self.persist_directory}")
# Initialize ChromaDB client with persistent storage
self.client = chromadb.PersistentClient(
path=str(self.persist_directory)
)
# Create or get collection
self.collection_name = "nepal_legal_docs"
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"description": "Nepal legal documents for RAG-based law explanation"}
)
current_count = self.collection.count()
logger.info(f"Collection '{self.collection_name}' ready. Current document count: {current_count}")
def add_chunks(
self,
chunks: List[Dict[str, Any]],
embeddings: List[List[float]]
) -> None:
"""
Add chunks with embeddings to the database
Args:
chunks: List of chunk dictionaries with 'chunk_id', 'text', and 'metadata'
embeddings: List of embedding vectors (as lists)
"""
if len(chunks) != len(embeddings):
raise ValueError(f"Number of chunks ({len(chunks)}) must match number of embeddings ({len(embeddings)})")
# Extract data from chunks
ids = [chunk['chunk_id'] for chunk in chunks]
documents = [chunk['text'] for chunk in chunks]
# Clean metadata: ChromaDB only accepts str, int, float, bool
# Remove None values and convert other types to strings
metadatas = []
for chunk in chunks:
cleaned_metadata = {}
for key, value in chunk['metadata'].items():
if value is None:
# Skip None values
continue
elif isinstance(value, (str, int, float, bool)):
# Keep valid types as-is
cleaned_metadata[key] = value
elif isinstance(value, list):
# Convert lists to comma-separated strings
if value: # Only include non-empty lists
cleaned_metadata[key] = ', '.join(str(item) for item in value)
else:
# Convert other types to strings
cleaned_metadata[key] = str(value)
metadatas.append(cleaned_metadata)
logger.info(f"Adding {len(chunks)} chunks to vector database")
# Add to ChromaDB
self.collection.add(
ids=ids,
documents=documents,
embeddings=embeddings,
metadatas=metadatas
)
total_count = self.collection.count()
logger.info(f"Successfully added chunks. Total documents in database: {total_count}")
def query(
self,
query_text: str,
n_results: int = DEFAULT_RETRIEVAL_K,
where: Optional[Dict] = None
) -> Dict[str, Any]:
"""
Query the database with a text query
Args:
query_text: Query string
n_results: Number of results to return
where: Optional metadata filter
Returns:
Dictionary with 'ids', 'documents', 'metadatas', and 'distances'
"""
logger.info(f"Querying database with: '{query_text[:50]}...' (n_results={n_results})")
results = self.collection.query(
query_texts=[query_text],
n_results=n_results,
where=where
)
return results
def query_with_embedding(
self,
query_embedding: List[float],
n_results: int = DEFAULT_RETRIEVAL_K,
where: Optional[Dict] = None
) -> Dict[str, Any]:
"""
Query with pre-computed embedding
Args:
query_embedding: Query embedding vector
n_results: Number of results to return
where: Optional metadata filter
Returns:
Dictionary with 'ids', 'documents', 'metadatas', and 'distances'
"""
logger.info(f"Querying database with embedding (n_results={n_results})")
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=n_results,
where=where
)
return results
def get_count(self) -> int:
"""Get the number of documents in the database"""
return self.collection.count()
def delete_collection(self) -> None:
"""Delete the entire collection (use with caution!)"""
logger.warning(f"Deleting collection '{self.collection_name}'")
self.client.delete_collection(name=self.collection_name)
logger.info("Collection deleted")
def peek(self, limit: int = 5) -> Dict[str, Any]:
"""
Peek at some documents in the database
Args:
limit: Number of documents to return
Returns:
Dictionary with sample documents
"""
return self.collection.peek(limit=limit)