Spaces:
Sleeping
Sleeping
Commit ·
11ce507
1
Parent(s): 0a02cd7
MOD: index algo
Browse files- exp_pipeline/pipeline.py +4 -1
- retriever/faiss_index.py +9 -10
exp_pipeline/pipeline.py
CHANGED
|
@@ -42,8 +42,11 @@ def run_pipeline(split: str = "train"):
|
|
| 42 |
raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
|
| 43 |
|
| 44 |
# 4. 建立FAISS索引
|
| 45 |
-
index = build_faiss_index(embeddings, texts)
|
| 46 |
logger.info("FAISS index built successfully")
|
|
|
|
|
|
|
|
|
|
| 47 |
return index
|
| 48 |
|
| 49 |
if __name__ == "__main__":
|
|
|
|
| 42 |
raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
|
| 43 |
|
| 44 |
# 4. 建立FAISS索引
|
| 45 |
+
index = build_faiss_index(embeddings, texts, index_type="HNSW")
|
| 46 |
logger.info("FAISS index built successfully")
|
| 47 |
+
# 持久化index到./index文件夹
|
| 48 |
+
index.save("../index/msmarco_hnsw")
|
| 49 |
+
logger.info("FAISS index saved to ./index/msmarco_hnsw")
|
| 50 |
return index
|
| 51 |
|
| 52 |
if __name__ == "__main__":
|
retriever/faiss_index.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# 工厂函数,供pipeline调用
|
| 2 |
-
def build_faiss_index(embeddings, texts, metadata=None, index_type="
|
| 3 |
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
|
| 4 |
raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
|
| 5 |
dimension = embeddings.shape[1]
|
|
@@ -16,7 +16,7 @@ import logging
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
class FAISSIndex:
|
| 19 |
-
def __init__(self, dimension: int, index_type: str = "
|
| 20 |
self.dimension = dimension
|
| 21 |
self.index_type = index_type
|
| 22 |
self.index = None
|
|
@@ -29,29 +29,28 @@ class FAISSIndex:
|
|
| 29 |
"""Build FAISS index from embeddings"""
|
| 30 |
if embeddings.shape[1] != self.dimension:
|
| 31 |
raise ValueError(f"Embedding dimension {embeddings.shape[1]} != {self.dimension}")
|
| 32 |
-
|
| 33 |
# Normalize embeddings for cosine similarity
|
| 34 |
faiss.normalize_L2(embeddings)
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
nlist = min(4096, len(embeddings) // 100)
|
| 39 |
quantizer = faiss.IndexFlatIP(self.dimension)
|
| 40 |
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
|
| 41 |
self.index.train(embeddings)
|
| 42 |
self.index.add(embeddings)
|
| 43 |
else:
|
| 44 |
-
# Flat index for small datasets
|
| 45 |
self.index = faiss.IndexFlatIP(self.dimension)
|
| 46 |
self.index.add(embeddings)
|
| 47 |
-
|
| 48 |
# Store text and metadata
|
| 49 |
for i, text in enumerate(texts):
|
| 50 |
self.id_to_text[i] = text
|
| 51 |
if metadata and i < len(metadata):
|
| 52 |
self.id_to_metadata[i] = metadata[i]
|
| 53 |
-
|
| 54 |
-
logger.info(f"Built FAISS index with {len(embeddings)} vectors")
|
| 55 |
|
| 56 |
def search(self, query_embeddings: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
|
| 57 |
"""Search for similar vectors"""
|
|
|
|
| 1 |
# 工厂函数,供pipeline调用
|
| 2 |
+
def build_faiss_index(embeddings, texts, metadata=None, index_type="HNSW"):
|
| 3 |
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
|
| 4 |
raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
|
| 5 |
dimension = embeddings.shape[1]
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
class FAISSIndex:
|
| 19 |
+
def __init__(self, dimension: int, index_type: str = "HNSW"):
|
| 20 |
self.dimension = dimension
|
| 21 |
self.index_type = index_type
|
| 22 |
self.index = None
|
|
|
|
| 29 |
"""Build FAISS index from embeddings"""
|
| 30 |
if embeddings.shape[1] != self.dimension:
|
| 31 |
raise ValueError(f"Embedding dimension {embeddings.shape[1]} != {self.dimension}")
|
|
|
|
| 32 |
# Normalize embeddings for cosine similarity
|
| 33 |
faiss.normalize_L2(embeddings)
|
| 34 |
+
if self.index_type == "HNSW":
|
| 35 |
+
# HNSW index for fast approximate search
|
| 36 |
+
self.index = faiss.IndexHNSWFlat(self.dimension, 32) # 32 is default M
|
| 37 |
+
self.index.hnsw.efConstruction = 200
|
| 38 |
+
self.index.add(embeddings)
|
| 39 |
+
elif self.index_type == "IVF":
|
| 40 |
nlist = min(4096, len(embeddings) // 100)
|
| 41 |
quantizer = faiss.IndexFlatIP(self.dimension)
|
| 42 |
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
|
| 43 |
self.index.train(embeddings)
|
| 44 |
self.index.add(embeddings)
|
| 45 |
else:
|
|
|
|
| 46 |
self.index = faiss.IndexFlatIP(self.dimension)
|
| 47 |
self.index.add(embeddings)
|
|
|
|
| 48 |
# Store text and metadata
|
| 49 |
for i, text in enumerate(texts):
|
| 50 |
self.id_to_text[i] = text
|
| 51 |
if metadata and i < len(metadata):
|
| 52 |
self.id_to_metadata[i] = metadata[i]
|
| 53 |
+
logger.info(f"Built FAISS {self.index_type} index with {len(embeddings)} vectors")
|
|
|
|
| 54 |
|
| 55 |
def search(self, query_embeddings: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
|
| 56 |
"""Search for similar vectors"""
|