safe_rag / exp_pipeline /pipeline.py
goodmodeler's picture
ADD: test log
9339c96
"""
End-to-end pipeline for dataset download, preprocessing, embedding, and indexing.
"""
import logging
from data_processing.data_loader import DataLoader
from data_processing.preprocessor import Preprocessor
from retriever.embedder import Embedder
from retriever.faiss_index import build_faiss_index
logger = logging.getLogger(__name__)
def run_pipeline(split: str = "train"):
# 1. 下载MS MARCO Passage Ranking数据集
data_loader = DataLoader()
raw_data = data_loader.get_passage_dataset(split)
logger.info(f"Loaded {len(raw_data)} samples from MS MARCO Passage Ranking [{split}]")
print("data_loader\n")
# 2. 预处理数据
preprocessor = Preprocessor()
# HuggingFace datasets对象转list
if hasattr(raw_data, "to_dict"):
raw_data = raw_data.to_dict()
raw_data = [dict(zip(raw_data.keys(), v)) for v in zip(*raw_data.values())]
print("raw_data\n")
# MS MARCO Passage v2.1: 用passages["passage_text"]字段
passages = []
for item in raw_data:
if "passages" in item and "passage_text" in item["passages"]:
passages.extend(item["passages"]["passage_text"])
processed = preprocessor.preprocess_passages(passages)
texts = [p["text"] for p in processed]
print("texts\n")
logger.info(f"Processed {len(texts)} passages")
# 3. 生产embedding
embedder = Embedder(device="cuda")
embeddings = embedder.encode(texts)
print(f"Embedding shape: {getattr(embeddings, 'shape', None)}")
print(f"Texts count: {len(texts)}")
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
# 4. 建立FAISS索引
index = build_faiss_index(embeddings, texts, index_type="HNSW")
logger.info("FAISS index built successfully")
# 持久化index到./index文件夹
index.save("../index/msmarco_hnsw")
logger.info("FAISS index saved to ./index/msmarco_hnsw")
return index
if __name__ == "__main__":
run_pipeline("train")