Tuminha commited on
Commit
6989c33
Β·
verified Β·
1 Parent(s): 46fa8d2

Upload src/embed_index.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/embed_index.py +148 -0
src/embed_index.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Embeddings + FAISS index build/save/load.
3
+ """
4
+ from typing import List
5
+ from pathlib import Path
6
+ import os
7
+ import platform
8
+
9
+ # On macOS, FAISS and PyTorch both ship libomp and loading both copies without
10
+ # telling LibOMP they're duplicates aborts the interpreter. Setting this flag
11
+ # before importing either library prevents the crash when building embeddings.
12
+ if platform.system() == "Darwin":
13
+ os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
14
+
15
+ import numpy as np
16
+ # Import FAISS before torch/sentence-transformers so libomp loads in a safe order on macOS.
17
+ import faiss
18
+ from sentence_transformers import SentenceTransformer
19
+ import pandas as pd
20
+
21
+
22
+ def embed_texts(texts: List[str], model_name: str):
23
+ """
24
+ Return matrix of embeddings for texts.
25
+
26
+ # TODO hints:
27
+ # - Load SentenceTransformer by name; encode with normalize_embeddings=True if available.
28
+ # - Batch encode; return numpy array (n, d).
29
+
30
+ # Acceptance:
31
+ # - Returns embeddings and model reference (if needed).
32
+ """
33
+ model = SentenceTransformer(model_name)
34
+ embeddings = model.encode(texts, normalize_embeddings=True, show_progress_bar=True)
35
+ # Ensure numpy array and float32 for FAISS compatibility
36
+ embeddings = np.array(embeddings, dtype=np.float32)
37
+ return embeddings, model
38
+
39
+
40
+ def build_faiss_index(embeddings):
41
+ """
42
+ Build a FAISS index and return it.
43
+
44
+ # TODO hints:
45
+ # - Use IndexFlatIP or L2; ensure vectors are normalized if using IP.
46
+
47
+ # Acceptance:
48
+ # - Returns a FAISS index ready for add/search.
49
+ """
50
+ # Ensure embeddings are numpy array and float32
51
+ if not isinstance(embeddings, np.ndarray):
52
+ embeddings = np.array(embeddings, dtype=np.float32)
53
+ if embeddings.dtype != np.float32:
54
+ embeddings = embeddings.astype(np.float32)
55
+
56
+ # Make a copy before normalizing to avoid in-place modification issues
57
+ # (normalize_L2 modifies the array in-place)
58
+ embeddings = embeddings.copy()
59
+
60
+ # Ensure embeddings are normalized for IndexFlatIP (inner product = cosine similarity)
61
+ # Note: embeddings should already be normalized from embed_texts, but normalize_L2 is idempotent
62
+ faiss.normalize_L2(embeddings)
63
+
64
+ # Create IndexFlatIP (Inner Product) for normalized vectors
65
+ dimension = embeddings.shape[1]
66
+ index = faiss.IndexFlatIP(dimension)
67
+ index.add(embeddings)
68
+
69
+ return index
70
+
71
+
72
+ def save_index(index, meta_rows, out_dir: str):
73
+ """
74
+ Persist FAISS index + metadata (CSV/Parquet) to data/index/.
75
+
76
+ Args:
77
+ index: FAISS index to save
78
+ meta_rows: List of dicts or DataFrame with metadata (chunk IDs, source info)
79
+ out_dir: Output directory path
80
+
81
+ # TODO hints:
82
+ # - Write index to .faiss and metadata to .parquet with chunk IDs and source info.
83
+
84
+ # Acceptance:
85
+ # - Files exist in data/index/.
86
+ """
87
+ out_path = Path(out_dir)
88
+ out_path.mkdir(parents=True, exist_ok=True)
89
+
90
+ # Save FAISS index
91
+ index_path = out_path / 'index.faiss'
92
+ faiss.write_index(index, str(index_path))
93
+
94
+ # Convert meta_rows to DataFrame if it's a list
95
+ if isinstance(meta_rows, list):
96
+ meta_df = pd.DataFrame(meta_rows)
97
+ elif isinstance(meta_rows, pd.DataFrame):
98
+ meta_df = meta_rows
99
+ else:
100
+ raise ValueError("meta_rows must be a list of dicts or a pandas DataFrame")
101
+
102
+ # Save metadata
103
+ metadata_path = out_path / 'metadata.parquet'
104
+ meta_df.to_parquet(metadata_path, index=False)
105
+
106
+ print(f"βœ… Saved index to: {index_path}")
107
+ print(f"βœ… Saved metadata to: {metadata_path}")
108
+ print(f" Index size: {index.ntotal} vectors")
109
+ print(f" Metadata rows: {len(meta_df)}")
110
+
111
+
112
+ def load_index(in_dir: str):
113
+ """
114
+ Load FAISS index + metadata.
115
+
116
+ Args:
117
+ in_dir: Input directory path containing index.faiss and metadata.parquet
118
+
119
+ # TODO hints:
120
+ # - Read index and matching metadata frame; sanity-check row counts.
121
+
122
+ # Acceptance:
123
+ # - Returns (index, metadata_df).
124
+ """
125
+ in_path = Path(in_dir)
126
+
127
+ # Load FAISS index
128
+ index_path = in_path / 'index.faiss'
129
+ if not index_path.exists():
130
+ raise FileNotFoundError(f"Index file not found: {index_path}")
131
+ index = faiss.read_index(str(index_path))
132
+
133
+ # Load metadata
134
+ metadata_path = in_path / 'metadata.parquet'
135
+ if not metadata_path.exists():
136
+ raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
137
+ meta_df = pd.read_parquet(metadata_path)
138
+
139
+ # Sanity check: row counts should match
140
+ if index.ntotal != len(meta_df):
141
+ raise ValueError(
142
+ f"Mismatch: index has {index.ntotal} vectors but metadata has {len(meta_df)} rows"
143
+ )
144
+
145
+ print(f"βœ… Loaded index: {index.ntotal} vectors, dimension {index.d}")
146
+ print(f"βœ… Loaded metadata: {len(meta_df)} rows")
147
+
148
+ return index, meta_df