File size: 5,291 Bytes
daafb32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
"""
Disk-based cache for computed embeddings.

PROBLEM WE'RE SOLVING:
    Embedding 15,664 chunks takes ~30-60 minutes on CPU.
    If you restart your pipeline or add 10 new papers,
    you don't want to re-embed the 15,654 unchanged chunks.

SOLUTION:
    Save embeddings to disk as numpy .npy files.
    Build an index that maps chunk_id -> array row index.
    On next run, load from disk instead of recomputing.

STORAGE FORMAT:
    data/embeddings/
    |-- embeddings.npy        <- numpy array, shape (N, 768)
    |-- chunk_ids.npy         <- chunk IDs in same order as rows  
    |-- embedding_index.json  <- metadata + chunk_id -> row mapping

WHY NUMPY .npy OVER JSON:
    Storing 15,664 * 768 floats as JSON = ~90MB of text
    Storing as .npy binary = ~46MB + loads 100x faster
"""

import json
import numpy as np
from pathlib import Path

from src.utils.logger import get_logger
from config.settings import EMBEDDINGS_DIR, EMBEDDING_DIMENSION

logger = get_logger(__name__)



class EmbeddingCache:
    """
    Manages persistent storage of chunk embeddings
    """


    def  __init__(self):
        self.embedding_file  = EMBEDDINGS_DIR / "embeddings.npy"
        self.chunk_ids_file  = EMBEDDINGS_DIR / "chunk_ids.npy"
        self.index_file      = EMBEDDINGS_DIR / "embedding_index.json"


        # In-memory state
        self._embeddings: np.ndarray = None     # Shape (N, 768)
        self._chunk_ids: list[str]   = None     # length N   
        self._id_to_row:    dict     = None     # chunk_id -> row index


    def exists(self) -> bool:
        """Check if cached embeddings exists on disk"""
        return (
            self.embedding_file.exists() and 
            self.chunk_ids_file.exists() and
            self.index_file.exists()
        )


    def load(self) -> bool:
        """
        Load embeddings from disk into memory

        Returns True if loaded successfully. False if no cache exists
        """
        if not self.exists():
            logger.info("No embedding cache found on disk")
            return False

        logger.info("Loading embeddings from disk cache...")


        # Load numpy arrays - mmap_mode='r' means memory-mapped read
        # WHY mmap: The array is NOT fully loaded into RAM immediately
        # It's read from disk only when specific rows are accessed
        # This is critical for large arrays on machines with limited RAM
        self._embeddings = np.load(
            str(self.embedding_file),
            mmap_mode = 'r'
        )

        # chunk_ids are stored as numpy array of strings
        # We convert back to Python list for easier indexing
        self._chunk_ids = list(
            np.load(str(self.chunk_ids_file), allow_pickle = True)
        )

        # Build the reverse lookup: chunk_id -> row number
        self._id_to_row = {
            chunk_id: idx
            for idx, chunk_id in enumerate(self._chunk_ids)
        }

        logger.info(
            f"Cache loaded: {self._embeddings.shape[0]:,} embeddings"
            f"dimension = {self._embeddings.shape[1]}"
        )

        return True

    
    def save(self, embeddings: np.ndarray, chunk_ids: list[str]):
        """
        Save embeddings and their chunk IDs to disk.

        Args:
            embeddings: numpy array of shape (N, 768)
            chunk_ids:  list of N chunk ID strings (same order as rows)
        """

        assert len(embeddings) == len(chunk_ids), (
            f"Mismatch {len(embeddings)} embeddings vs {len(chunk_ids)} IDs"
        )

        logger.info(f"Saving {len(embeddings):,} embeddings to disk...")

        # Save the embedding matrix
        np.save(str(self.embedding_file), embeddings)  

        # Save chunk IDs as numpy object array (handles strings)
        np.save(str(self.chunk_ids_file), np.array(chunk_ids, dtype = object))

        # Save human-readable index file
        index = {
            "total_embeddings": len(embeddings),
            "embedding_dimension": embeddings.shape[1],
            "model_name": "BAAI/bge-base-en-v1.5",
            "chunk_id_sample": chunk_ids[:5],   # First 5 for verification
        }

        with open(self.index_file, "w", encoding = 'utf-8') as f:
            json.dump(index, f, indent = 2)



        # Update in-memory state
        self._embeddings = embeddings
        self._chunk_ids  = chunk_ids
        self._id_to_row  = {cid: i for i, cid in enumerate(chunk_ids)}


        logger.info(
            f"Saved embeddings: {self.embedding_file}"
            f"({self.embedding_file.stat().st_size / 1024 / 1024:.1f} MB)"
        )


    def get_embeddings(self, chunk_id: str) -> np.ndarray | None:
        """Get the embedding vector for a specific chunk ID."""
        if self._id_to_row is None:
            return None
        
        row = self._id_to_row.get(chunk_id)

        if row is None:
            return None
        
        return self._embeddings[row]



    def get_all(self) -> tuple[np.ndarray, list[str]]:
        """Return all embeddings and their chunk IDs."""
        return self._embeddings, self._chunk_ids

    
    @property
    def size(self) -> int:
        """Number of cached embeddings"""
        if self._chunk_ids is None:
            return 0

        return len(self._chunk_ids)