Spaces:
Running
Running
File size: 5,291 Bytes
daafb32 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | """
Disk-based cache for computed embeddings.
PROBLEM WE'RE SOLVING:
Embedding 15,664 chunks takes ~30-60 minutes on CPU.
If you restart your pipeline or add 10 new papers,
you don't want to re-embed the 15,654 unchanged chunks.
SOLUTION:
Save embeddings to disk as numpy .npy files.
Build an index that maps chunk_id -> array row index.
On next run, load from disk instead of recomputing.
STORAGE FORMAT:
data/embeddings/
|-- embeddings.npy <- numpy array, shape (N, 768)
|-- chunk_ids.npy <- chunk IDs in same order as rows
|-- embedding_index.json <- metadata + chunk_id -> row mapping
WHY NUMPY .npy OVER JSON:
Storing 15,664 * 768 floats as JSON = ~90MB of text
Storing as .npy binary = ~46MB + loads 100x faster
"""
import json
import numpy as np
from pathlib import Path
from src.utils.logger import get_logger
from config.settings import EMBEDDINGS_DIR, EMBEDDING_DIMENSION
logger = get_logger(__name__)
class EmbeddingCache:
"""
Manages persistent storage of chunk embeddings
"""
def __init__(self):
self.embedding_file = EMBEDDINGS_DIR / "embeddings.npy"
self.chunk_ids_file = EMBEDDINGS_DIR / "chunk_ids.npy"
self.index_file = EMBEDDINGS_DIR / "embedding_index.json"
# In-memory state
self._embeddings: np.ndarray = None # Shape (N, 768)
self._chunk_ids: list[str] = None # length N
self._id_to_row: dict = None # chunk_id -> row index
def exists(self) -> bool:
"""Check if cached embeddings exists on disk"""
return (
self.embedding_file.exists() and
self.chunk_ids_file.exists() and
self.index_file.exists()
)
def load(self) -> bool:
"""
Load embeddings from disk into memory
Returns True if loaded successfully. False if no cache exists
"""
if not self.exists():
logger.info("No embedding cache found on disk")
return False
logger.info("Loading embeddings from disk cache...")
# Load numpy arrays - mmap_mode='r' means memory-mapped read
# WHY mmap: The array is NOT fully loaded into RAM immediately
# It's read from disk only when specific rows are accessed
# This is critical for large arrays on machines with limited RAM
self._embeddings = np.load(
str(self.embedding_file),
mmap_mode = 'r'
)
# chunk_ids are stored as numpy array of strings
# We convert back to Python list for easier indexing
self._chunk_ids = list(
np.load(str(self.chunk_ids_file), allow_pickle = True)
)
# Build the reverse lookup: chunk_id -> row number
self._id_to_row = {
chunk_id: idx
for idx, chunk_id in enumerate(self._chunk_ids)
}
logger.info(
f"Cache loaded: {self._embeddings.shape[0]:,} embeddings"
f"dimension = {self._embeddings.shape[1]}"
)
return True
def save(self, embeddings: np.ndarray, chunk_ids: list[str]):
"""
Save embeddings and their chunk IDs to disk.
Args:
embeddings: numpy array of shape (N, 768)
chunk_ids: list of N chunk ID strings (same order as rows)
"""
assert len(embeddings) == len(chunk_ids), (
f"Mismatch {len(embeddings)} embeddings vs {len(chunk_ids)} IDs"
)
logger.info(f"Saving {len(embeddings):,} embeddings to disk...")
# Save the embedding matrix
np.save(str(self.embedding_file), embeddings)
# Save chunk IDs as numpy object array (handles strings)
np.save(str(self.chunk_ids_file), np.array(chunk_ids, dtype = object))
# Save human-readable index file
index = {
"total_embeddings": len(embeddings),
"embedding_dimension": embeddings.shape[1],
"model_name": "BAAI/bge-base-en-v1.5",
"chunk_id_sample": chunk_ids[:5], # First 5 for verification
}
with open(self.index_file, "w", encoding = 'utf-8') as f:
json.dump(index, f, indent = 2)
# Update in-memory state
self._embeddings = embeddings
self._chunk_ids = chunk_ids
self._id_to_row = {cid: i for i, cid in enumerate(chunk_ids)}
logger.info(
f"Saved embeddings: {self.embedding_file}"
f"({self.embedding_file.stat().st_size / 1024 / 1024:.1f} MB)"
)
def get_embeddings(self, chunk_id: str) -> np.ndarray | None:
"""Get the embedding vector for a specific chunk ID."""
if self._id_to_row is None:
return None
row = self._id_to_row.get(chunk_id)
if row is None:
return None
return self._embeddings[row]
def get_all(self) -> tuple[np.ndarray, list[str]]:
"""Return all embeddings and their chunk IDs."""
return self._embeddings, self._chunk_ids
@property
def size(self) -> int:
"""Number of cached embeddings"""
if self._chunk_ids is None:
return 0
return len(self._chunk_ids) |