|
|
"""Vector store management for document embeddings.""" |
|
|
|
|
|
import os |
|
|
from typing import List, Optional |
|
|
from pathlib import Path |
|
|
|
|
|
import chromadb |
|
|
from chromadb.config import Settings |
|
|
from llama_index.core import Document, VectorStoreIndex, StorageContext |
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
|
from llama_index.core.node_parser import SentenceSplitter |
|
|
|
|
|
from src.config import config |
|
|
|
|
|
|
|
|
class VectorStoreManager: |
|
|
"""Manage ChromaDB vector store for document embeddings.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.collection_name = config.collection_name |
|
|
self.persist_dir = str(config.chroma_persist_dir) |
|
|
self.embedding_model = config.embedding_model |
|
|
|
|
|
|
|
|
print(f"Loading embedding model: {self.embedding_model}") |
|
|
self.embed_model = HuggingFaceEmbedding( |
|
|
model_name=self.embedding_model, |
|
|
cache_folder="./models" |
|
|
) |
|
|
|
|
|
|
|
|
self.chroma_client = chromadb.PersistentClient( |
|
|
path=self.persist_dir, |
|
|
settings=Settings(anonymized_telemetry=False) |
|
|
) |
|
|
|
|
|
|
|
|
self.collection = None |
|
|
self.vector_store = None |
|
|
self.index = None |
|
|
|
|
|
def initialize_collection(self, reset: bool = False) -> None: |
|
|
"""Initialize ChromaDB collection.""" |
|
|
if reset: |
|
|
|
|
|
try: |
|
|
self.chroma_client.delete_collection(name=self.collection_name) |
|
|
print(f"Deleted existing collection: {self.collection_name}") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
self.collection = self.chroma_client.get_or_create_collection( |
|
|
name=self.collection_name, |
|
|
metadata={"hnsw:space": "cosine"} |
|
|
) |
|
|
print(f"Using collection: {self.collection_name}") |
|
|
|
|
|
|
|
|
self.vector_store = ChromaVectorStore( |
|
|
chroma_collection=self.collection, |
|
|
embedding_function=self.embed_model |
|
|
) |
|
|
|
|
|
def create_index(self, documents: List[Document], show_progress: bool = True) -> VectorStoreIndex: |
|
|
"""Create vector index from documents.""" |
|
|
if not self.vector_store: |
|
|
self.initialize_collection() |
|
|
|
|
|
print(f"Creating index from {len(documents)} documents...") |
|
|
|
|
|
|
|
|
storage_context = StorageContext.from_defaults( |
|
|
vector_store=self.vector_store |
|
|
) |
|
|
|
|
|
|
|
|
self.index = VectorStoreIndex.from_documents( |
|
|
documents, |
|
|
storage_context=storage_context, |
|
|
embed_model=self.embed_model, |
|
|
show_progress=show_progress |
|
|
) |
|
|
|
|
|
print("Index created successfully!") |
|
|
return self.index |
|
|
|
|
|
def load_index(self) -> Optional[VectorStoreIndex]: |
|
|
"""Load existing index from storage.""" |
|
|
if not self.vector_store: |
|
|
self.initialize_collection() |
|
|
|
|
|
|
|
|
if self.collection.count() == 0: |
|
|
print("No existing index found in ChromaDB") |
|
|
return None |
|
|
|
|
|
print(f"Loading index with {self.collection.count()} vectors") |
|
|
|
|
|
|
|
|
storage_context = StorageContext.from_defaults( |
|
|
vector_store=self.vector_store |
|
|
) |
|
|
|
|
|
|
|
|
self.index = VectorStoreIndex.from_vector_store( |
|
|
self.vector_store, |
|
|
storage_context=storage_context, |
|
|
embed_model=self.embed_model |
|
|
) |
|
|
|
|
|
return self.index |
|
|
|
|
|
def get_or_create_index( |
|
|
self, |
|
|
documents: Optional[List[Document]] = None, |
|
|
force_recreate: bool = False |
|
|
) -> VectorStoreIndex: |
|
|
"""Get existing index or create new one.""" |
|
|
if not force_recreate: |
|
|
|
|
|
index = self.load_index() |
|
|
if index: |
|
|
return index |
|
|
|
|
|
|
|
|
if not documents: |
|
|
raise ValueError("No documents provided for creating index") |
|
|
|
|
|
self.initialize_collection(reset=True) |
|
|
return self.create_index(documents) |
|
|
|
|
|
def query(self, query_text: str, top_k: int = None) -> List: |
|
|
"""Query the vector store.""" |
|
|
if not self.index: |
|
|
raise ValueError("Index not initialized. Call get_or_create_index first.") |
|
|
|
|
|
if top_k is None: |
|
|
top_k = config.top_k_retrieval |
|
|
|
|
|
|
|
|
retriever = self.index.as_retriever( |
|
|
similarity_top_k=top_k |
|
|
) |
|
|
|
|
|
|
|
|
nodes = retriever.retrieve(query_text) |
|
|
return nodes |
|
|
|
|
|
def get_stats(self) -> dict: |
|
|
"""Get statistics about the vector store.""" |
|
|
if not self.collection: |
|
|
self.initialize_collection() |
|
|
|
|
|
stats = { |
|
|
"collection_name": self.collection_name, |
|
|
"persist_dir": self.persist_dir, |
|
|
"embedding_model": self.embedding_model, |
|
|
"num_vectors": self.collection.count(), |
|
|
"metadata": self.collection.metadata |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Test vector store functionality.""" |
|
|
from src.document_processor import HPMORProcessor |
|
|
|
|
|
|
|
|
processor = HPMORProcessor() |
|
|
documents = processor.process() |
|
|
|
|
|
|
|
|
vector_store = VectorStoreManager() |
|
|
index = vector_store.get_or_create_index(documents, force_recreate=True) |
|
|
|
|
|
|
|
|
stats = vector_store.get_stats() |
|
|
print("\nVector Store Statistics:") |
|
|
for key, value in stats.items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
|
|
|
test_query = "What is Harry's opinion on magic?" |
|
|
print(f"\nTest query: '{test_query}'") |
|
|
results = vector_store.query(test_query, top_k=3) |
|
|
|
|
|
print(f"\nFound {len(results)} relevant chunks:") |
|
|
for i, node in enumerate(results, 1): |
|
|
print(f"\n{i}. Score: {node.score:.4f}") |
|
|
print(f" Chapter: {node.metadata.get('chapter_title', 'Unknown')}") |
|
|
print(f" Text preview: {node.text[:200]}...") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |