Spaces:
Sleeping
Sleeping
Commit
·
0a02cd7
1
Parent(s):
8a3396b
ADD: pipeline
Browse files- data_processing/data_loader.py +16 -61
- data_processing/preprocessor.py +15 -9
- exp_pipeline/pipeline.py +50 -0
- retriever/faiss_index.py +8 -0
data_processing/data_loader.py
CHANGED
|
@@ -1,74 +1,29 @@
|
|
| 1 |
-
|
| 2 |
import logging
|
|
|
|
| 3 |
|
| 4 |
logger = logging.getLogger(__name__)
|
| 5 |
|
| 6 |
class DataLoader:
|
| 7 |
def __init__(self, cache_dir: str = "./cache"):
|
| 8 |
self.cache_dir = cache_dir
|
| 9 |
-
|
| 10 |
-
def
|
| 11 |
-
"""Load
|
| 12 |
-
try:
|
| 13 |
-
# Simplified version - return empty list for demo
|
| 14 |
-
logger.info(f"Loading HotpotQA {split} (simplified version)")
|
| 15 |
-
return []
|
| 16 |
-
except Exception as e:
|
| 17 |
-
logger.error(f"Failed to load HotpotQA: {e}")
|
| 18 |
-
raise
|
| 19 |
-
|
| 20 |
-
def load_triviaqa(self, split: str = "train"):
|
| 21 |
-
"""Load TriviaQA dataset for open-domain QA (simplified version)"""
|
| 22 |
-
try:
|
| 23 |
-
logger.info(f"Loading TriviaQA {split} (simplified version)")
|
| 24 |
-
return []
|
| 25 |
-
except Exception as e:
|
| 26 |
-
logger.error(f"Failed to load TriviaQA: {e}")
|
| 27 |
-
raise
|
| 28 |
-
|
| 29 |
-
def load_wikipedia(self, language: str = "en", date: str = "20231101"):
|
| 30 |
-
"""Load Wikipedia dump for knowledge base (simplified version)"""
|
| 31 |
-
try:
|
| 32 |
-
logger.info(f"Loading Wikipedia {language} (simplified version)")
|
| 33 |
-
return []
|
| 34 |
-
except Exception as e:
|
| 35 |
-
logger.error(f"Failed to load Wikipedia: {e}")
|
| 36 |
-
raise
|
| 37 |
-
|
| 38 |
-
def load_nq_open(self, split: str = "train"):
|
| 39 |
-
"""Load Natural Questions Open dataset (simplified version)"""
|
| 40 |
-
try:
|
| 41 |
-
logger.info(f"Loading NQ Open {split} (simplified version)")
|
| 42 |
-
return []
|
| 43 |
-
except Exception as e:
|
| 44 |
-
logger.error(f"Failed to load NQ Open: {e}")
|
| 45 |
-
raise
|
| 46 |
-
|
| 47 |
-
def get_qa_datasets(self) -> Dict[str, List]:
|
| 48 |
-
"""Load all QA datasets (simplified version)"""
|
| 49 |
-
datasets = {}
|
| 50 |
try:
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
logger.info("All QA datasets loaded successfully")
|
| 55 |
-
return datasets
|
| 56 |
except Exception as e:
|
| 57 |
-
logger.error(f"Failed to load
|
| 58 |
raise
|
| 59 |
-
|
| 60 |
-
def
|
| 61 |
-
"""Load
|
| 62 |
try:
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
return
|
| 66 |
-
"Machine learning is a subset of artificial intelligence that focuses on algorithms.",
|
| 67 |
-
"The capital of France is Paris.",
|
| 68 |
-
"Python is a popular programming language used for data science.",
|
| 69 |
-
"The Great Wall of China is one of the most famous landmarks in the world.",
|
| 70 |
-
"Climate change refers to long-term shifts in global temperatures and weather patterns."
|
| 71 |
-
]
|
| 72 |
except Exception as e:
|
| 73 |
-
logger.error(f"Failed to load
|
| 74 |
raise
|
|
|
|
| 1 |
+
|
| 2 |
import logging
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
| 6 |
|
| 7 |
class DataLoader:
|
| 8 |
def __init__(self, cache_dir: str = "./cache"):
|
| 9 |
self.cache_dir = cache_dir
|
| 10 |
+
|
| 11 |
+
def load_msmarco_passage(self, split: str = "train"):
|
| 12 |
+
"""Load MS MARCO Passage Ranking dataset from Hugging Face (v2.1)"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
+
logger.info(f"Downloading MS MARCO Passage Ranking {split} (v2.1) from Hugging Face")
|
| 15 |
+
ds = load_dataset("ms_marco", "v2.1", split=split)
|
| 16 |
+
return ds
|
|
|
|
|
|
|
| 17 |
except Exception as e:
|
| 18 |
+
logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
|
| 19 |
raise
|
| 20 |
+
|
| 21 |
+
def get_passage_dataset(self, split: str = "train"):
|
| 22 |
+
"""Load MS MARCO Passage Ranking dataset"""
|
| 23 |
try:
|
| 24 |
+
ds = self.load_msmarco_passage(split)
|
| 25 |
+
logger.info("MS MARCO Passage Ranking loaded successfully")
|
| 26 |
+
return ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
except Exception as e:
|
| 28 |
+
logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
|
| 29 |
raise
|
data_processing/preprocessor.py
CHANGED
|
@@ -66,17 +66,25 @@ class Preprocessor:
|
|
| 66 |
return processed
|
| 67 |
|
| 68 |
def preprocess_qa_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 69 |
-
"""Preprocess QA data"""
|
| 70 |
processed = []
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
for item in data:
|
| 73 |
if not isinstance(item, dict):
|
| 74 |
continue
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
processed_item = {
|
| 81 |
'question': self.clean_text(question),
|
| 82 |
'answer': self.clean_text(answer),
|
|
@@ -85,9 +93,7 @@ class Preprocessor:
|
|
| 85 |
'answer_tokens': self.tokenize(answer),
|
| 86 |
'context_tokens': self.tokenize(context)
|
| 87 |
}
|
| 88 |
-
|
| 89 |
processed.append(processed_item)
|
| 90 |
-
|
| 91 |
return processed
|
| 92 |
|
| 93 |
def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
|
|
|
|
| 66 |
return processed
|
| 67 |
|
| 68 |
def preprocess_qa_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 69 |
+
"""Preprocess QA data, auto convert dict/list fields to string"""
|
| 70 |
processed = []
|
| 71 |
+
def to_str(val):
|
| 72 |
+
if isinstance(val, dict):
|
| 73 |
+
# 拼接所有value
|
| 74 |
+
return " ".join([to_str(v) for v in val.values()])
|
| 75 |
+
elif isinstance(val, list):
|
| 76 |
+
return " ".join([to_str(v) for v in val])
|
| 77 |
+
elif val is None:
|
| 78 |
+
return ""
|
| 79 |
+
return str(val)
|
| 80 |
+
|
| 81 |
for item in data:
|
| 82 |
if not isinstance(item, dict):
|
| 83 |
continue
|
| 84 |
+
question = to_str(item.get('question', ''))
|
| 85 |
+
answer = to_str(item.get('answer', ''))
|
| 86 |
+
context = to_str(item.get('context', ''))
|
| 87 |
+
|
|
|
|
| 88 |
processed_item = {
|
| 89 |
'question': self.clean_text(question),
|
| 90 |
'answer': self.clean_text(answer),
|
|
|
|
| 93 |
'answer_tokens': self.tokenize(answer),
|
| 94 |
'context_tokens': self.tokenize(context)
|
| 95 |
}
|
|
|
|
| 96 |
processed.append(processed_item)
|
|
|
|
| 97 |
return processed
|
| 98 |
|
| 99 |
def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
|
exp_pipeline/pipeline.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
End-to-end pipeline for dataset download, preprocessing, embedding, and indexing.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from data_processing.data_loader import DataLoader
|
| 6 |
+
from data_processing.preprocessor import Preprocessor
|
| 7 |
+
from retriever.embedder import Embedder
|
| 8 |
+
from retriever.faiss_index import build_faiss_index
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def run_pipeline(split: str = "train"):
|
| 14 |
+
# 1. 下载MS MARCO Passage Ranking数据集
|
| 15 |
+
data_loader = DataLoader()
|
| 16 |
+
raw_data = data_loader.get_passage_dataset(split)
|
| 17 |
+
logger.info(f"Loaded {len(raw_data)} samples from MS MARCO Passage Ranking [{split}]")
|
| 18 |
+
|
| 19 |
+
# 2. 预处理数据
|
| 20 |
+
preprocessor = Preprocessor()
|
| 21 |
+
# HuggingFace datasets对象转list
|
| 22 |
+
if hasattr(raw_data, "to_dict"):
|
| 23 |
+
raw_data = raw_data.to_dict()
|
| 24 |
+
raw_data = [dict(zip(raw_data.keys(), v)) for v in zip(*raw_data.values())]
|
| 25 |
+
|
| 26 |
+
# MS MARCO Passage v2.1: 用passages["passage_text"]字段
|
| 27 |
+
passages = []
|
| 28 |
+
for item in raw_data:
|
| 29 |
+
if "passages" in item and "passage_text" in item["passages"]:
|
| 30 |
+
passages.extend(item["passages"]["passage_text"])
|
| 31 |
+
processed = preprocessor.preprocess_passages(passages)
|
| 32 |
+
texts = [p["text"] for p in processed]
|
| 33 |
+
|
| 34 |
+
logger.info(f"Processed {len(texts)} passages")
|
| 35 |
+
|
| 36 |
+
# 3. 生产embedding
|
| 37 |
+
embedder = Embedder(device="cuda")
|
| 38 |
+
embeddings = embedder.encode(texts)
|
| 39 |
+
print(f"Embedding shape: {getattr(embeddings, 'shape', None)}")
|
| 40 |
+
print(f"Texts count: {len(texts)}")
|
| 41 |
+
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
|
| 42 |
+
raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
|
| 43 |
+
|
| 44 |
+
# 4. 建立FAISS索引
|
| 45 |
+
index = build_faiss_index(embeddings, texts)
|
| 46 |
+
logger.info("FAISS index built successfully")
|
| 47 |
+
return index
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
run_pipeline("train")
|
retriever/faiss_index.py
CHANGED
|
@@ -1,3 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import faiss
|
| 2 |
import numpy as np
|
| 3 |
import pickle
|
|
|
|
| 1 |
+
# 工厂函数,供pipeline调用
|
| 2 |
+
def build_faiss_index(embeddings, texts, metadata=None, index_type="IVF"):
|
| 3 |
+
if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
|
| 4 |
+
raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
|
| 5 |
+
dimension = embeddings.shape[1]
|
| 6 |
+
index = FAISSIndex(dimension, index_type=index_type)
|
| 7 |
+
index.build_index(embeddings, texts, metadata)
|
| 8 |
+
return index
|
| 9 |
import faiss
|
| 10 |
import numpy as np
|
| 11 |
import pickle
|