hh786's picture
Deployment of Hierarchical RAG system
c54dcef
"""Vector database indexing and operations."""
import os
import sys
# DEBUG: Print all httpx versions
print("\n=== DEBUG INFO ===")
try:
import httpx
print(f"httpx version: {httpx.__version__}")
except Exception as e:
print(f"httpx error: {e}")
try:
import chromadb
print(f"chromadb version: {chromadb.__version__}")
except Exception as e:
print(f"chromadb error: {e}")
print("==================\n")
from typing import List, Dict, Any, Optional
from pathlib import Path
import chromadb
from sentence_transformers import SentenceTransformer
import numpy as np
class EmbeddingModel:
"""Wrapper for embedding models."""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
"""
Initialize embedding model.
Args:
model_name: Name of the sentence transformer model
"""
self.model_name = model_name
try:
self.model = SentenceTransformer(model_name, trust_remote_code=True)
except Exception as e:
print(f"Error loading model {model_name}: {e}")
# Fallback to a simpler model
self.model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
self.embedding_dim = self.model.get_sentence_embedding_dimension()
def embed_texts(self, texts: List[str]) -> np.ndarray:
"""
Generate embeddings for a list of texts.
Args:
texts: List of text strings
Returns:
Numpy array of embeddings
"""
return self.model.encode(texts, show_progress_bar=False)
def embed_query(self, query: str) -> np.ndarray:
"""
Generate embedding for a single query.
Args:
query: Query string
Returns:
Numpy array embedding
"""
return self.model.encode([query], show_progress_bar=False)[0]
class VectorStore:
"""Vector database operations using ChromaDB."""
def __init__(
self,
collection_name: str = "rag_documents",
persist_directory: str = "./data/chroma",
embedding_model: Optional[EmbeddingModel] = None
):
"""
Initialize vector store.
Args:
collection_name: Name of the collection
persist_directory: Directory to persist the database
embedding_model: Embedding model instance
"""
self.collection_name = collection_name
self.persist_directory = persist_directory
# Create persist directory if it doesn't exist
Path(persist_directory).mkdir(parents=True, exist_ok=True)
# Initialize ChromaDB client - simplified
self.client = chromadb.PersistentClient(path=persist_directory)
# Initialize embedding model
self.embedding_model = embedding_model or EmbeddingModel()
# Get or create collection
self.collection = self.client.get_or_create_collection(
name=collection_name,
metadata={"hnsw:space": "cosine"}
)
def add_documents(
self,
chunks: List[Dict[str, Any]],
batch_size: int = 100
) -> int:
"""
Add documents to the vector store.
Args:
chunks: List of chunk dictionaries with 'text' and 'metadata'
batch_size: Number of documents to process at once
Returns:
Number of documents added
"""
if not chunks:
return 0
total_added = 0
# Process in batches
for i in range(0, len(chunks), batch_size):
batch = chunks[i:i + batch_size]
# Extract texts and IDs
texts = [chunk["text"] for chunk in batch]
ids = [chunk["metadata"]["chunk_id"] for chunk in batch]
# Generate embeddings
embeddings = self.embedding_model.embed_texts(texts)
# Prepare metadata (convert all values to strings for ChromaDB)
metadatas = []
for chunk in batch:
metadata = {}
for key, value in chunk["metadata"].items():
if value is not None:
metadata[key] = str(value)
metadatas.append(metadata)
# Add to collection
self.collection.add(
ids=ids,
embeddings=embeddings.tolist(),
documents=texts,
metadatas=metadatas
)
total_added += len(batch)
return total_added
def search(
self,
query: str,
n_results: int = 5,
where: Optional[Dict[str, Any]] = None,
where_document: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Search the vector store.
Args:
query: Search query string
n_results: Number of results to return
where: Metadata filters (e.g., {"level1": "Clinical Care"} or {"$and": [{"level1": "Clinical Care"}, {"doc_type": "policy"}]})
where_document: Document content filters
Returns:
List of search results with documents, metadata, and distances
"""
# Generate query embedding
query_embedding = self.embedding_model.embed_query(query)
# Perform search
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=n_results,
where=where,
where_document=where_document
)
# Format results
formatted_results = []
for i in range(len(results['ids'][0])):
formatted_results.append({
"id": results['ids'][0][i],
"document": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i]
})
return formatted_results
def search_with_hierarchy(
self,
query: str,
n_results: int = 5,
level1: Optional[str] = None,
level2: Optional[str] = None,
level3: Optional[str] = None,
doc_type: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Search with hierarchical filtering.
Args:
query: Search query string
n_results: Number of results to return
level1: Domain filter
level2: Section filter
level3: Topic filter
doc_type: Document type filter
Returns:
List of search results
"""
# Build where clause with proper ChromaDB format
filters = []
if level1:
filters.append({"level1": level1})
if level2:
filters.append({"level2": level2})
if level3:
filters.append({"level3": level3})
if doc_type:
filters.append({"doc_type": doc_type})
# Construct where clause based on number of filters
where = None
if len(filters) == 0:
where = None
elif len(filters) == 1:
where = filters[0]
else:
# Multiple filters require $and operator
where = {"$and": filters}
return self.search(query, n_results=n_results, where=where)
def get_collection_stats(self) -> Dict[str, Any]:
"""
Get statistics about the collection.
Returns:
Dictionary with collection statistics
"""
count = self.collection.count()
# Get sample to check metadata
sample = self.collection.get(limit=1)
stats = {
"collection_name": self.collection_name,
"total_chunks": count,
"embedding_dimension": self.embedding_model.embedding_dim,
"model_name": self.embedding_model.model_name
}
if sample['metadatas']:
stats["sample_metadata_keys"] = list(sample['metadatas'][0].keys())
return stats
def delete_collection(self) -> None:
"""Delete the entire collection."""
self.client.delete_collection(name=self.collection_name)
def clear_collection(self) -> None:
"""Clear all documents from the collection."""
# Delete and recreate
self.delete_collection()
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
class IndexManager:
"""Manage multiple vector stores and indexing operations."""
def __init__(
self,
persist_directory: str = "./data/chroma",
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"
):
"""
Initialize index manager.
Args:
persist_directory: Directory to persist databases
embedding_model_name: Name of the embedding model
"""
self.persist_directory = persist_directory
self.embedding_model = EmbeddingModel(embedding_model_name)
self.stores: Dict[str, VectorStore] = {}
def create_store(self, collection_name: str) -> VectorStore:
"""
Create or get a vector store.
Args:
collection_name: Name of the collection
Returns:
VectorStore instance
"""
if collection_name not in self.stores:
self.stores[collection_name] = VectorStore(
collection_name=collection_name,
persist_directory=self.persist_directory,
embedding_model=self.embedding_model
)
return self.stores[collection_name]
def get_store(self, collection_name: str) -> Optional[VectorStore]:
"""
Get an existing vector store.
Args:
collection_name: Name of the collection
Returns:
VectorStore instance or None
"""
return self.stores.get(collection_name)
def index_documents(
self,
chunks: List[Dict[str, Any]],
collection_name: str = "rag_documents"
) -> Dict[str, Any]:
"""
Index documents into a collection.
Args:
chunks: List of processed document chunks
collection_name: Target collection name
Returns:
Dictionary with indexing statistics
"""
store = self.create_store(collection_name)
# Add documents
num_added = store.add_documents(chunks)
# Get stats
stats = store.get_collection_stats()
stats["chunks_added"] = num_added
return stats
def list_collections(self) -> List[str]:
"""
List all available collections.
Returns:
List of collection names
"""
return list(self.stores.keys())