hpmor / src /vector_store.py
deenaik's picture
Initial commit
6ef4823
"""Vector store management for document embeddings."""
import os
from typing import List, Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from llama_index.core import Document, VectorStoreIndex, StorageContext
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from src.config import config
class VectorStoreManager:
"""Manage ChromaDB vector store for document embeddings."""
def __init__(self):
self.collection_name = config.collection_name
self.persist_dir = str(config.chroma_persist_dir)
self.embedding_model = config.embedding_model
# Initialize embedding model
print(f"Loading embedding model: {self.embedding_model}")
self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model,
cache_folder="./models"
)
# Initialize ChromaDB client
self.chroma_client = chromadb.PersistentClient(
path=self.persist_dir,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = None
self.vector_store = None
self.index = None
def initialize_collection(self, reset: bool = False) -> None:
"""Initialize ChromaDB collection."""
if reset:
# Delete existing collection if it exists
try:
self.chroma_client.delete_collection(name=self.collection_name)
print(f"Deleted existing collection: {self.collection_name}")
except Exception:
pass
# Create or get collection
self.collection = self.chroma_client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"Using collection: {self.collection_name}")
# Initialize vector store
self.vector_store = ChromaVectorStore(
chroma_collection=self.collection,
embedding_function=self.embed_model
)
def create_index(self, documents: List[Document], show_progress: bool = True) -> VectorStoreIndex:
"""Create vector index from documents."""
if not self.vector_store:
self.initialize_collection()
print(f"Creating index from {len(documents)} documents...")
# Create storage context
storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Create index with documents
self.index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
embed_model=self.embed_model,
show_progress=show_progress
)
print("Index created successfully!")
return self.index
def load_index(self) -> Optional[VectorStoreIndex]:
"""Load existing index from storage."""
if not self.vector_store:
self.initialize_collection()
# Check if collection has data
if self.collection.count() == 0:
print("No existing index found in ChromaDB")
return None
print(f"Loading index with {self.collection.count()} vectors")
# Create storage context
storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Load index
self.index = VectorStoreIndex.from_vector_store(
self.vector_store,
storage_context=storage_context,
embed_model=self.embed_model
)
return self.index
def get_or_create_index(
self,
documents: Optional[List[Document]] = None,
force_recreate: bool = False
) -> VectorStoreIndex:
"""Get existing index or create new one."""
if not force_recreate:
# Try to load existing index
index = self.load_index()
if index:
return index
# Create new index
if not documents:
raise ValueError("No documents provided for creating index")
self.initialize_collection(reset=True)
return self.create_index(documents)
def query(self, query_text: str, top_k: int = None) -> List:
"""Query the vector store."""
if not self.index:
raise ValueError("Index not initialized. Call get_or_create_index first.")
if top_k is None:
top_k = config.top_k_retrieval
# Use retriever directly instead of query engine to avoid LLM requirement
retriever = self.index.as_retriever(
similarity_top_k=top_k
)
# Retrieve nodes
nodes = retriever.retrieve(query_text)
return nodes
def get_stats(self) -> dict:
"""Get statistics about the vector store."""
if not self.collection:
self.initialize_collection()
stats = {
"collection_name": self.collection_name,
"persist_dir": self.persist_dir,
"embedding_model": self.embedding_model,
"num_vectors": self.collection.count(),
"metadata": self.collection.metadata
}
return stats
def main():
"""Test vector store functionality."""
from src.document_processor import HPMORProcessor
# Process documents
processor = HPMORProcessor()
documents = processor.process()
# Create vector store
vector_store = VectorStoreManager()
index = vector_store.get_or_create_index(documents, force_recreate=True)
# Get stats
stats = vector_store.get_stats()
print("\nVector Store Statistics:")
for key, value in stats.items():
print(f" {key}: {value}")
# Test query
test_query = "What is Harry's opinion on magic?"
print(f"\nTest query: '{test_query}'")
results = vector_store.query(test_query, top_k=3)
print(f"\nFound {len(results)} relevant chunks:")
for i, node in enumerate(results, 1):
print(f"\n{i}. Score: {node.score:.4f}")
print(f" Chapter: {node.metadata.get('chapter_title', 'Unknown')}")
print(f" Text preview: {node.text[:200]}...")
if __name__ == "__main__":
main()