goodmodeler commited on
Commit
0a02cd7
·
1 Parent(s): 8a3396b

ADD: pipeline

Browse files
data_processing/data_loader.py CHANGED
@@ -1,74 +1,29 @@
1
- from typing import Dict, List, Optional
2
  import logging
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
6
  class DataLoader:
7
  def __init__(self, cache_dir: str = "./cache"):
8
  self.cache_dir = cache_dir
9
-
10
- def load_hotpotqa(self, split: str = "train"):
11
- """Load HotpotQA dataset for multi-hop reasoning (simplified version)"""
12
- try:
13
- # Simplified version - return empty list for demo
14
- logger.info(f"Loading HotpotQA {split} (simplified version)")
15
- return []
16
- except Exception as e:
17
- logger.error(f"Failed to load HotpotQA: {e}")
18
- raise
19
-
20
- def load_triviaqa(self, split: str = "train"):
21
- """Load TriviaQA dataset for open-domain QA (simplified version)"""
22
- try:
23
- logger.info(f"Loading TriviaQA {split} (simplified version)")
24
- return []
25
- except Exception as e:
26
- logger.error(f"Failed to load TriviaQA: {e}")
27
- raise
28
-
29
- def load_wikipedia(self, language: str = "en", date: str = "20231101"):
30
- """Load Wikipedia dump for knowledge base (simplified version)"""
31
- try:
32
- logger.info(f"Loading Wikipedia {language} (simplified version)")
33
- return []
34
- except Exception as e:
35
- logger.error(f"Failed to load Wikipedia: {e}")
36
- raise
37
-
38
- def load_nq_open(self, split: str = "train"):
39
- """Load Natural Questions Open dataset (simplified version)"""
40
- try:
41
- logger.info(f"Loading NQ Open {split} (simplified version)")
42
- return []
43
- except Exception as e:
44
- logger.error(f"Failed to load NQ Open: {e}")
45
- raise
46
-
47
- def get_qa_datasets(self) -> Dict[str, List]:
48
- """Load all QA datasets (simplified version)"""
49
- datasets = {}
50
  try:
51
- datasets['hotpotqa'] = self.load_hotpotqa()
52
- datasets['triviaqa'] = self.load_triviaqa()
53
- datasets['nq_open'] = self.load_nq_open()
54
- logger.info("All QA datasets loaded successfully")
55
- return datasets
56
  except Exception as e:
57
- logger.error(f"Failed to load QA datasets: {e}")
58
  raise
59
-
60
- def get_knowledge_base(self) -> List[str]:
61
- """Load knowledge base (simplified version)"""
62
  try:
63
- logger.info("Loading knowledge base (simplified version)")
64
- # Return some sample passages for demo
65
- return [
66
- "Machine learning is a subset of artificial intelligence that focuses on algorithms.",
67
- "The capital of France is Paris.",
68
- "Python is a popular programming language used for data science.",
69
- "The Great Wall of China is one of the most famous landmarks in the world.",
70
- "Climate change refers to long-term shifts in global temperatures and weather patterns."
71
- ]
72
  except Exception as e:
73
- logger.error(f"Failed to load knowledge base: {e}")
74
  raise
 
1
+
2
  import logging
3
+ from datasets import load_dataset
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
  class DataLoader:
8
  def __init__(self, cache_dir: str = "./cache"):
9
  self.cache_dir = cache_dir
