|
|
""" |
|
|
Vector database module using ChromaDB |
|
|
Stores and retrieves document chunks with embeddings |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any, Optional |
|
|
|
|
|
try: |
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
CHROMADB_AVAILABLE = True |
|
|
except ImportError: |
|
|
CHROMADB_AVAILABLE = False |
|
|
|
|
|
from .config import VECTOR_DB_DIR, DEFAULT_RETRIEVAL_K |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class LegalVectorDB: |
|
|
"""ChromaDB vector database for legal documents""" |
|
|
|
|
|
def __init__(self, persist_directory: Path = VECTOR_DB_DIR): |
|
|
""" |
|
|
Initialize ChromaDB with persistent storage |
|
|
|
|
|
Args: |
|
|
persist_directory: Directory to store the database |
|
|
""" |
|
|
if not CHROMADB_AVAILABLE: |
|
|
raise ImportError( |
|
|
"chromadb not installed. " |
|
|
"Install with: pip install chromadb" |
|
|
) |
|
|
|
|
|
self.persist_directory = Path(persist_directory) |
|
|
self.persist_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logger.info(f"Initializing ChromaDB at {self.persist_directory}") |
|
|
|
|
|
|
|
|
self.client = chromadb.PersistentClient( |
|
|
path=str(self.persist_directory) |
|
|
) |
|
|
|
|
|
|
|
|
self.collection_name = "nepal_legal_docs" |
|
|
self.collection = self.client.get_or_create_collection( |
|
|
name=self.collection_name, |
|
|
metadata={"description": "Nepal legal documents for RAG-based law explanation"} |
|
|
) |
|
|
|
|
|
current_count = self.collection.count() |
|
|
logger.info(f"Collection '{self.collection_name}' ready. Current document count: {current_count}") |
|
|
|
|
|
def add_chunks( |
|
|
self, |
|
|
chunks: List[Dict[str, Any]], |
|
|
embeddings: List[List[float]] |
|
|
) -> None: |
|
|
""" |
|
|
Add chunks with embeddings to the database |
|
|
|
|
|
Args: |
|
|
chunks: List of chunk dictionaries with 'chunk_id', 'text', and 'metadata' |
|
|
embeddings: List of embedding vectors (as lists) |
|
|
""" |
|
|
if len(chunks) != len(embeddings): |
|
|
raise ValueError(f"Number of chunks ({len(chunks)}) must match number of embeddings ({len(embeddings)})") |
|
|
|
|
|
|
|
|
ids = [chunk['chunk_id'] for chunk in chunks] |
|
|
documents = [chunk['text'] for chunk in chunks] |
|
|
|
|
|
|
|
|
|
|
|
metadatas = [] |
|
|
for chunk in chunks: |
|
|
cleaned_metadata = {} |
|
|
for key, value in chunk['metadata'].items(): |
|
|
if value is None: |
|
|
|
|
|
continue |
|
|
elif isinstance(value, (str, int, float, bool)): |
|
|
|
|
|
cleaned_metadata[key] = value |
|
|
elif isinstance(value, list): |
|
|
|
|
|
if value: |
|
|
cleaned_metadata[key] = ', '.join(str(item) for item in value) |
|
|
else: |
|
|
|
|
|
cleaned_metadata[key] = str(value) |
|
|
metadatas.append(cleaned_metadata) |
|
|
|
|
|
logger.info(f"Adding {len(chunks)} chunks to vector database") |
|
|
|
|
|
|
|
|
self.collection.add( |
|
|
ids=ids, |
|
|
documents=documents, |
|
|
embeddings=embeddings, |
|
|
metadatas=metadatas |
|
|
) |
|
|
|
|
|
total_count = self.collection.count() |
|
|
logger.info(f"Successfully added chunks. Total documents in database: {total_count}") |
|
|
|
|
|
def query( |
|
|
self, |
|
|
query_text: str, |
|
|
n_results: int = DEFAULT_RETRIEVAL_K, |
|
|
where: Optional[Dict] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Query the database with a text query |
|
|
|
|
|
Args: |
|
|
query_text: Query string |
|
|
n_results: Number of results to return |
|
|
where: Optional metadata filter |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'ids', 'documents', 'metadatas', and 'distances' |
|
|
""" |
|
|
logger.info(f"Querying database with: '{query_text[:50]}...' (n_results={n_results})") |
|
|
|
|
|
results = self.collection.query( |
|
|
query_texts=[query_text], |
|
|
n_results=n_results, |
|
|
where=where |
|
|
) |
|
|
|
|
|
return results |
|
|
|
|
|
def query_with_embedding( |
|
|
self, |
|
|
query_embedding: List[float], |
|
|
n_results: int = DEFAULT_RETRIEVAL_K, |
|
|
where: Optional[Dict] = None |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Query with pre-computed embedding |
|
|
|
|
|
Args: |
|
|
query_embedding: Query embedding vector |
|
|
n_results: Number of results to return |
|
|
where: Optional metadata filter |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'ids', 'documents', 'metadatas', and 'distances' |
|
|
""" |
|
|
logger.info(f"Querying database with embedding (n_results={n_results})") |
|
|
|
|
|
results = self.collection.query( |
|
|
query_embeddings=[query_embedding], |
|
|
n_results=n_results, |
|
|
where=where |
|
|
) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_count(self) -> int: |
|
|
"""Get the number of documents in the database""" |
|
|
return self.collection.count() |
|
|
|
|
|
def delete_collection(self) -> None: |
|
|
"""Delete the entire collection (use with caution!)""" |
|
|
logger.warning(f"Deleting collection '{self.collection_name}'") |
|
|
self.client.delete_collection(name=self.collection_name) |
|
|
logger.info("Collection deleted") |
|
|
|
|
|
def peek(self, limit: int = 5) -> Dict[str, Any]: |
|
|
""" |
|
|
Peek at some documents in the database |
|
|
|
|
|
Args: |
|
|
limit: Number of documents to return |
|
|
|
|
|
Returns: |
|
|
Dictionary with sample documents |
|
|
""" |
|
|
return self.collection.peek(limit=limit) |
|
|
|