goodmodeler commited on
Commit
11ce507
·
1 Parent(s): 0a02cd7

MOD: index algo

Browse files
exp_pipeline/pipeline.py CHANGED
@@ -42,8 +42,11 @@ def run_pipeline(split: str = "train"):
42
  raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
43
 
44
  # 4. 建立FAISS索引
45
- index = build_faiss_index(embeddings, texts)
46
  logger.info("FAISS index built successfully")
 
 
 
47
  return index
48
 
49
  if __name__ == "__main__":
 
42
  raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
43
 
44
  # 4. 建立FAISS索引
45
+ index = build_faiss_index(embeddings, texts, index_type="HNSW")
46
  logger.info("FAISS index built successfully")
47
+ # 持久化index到./index文件夹
48
+ index.save("../index/msmarco_hnsw")
49
+ logger.info("FAISS index saved to ./index/msmarco_hnsw")
50
  return index
51
 
52
  if __name__ == "__main__":
retriever/faiss_index.py CHANGED
@@ -1,5 +1,5 @@
1
  # 工厂函数,供pipeline调用
2
- def build_faiss_index(embeddings, texts, metadata=None, index_type="IVF"):
3
  if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
4
  raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
5
  dimension = embeddings.shape[1]
@@ -16,7 +16,7 @@ import logging
16
  logger = logging.getLogger(__name__)
17
 
18
  class FAISSIndex:
19
- def __init__(self, dimension: int, index_type: str = "IVF"):
20
  self.dimension = dimension
21
  self.index_type = index_type
22
  self.index = None
@@ -29,29 +29,28 @@ class FAISSIndex:
29
  """Build FAISS index from embeddings"""
30
  if embeddings.shape[1] != self.dimension:
31
  raise ValueError(f"Embedding dimension {embeddings.shape[1]} != {self.dimension}")
32
-
33
  # Normalize embeddings for cosine similarity
34
  faiss.normalize_L2(embeddings)
35
-
36
- if self.index_type == "IVF":
37
- # IVF index for large datasets
 
 
 
38
  nlist = min(4096, len(embeddings) // 100)
39
  quantizer = faiss.IndexFlatIP(self.dimension)
40
  self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
41
  self.index.train(embeddings)
42
  self.index.add(embeddings)
43
  else:
44
- # Flat index for small datasets
45
  self.index = faiss.IndexFlatIP(self.dimension)
46
  self.index.add(embeddings)
47
-
48
  # Store text and metadata
49
  for i, text in enumerate(texts):
50
  self.id_to_text[i] = text
51
  if metadata and i < len(metadata):
52
  self.id_to_metadata[i] = metadata[i]
53
-
54
- logger.info(f"Built FAISS index with {len(embeddings)} vectors")
55
 
56
  def search(self, query_embeddings: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
57
  """Search for similar vectors"""
 
1
  # 工厂函数,供pipeline调用
2
+ def build_faiss_index(embeddings, texts, metadata=None, index_type="HNSW"):
3
  if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
4
  raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
5
  dimension = embeddings.shape[1]
 
16
  logger = logging.getLogger(__name__)
17
 
18
  class FAISSIndex:
19
+ def __init__(self, dimension: int, index_type: str = "HNSW"):
20
  self.dimension = dimension
21
  self.index_type = index_type
22
  self.index = None
 
29
  """Build FAISS index from embeddings"""
30
  if embeddings.shape[1] != self.dimension:
31
  raise ValueError(f"Embedding dimension {embeddings.shape[1]} != {self.dimension}")
 
32
  # Normalize embeddings for cosine similarity
33
  faiss.normalize_L2(embeddings)
34
+ if self.index_type == "HNSW":
35
+ # HNSW index for fast approximate search
36
+ self.index = faiss.IndexHNSWFlat(self.dimension, 32) # 32 is default M
37
+ self.index.hnsw.efConstruction = 200
38
+ self.index.add(embeddings)
39
+ elif self.index_type == "IVF":
40
  nlist = min(4096, len(embeddings) // 100)
41
  quantizer = faiss.IndexFlatIP(self.dimension)
42
  self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist)
43
  self.index.train(embeddings)
44
  self.index.add(embeddings)
45
  else:
 
46
  self.index = faiss.IndexFlatIP(self.dimension)
47
  self.index.add(embeddings)
 
48
  # Store text and metadata
49
  for i, text in enumerate(texts):
50
  self.id_to_text[i] = text
51
  if metadata and i < len(metadata):
52
  self.id_to_metadata[i] = metadata[i]
53
+ logger.info(f"Built FAISS {self.index_type} index with {len(embeddings)} vectors")
 
54
 
55
  def search(self, query_embeddings: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]:
56
  """Search for similar vectors"""