10
+
11
+ def load_msmarco_passage(self, split: str = "train"):
12
+ """Load MS MARCO Passage Ranking dataset from Hugging Face (v2.1)"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  try:
14
+ logger.info(f"Downloading MS MARCO Passage Ranking {split} (v2.1) from Hugging Face")
15
+ ds = load_dataset("ms_marco", "v2.1", split=split)
16
+ return ds
 
 
17
  except Exception as e:
18
+ logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
19
  raise
20
+
21
+ def get_passage_dataset(self, split: str = "train"):
22
+ """Load MS MARCO Passage Ranking dataset"""
23
  try:
24
+ ds = self.load_msmarco_passage(split)
25
+ logger.info("MS MARCO Passage Ranking loaded successfully")
26
+ return ds
 
 
 
 
 
 
27
  except Exception as e:
28
+ logger.error(f"Failed to load MS MARCO Passage Ranking: {e}")
29
  raise
data_processing/preprocessor.py CHANGED
@@ -66,17 +66,25 @@ class Preprocessor:
66
  return processed
67
 
68
  def preprocess_qa_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
69
- """Preprocess QA data"""
70
  processed = []
71
-
 
 
 
 
 
 
 
 
 
72
  for item in data:
73
  if not isinstance(item, dict):
74
  continue
75
-
76
- question = item.get('question', '')
77
- answer = item.get('answer', '')
78
- context = item.get('context', '')
79
-
80
  processed_item = {
81
  'question': self.clean_text(question),
82
  'answer': self.clean_text(answer),
@@ -85,9 +93,7 @@ class Preprocessor:
85
  'answer_tokens': self.tokenize(answer),
86
  'context_tokens': self.tokenize(context)
87
  }
88
-
89
  processed.append(processed_item)
90
-
91
  return processed
92
 
93
  def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
 
66
  return processed
67
 
68
  def preprocess_qa_data(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
69
+ """Preprocess QA data, auto convert dict/list fields to string"""
70
  processed = []
71
+ def to_str(val):
72
+ if isinstance(val, dict):
73
+ # 拼接所有value
74
+ return " ".join([to_str(v) for v in val.values()])
75
+ elif isinstance(val, list):
76
+ return " ".join([to_str(v) for v in val])
77
+ elif val is None:
78
+ return ""
79
+ return str(val)
80
+
81
  for item in data:
82
  if not isinstance(item, dict):
83
  continue
84
+ question = to_str(item.get('question', ''))
85
+ answer = to_str(item.get('answer', ''))
86
+ context = to_str(item.get('context', ''))
87
+
 
88
  processed_item = {
89
  'question': self.clean_text(question),
90
  'answer': self.clean_text(answer),
 
93
  'answer_tokens': self.tokenize(answer),
94
  'context_tokens': self.tokenize(context)
95
  }
 
96
  processed.append(processed_item)
 
97
  return processed
98
 
99
  def create_chunks(self, text: str, chunk_size: int = 512, overlap: int = 50) -> List[str]:
exp_pipeline/pipeline.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ End-to-end pipeline for dataset download, preprocessing, embedding, and indexing.
3
+ """
4
+ import logging
5
+ from data_processing.data_loader import DataLoader
6
+ from data_processing.preprocessor import Preprocessor
7
+ from retriever.embedder import Embedder
8
+ from retriever.faiss_index import build_faiss_index
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def run_pipeline(split: str = "train"):
14
+ # 1. 下载MS MARCO Passage Ranking数据集
15
+ data_loader = DataLoader()
16
+ raw_data = data_loader.get_passage_dataset(split)
17
+ logger.info(f"Loaded {len(raw_data)} samples from MS MARCO Passage Ranking [{split}]")
18
+
19
+ # 2. 预处理数据
20
+ preprocessor = Preprocessor()
21
+ # HuggingFace datasets对象转list
22
+ if hasattr(raw_data, "to_dict"):
23
+ raw_data = raw_data.to_dict()
24
+ raw_data = [dict(zip(raw_data.keys(), v)) for v in zip(*raw_data.values())]
25
+
26
+ # MS MARCO Passage v2.1: 用passages["passage_text"]字段
27
+ passages = []
28
+ for item in raw_data:
29
+ if "passages" in item and "passage_text" in item["passages"]:
30
+ passages.extend(item["passages"]["passage_text"])
31
+ processed = preprocessor.preprocess_passages(passages)
32
+ texts = [p["text"] for p in processed]
33
+
34
+ logger.info(f"Processed {len(texts)} passages")
35
+
36
+ # 3. 生产embedding
37
+ embedder = Embedder(device="cuda")
38
+ embeddings = embedder.encode(texts)
39
+ print(f"Embedding shape: {getattr(embeddings, 'shape', None)}")
40
+ print(f"Texts count: {len(texts)}")
41
+ if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
42
+ raise ValueError("Embeddings is empty or not a 2D array. Check input texts and embedding model.")
43
+
44
+ # 4. 建立FAISS索引
45
+ index = build_faiss_index(embeddings, texts)
46
+ logger.info("FAISS index built successfully")
47
+ return index
48
+
49
+ if __name__ == "__main__":
50
+ run_pipeline("train")
retriever/faiss_index.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import faiss
2
  import numpy as np
3
  import pickle
 
1
+ # 工厂函数,供pipeline调用
2
+ def build_faiss_index(embeddings, texts, metadata=None, index_type="IVF"):
3
+ if embeddings is None or not hasattr(embeddings, 'shape') or len(embeddings.shape) != 2 or embeddings.shape[0] == 0:
4
+ raise ValueError(f"Embeddings is empty or not a 2D array. Got shape: {getattr(embeddings, 'shape', None)}")
5
+ dimension = embeddings.shape[1]
6
+ index = FAISSIndex(dimension, index_type=index_type)
7
+ index.build_index(embeddings, texts, metadata)
8
+ return index
9
  import faiss
10
  import numpy as np
11
  import pickle