Spaces:
Sleeping
Sleeping
| """ | |
| End-to-end pipeline for dataset download, preprocessing, embedding, and indexing. | |
| """ | |
| import logging | |
| from data_processing.data_loader import DataLoader | |
| from data_processing.preprocessor import Preprocessor | |
| from retriever.embedder import Embedder | |
| from retriever.faiss_index import build_faiss_index | |
| logger = logging.getLogger(__name__) | |
| def run_pipeline(split: str = "train"): | |
| # 1. 下载MS MARCO Passage Ranking数据集 | |
| data_loader = DataLoader() | |
| raw_data = data_loader.get_passage_dataset(split) | |
| logger.info(f"Loaded {len(raw_data)} samples from MS MARCO Passage Ranking [{split}]") | |
| print("data_loader\n") | |
| # 2. 预处理数据 | |
| preprocessor = Preprocessor() | |
| # HuggingFace datasets对象转list | |
| if hasattr(raw_data, "to_dict"): | |
| raw_data = raw_data.to_dict() | |
| raw_data = [dict(zip(raw_data.keys(), v)) for v in zip(*raw_data.values())] | |
| print("raw_data\n") | |
| # MS MARCO Passage v2.1: 用passages["passage_text"]字段 | |
| passages = [] | |
| for item in raw_data: | |
| if "passages" in item and "passage_text" in item["passages"]: | |
| passages.extend(item["passages"]["passage_text"]) | |
| processed = preprocessor.preprocess_passages(passages) | |
| texts = [p["text"] for p in processed] | |
| print("texts\n") | |
| logger.info(f"Processed {len(texts)} passages") | |
| # 3. 生产embedding | |
| embedder = Embedder(device="cuda") | |
| embeddings = embedder.encode(texts) | |
| print(f"Embedding shape: {getattr(embeddings, 'shape', None)}") | |
| print(f"Texts count: {len(texts)}") | |
| if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0: | |
| raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.") | |
| # 4. 建立FAISS索引 | |
| index = build_faiss_index(embeddings, texts, index_type="HNSW") | |
| logger.info("FAISS index built successfully") | |
| # 持久化index到./index文件夹 | |
| index.save("../index/msmarco_hnsw") | |
| logger.info("FAISS index saved to ./index/msmarco_hnsw") | |
| return index | |
| if __name__ == "__main__": | |
| run_pipeline("train") | |