File size: 6,348 Bytes
6ef4823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
"""Vector store management for document embeddings."""
import os
from typing import List, Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings
from llama_index.core import Document, VectorStoreIndex, StorageContext
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.node_parser import SentenceSplitter
from src.config import config
class VectorStoreManager:
"""Manage ChromaDB vector store for document embeddings."""
def __init__(self):
self.collection_name = config.collection_name
self.persist_dir = str(config.chroma_persist_dir)
self.embedding_model = config.embedding_model
# Initialize embedding model
print(f"Loading embedding model: {self.embedding_model}")
self.embed_model = HuggingFaceEmbedding(
model_name=self.embedding_model,
cache_folder="./models"
)
# Initialize ChromaDB client
self.chroma_client = chromadb.PersistentClient(
path=self.persist_dir,
settings=Settings(anonymized_telemetry=False)
)
# Get or create collection
self.collection = None
self.vector_store = None
self.index = None
def initialize_collection(self, reset: bool = False) -> None:
"""Initialize ChromaDB collection."""
if reset:
# Delete existing collection if it exists
try:
self.chroma_client.delete_collection(name=self.collection_name)
print(f"Deleted existing collection: {self.collection_name}")
except Exception:
pass
# Create or get collection
self.collection = self.chroma_client.get_or_create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine"}
)
print(f"Using collection: {self.collection_name}")
# Initialize vector store
self.vector_store = ChromaVectorStore(
chroma_collection=self.collection,
embedding_function=self.embed_model
)
def create_index(self, documents: List[Document], show_progress: bool = True) -> VectorStoreIndex:
"""Create vector index from documents."""
if not self.vector_store:
self.initialize_collection()
print(f"Creating index from {len(documents)} documents...")
# Create storage context
storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Create index with documents
self.index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
embed_model=self.embed_model,
show_progress=show_progress
)
print("Index created successfully!")
return self.index
def load_index(self) -> Optional[VectorStoreIndex]:
"""Load existing index from storage."""
if not self.vector_store:
self.initialize_collection()
# Check if collection has data
if self.collection.count() == 0:
print("No existing index found in ChromaDB")
return None
print(f"Loading index with {self.collection.count()} vectors")
# Create storage context
storage_context = StorageContext.from_defaults(
vector_store=self.vector_store
)
# Load index
self.index = VectorStoreIndex.from_vector_store(
self.vector_store,
storage_context=storage_context,
embed_model=self.embed_model
)
return self.index
def get_or_create_index(
self,
documents: Optional[List[Document]] = None,
force_recreate: bool = False
) -> VectorStoreIndex:
"""Get existing index or create new one."""
if not force_recreate:
# Try to load existing index
index = self.load_index()
if index:
return index
# Create new index
if not documents:
raise ValueError("No documents provided for creating index")
self.initialize_collection(reset=True)
return self.create_index(documents)
def query(self, query_text: str, top_k: int = None) -> List:
"""Query the vector store."""
if not self.index:
raise ValueError("Index not initialized. Call get_or_create_index first.")
if top_k is None:
top_k = config.top_k_retrieval
# Use retriever directly instead of query engine to avoid LLM requirement
retriever = self.index.as_retriever(
similarity_top_k=top_k
)
# Retrieve nodes
nodes = retriever.retrieve(query_text)
return nodes
def get_stats(self) -> dict:
"""Get statistics about the vector store."""
if not self.collection:
self.initialize_collection()
stats = {
"collection_name": self.collection_name,
"persist_dir": self.persist_dir,
"embedding_model": self.embedding_model,
"num_vectors": self.collection.count(),
"metadata": self.collection.metadata
}
return stats
def main():
"""Test vector store functionality."""
from src.document_processor import HPMORProcessor
# Process documents
processor = HPMORProcessor()
documents = processor.process()
# Create vector store
vector_store = VectorStoreManager()
index = vector_store.get_or_create_index(documents, force_recreate=True)
# Get stats
stats = vector_store.get_stats()
print("\nVector Store Statistics:")
for key, value in stats.items():
print(f" {key}: {value}")
# Test query
test_query = "What is Harry's opinion on magic?"
print(f"\nTest query: '{test_query}'")
results = vector_store.query(test_query, top_k=3)
print(f"\nFound {len(results)} relevant chunks:")
for i, node in enumerate(results, 1):
print(f"\n{i}. Score: {node.score:.4f}")
print(f" Chapter: {node.metadata.get('chapter_title', 'Unknown')}")
print(f" Text preview: {node.text[:200]}...")
if __name__ == "__main__":
main() |