researchpilot-api / src /embeddings /embedding_cache.py
Subhadip007's picture
feat: vector database indexing complete
daafb32
"""
Disk-based cache for computed embeddings.
PROBLEM WE'RE SOLVING:
Embedding 15,664 chunks takes ~30-60 minutes on CPU.
If you restart your pipeline or add 10 new papers,
you don't want to re-embed the 15,654 unchanged chunks.
SOLUTION:
Save embeddings to disk as numpy .npy files.
Build an index that maps chunk_id -> array row index.
On next run, load from disk instead of recomputing.
STORAGE FORMAT:
data/embeddings/
|-- embeddings.npy <- numpy array, shape (N, 768)
|-- chunk_ids.npy <- chunk IDs in same order as rows
|-- embedding_index.json <- metadata + chunk_id -> row mapping
WHY NUMPY .npy OVER JSON:
Storing 15,664 * 768 floats as JSON = ~90MB of text
Storing as .npy binary = ~46MB + loads 100x faster
"""
import json
import numpy as np
from pathlib import Path
from src.utils.logger import get_logger
from config.settings import EMBEDDINGS_DIR, EMBEDDING_DIMENSION
logger = get_logger(__name__)
class EmbeddingCache:
"""
Manages persistent storage of chunk embeddings
"""
def __init__(self):
self.embedding_file = EMBEDDINGS_DIR / "embeddings.npy"
self.chunk_ids_file = EMBEDDINGS_DIR / "chunk_ids.npy"
self.index_file = EMBEDDINGS_DIR / "embedding_index.json"
# In-memory state
self._embeddings: np.ndarray = None # Shape (N, 768)
self._chunk_ids: list[str] = None # length N
self._id_to_row: dict = None # chunk_id -> row index
def exists(self) -> bool:
"""Check if cached embeddings exists on disk"""
return (
self.embedding_file.exists() and
self.chunk_ids_file.exists() and
self.index_file.exists()
)
def load(self) -> bool:
"""
Load embeddings from disk into memory
Returns True if loaded successfully. False if no cache exists
"""
if not self.exists():
logger.info("No embedding cache found on disk")
return False
logger.info("Loading embeddings from disk cache...")
# Load numpy arrays - mmap_mode='r' means memory-mapped read
# WHY mmap: The array is NOT fully loaded into RAM immediately
# It's read from disk only when specific rows are accessed
# This is critical for large arrays on machines with limited RAM
self._embeddings = np.load(
str(self.embedding_file),
mmap_mode = 'r'
)
# chunk_ids are stored as numpy array of strings
# We convert back to Python list for easier indexing
self._chunk_ids = list(
np.load(str(self.chunk_ids_file), allow_pickle = True)
)
# Build the reverse lookup: chunk_id -> row number
self._id_to_row = {
chunk_id: idx
for idx, chunk_id in enumerate(self._chunk_ids)
}
logger.info(
f"Cache loaded: {self._embeddings.shape[0]:,} embeddings"
f"dimension = {self._embeddings.shape[1]}"
)
return True
def save(self, embeddings: np.ndarray, chunk_ids: list[str]):
"""
Save embeddings and their chunk IDs to disk.
Args:
embeddings: numpy array of shape (N, 768)
chunk_ids: list of N chunk ID strings (same order as rows)
"""
assert len(embeddings) == len(chunk_ids), (
f"Mismatch {len(embeddings)} embeddings vs {len(chunk_ids)} IDs"
)
logger.info(f"Saving {len(embeddings):,} embeddings to disk...")
# Save the embedding matrix
np.save(str(self.embedding_file), embeddings)
# Save chunk IDs as numpy object array (handles strings)
np.save(str(self.chunk_ids_file), np.array(chunk_ids, dtype = object))
# Save human-readable index file
index = {
"total_embeddings": len(embeddings),
"embedding_dimension": embeddings.shape[1],
"model_name": "BAAI/bge-base-en-v1.5",
"chunk_id_sample": chunk_ids[:5], # First 5 for verification
}
with open(self.index_file, "w", encoding = 'utf-8') as f:
json.dump(index, f, indent = 2)
# Update in-memory state
self._embeddings = embeddings
self._chunk_ids = chunk_ids
self._id_to_row = {cid: i for i, cid in enumerate(chunk_ids)}
logger.info(
f"Saved embeddings: {self.embedding_file}"
f"({self.embedding_file.stat().st_size / 1024 / 1024:.1f} MB)"
)
def get_embeddings(self, chunk_id: str) -> np.ndarray | None:
"""Get the embedding vector for a specific chunk ID."""
if self._id_to_row is None:
return None
row = self._id_to_row.get(chunk_id)
if row is None:
return None
return self._embeddings[row]
def get_all(self) -> tuple[np.ndarray, list[str]]:
"""Return all embeddings and their chunk IDs."""
return self._embeddings, self._chunk_ids
@property
def size(self) -> int:
"""Number of cached embeddings"""
if self._chunk_ids is None:
return 0
return len(self._chunk_ids)