Spaces:
Sleeping
Sleeping
File size: 7,904 Bytes
14f13a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
"""
Vector retrieval system using ChromaDB.
Handles document storage, indexing, and semantic search.
"""
import logging
from typing import List, Dict, Any, Optional
from pathlib import Path
import chromadb
from chromadb.config import Settings as ChromaSettings
from chromadb.utils import embedding_functions
from src.embeddings import EmbeddingGenerator
from src.chunking import DocumentChunk
logger = logging.getLogger(__name__)
class DocumentRetriever:
"""
Manages document storage and retrieval using ChromaDB.
Features:
- Persistent vector storage
- Semantic similarity search
- Metadata filtering
- Source attribution
"""
def __init__(
self,
persist_directory: str,
collection_name: str,
embedding_generator: EmbeddingGenerator
):
"""
Initialize retriever.
Args:
persist_directory: Path to ChromaDB storage
collection_name: Name of the collection
embedding_generator: Embedding generator instance
"""
self.persist_directory = Path(persist_directory)
self.persist_directory.mkdir(parents=True, exist_ok=True)
self.collection_name = collection_name
self.embedding_generator = embedding_generator
# Initialize ChromaDB client
logger.info(f"Initializing ChromaDB at {persist_directory}")
self.client = chromadb.PersistentClient(
path=str(self.persist_directory),
settings=ChromaSettings(
anonymized_telemetry=False,
allow_reset=True
)
)
# Get or create collection (cosine distance for proper similarity scores)
self.collection = self._get_or_create_collection()
coll_meta = self.collection.metadata or {}
self._use_cosine = coll_meta.get("hnsw:space") == "cosine"
logger.info(f"Collection '{collection_name}' ready. Count: {self.collection.count()}. Distance: {'cosine' if self._use_cosine else 'l2'}")
def _get_or_create_collection(self):
"""Get existing collection or create new one."""
try:
# Try to get existing collection
collection = self.client.get_collection(
name=self.collection_name
)
logger.info(f"Loaded existing collection: {self.collection_name}")
except Exception:
# Create new collection with cosine distance so scores stay in [0, 1]
collection = self.client.create_collection(
name=self.collection_name,
metadata={"hnsw:space": "cosine", "description": "Developer documentation chunks"}
)
logger.info(f"Created new collection: {self.collection_name}")
return collection
def add_documents(self, chunks: List[DocumentChunk]) -> None:
"""
Add document chunks to the vector store.
Args:
chunks: List of DocumentChunk objects
"""
if not chunks:
logger.warning("No chunks to add")
return
logger.info(f"Adding {len(chunks)} chunks to collection")
# Prepare data for ChromaDB
documents = [chunk.content for chunk in chunks]
metadatas = [chunk.metadata for chunk in chunks]
ids = [chunk.chunk_id for chunk in chunks]
# Generate embeddings
embeddings = self.embedding_generator.embed_documents(documents)
# Add to collection in batches
batch_size = 100
for i in range(0, len(chunks), batch_size):
batch_end = min(i + batch_size, len(chunks))
self.collection.add(
embeddings=embeddings[i:batch_end].tolist(),
documents=documents[i:batch_end],
metadatas=metadatas[i:batch_end],
ids=ids[i:batch_end]
)
logger.debug(f"Added batch {i//batch_size + 1}")
logger.info(f"Successfully added {len(chunks)} chunks. Total: {self.collection.count()}")
def retrieve(
self,
query: str,
top_k: int = 5,
filter_metadata: Optional[Dict[str, Any]] = None
) -> List[Dict[str, Any]]:
"""
Retrieve relevant documents for a query.
Args:
query: Search query
top_k: Number of results to return
filter_metadata: Optional metadata filters
Returns:
List of results with content, metadata, and scores
"""
logger.debug(f"Retrieving top {top_k} results for query: {query[:100]}...")
# Generate query embedding
query_embedding = self.embedding_generator.embed_query(query)
# Search
results = self.collection.query(
query_embeddings=[query_embedding.tolist()],
n_results=top_k,
where=filter_metadata,
include=["documents", "metadatas", "distances"]
)
# Format results
formatted_results = []
if results["documents"] and results["documents"][0]:
for i in range(len(results["documents"][0])):
d = results["distances"][0][i]
score = max(0.0, 1 - d) if self._use_cosine else max(0.0, 1 - d ** 2 / 2)
formatted_results.append({
"content": results["documents"][0][i],
"metadata": results["metadatas"][0][i],
"score": score,
"id": results["ids"][0][i] if "ids" in results else None
})
logger.info(f"Retrieved {len(formatted_results)} results")
return formatted_results
def get_collection_stats(self) -> Dict[str, Any]:
"""Get statistics about the collection."""
count = self.collection.count()
# Sample a document to get metadata fields
sample = self.collection.peek(limit=1)
metadata_fields = list(sample["metadatas"][0].keys()) if sample["metadatas"] else []
return {
"total_chunks": count,
"collection_name": self.collection_name,
"metadata_fields": metadata_fields,
"embedding_dimension": self.embedding_generator.embedding_dim
}
def delete_collection(self) -> None:
"""Delete the entire collection."""
logger.warning(f"Deleting collection: {self.collection_name}")
self.client.delete_collection(name=self.collection_name)
def reset_collection(self) -> None:
"""Reset collection (delete and recreate)."""
logger.warning("Resetting collection")
try:
self.delete_collection()
except Exception:
pass
self.collection = self._get_or_create_collection()
def create_retriever(
persist_directory: Optional[str] = None,
collection_name: Optional[str] = None,
embedding_generator: Optional[EmbeddingGenerator] = None
) -> DocumentRetriever:
"""
Factory function to create retriever.
Args:
persist_directory: Optional directory override
collection_name: Optional collection name override
embedding_generator: Optional embedding generator override
Returns:
DocumentRetriever instance
"""
from src.config import settings
from src.embeddings import create_embedding_generator
persist_dir = persist_directory or settings.chroma_persist_dir
coll_name = collection_name or settings.collection_name
emb_gen = embedding_generator or create_embedding_generator()
return DocumentRetriever(
persist_directory=persist_dir,
collection_name=coll_name,
embedding_generator=emb_gen
)
|