szxllm commited on
Commit
ddb2b53
·
verified ·
1 Parent(s): 263d741

Update grpo_dataloader.py

Browse files
Files changed (1) hide show
  1. grpo_dataloader.py +5 -30
grpo_dataloader.py CHANGED
@@ -1,6 +1,3 @@
1
- """
2
- GRPO专用数据加载器
3
- """
4
  import torch
5
  from torch.utils.data import Dataset, DataLoader
6
  from datasets import load_dataset, interleave_datasets
@@ -18,9 +15,6 @@ from data_config import (
18
 
19
 
20
  class GRPOPromptDataset(Dataset):
21
- """
22
- GRPO Prompt数据集 - 用于生成阶段
23
- """
24
  def __init__(
25
  self,
26
  mix_name: str = 'default',
@@ -35,8 +29,7 @@ class GRPOPromptDataset(Dataset):
35
 
36
  self.tokenizer = tokenizer
37
  self.max_length = max_length
38
-
39
- # 获取混合配置
40
  if mix_name not in GRPO_PROMPT_MIX:
41
  raise ValueError(
42
  f"Unknown mix: {mix_name}. "
@@ -46,12 +39,7 @@ class GRPOPromptDataset(Dataset):
46
  mix_config = GRPO_PROMPT_MIX[mix_name]
47
  dataset_names = mix_config.get('datasets', [])
48
  weights = mix_config.get('weights', [])
49
-
50
- logger.info(f"Loading GRPO prompt mix: {mix_name}")
51
- logger.info(f" Datasets: {dataset_names}")
52
- logger.info(f" Weights: {weights}")
53
-
54
- # 加载数据集
55
  all_datasets = []
56
 
57
  for name in dataset_names:
@@ -60,12 +48,10 @@ class GRPOPromptDataset(Dataset):
60
  continue
61
 
62
  config = GRPO_DATASETS[name]
63
-
64
- # 验证文件存在
65
  data_file = config.get('data_files')
66
  if data_file and not os.path.exists(data_file):
67
  logger.error(f"Data file not found: {data_file}")
68
- logger.error(f"请先运行 download_grpo_datasets.py 下载数据")
69
  continue
70
 
71
  try:
@@ -79,13 +65,10 @@ class GRPOPromptDataset(Dataset):
79
  load_kwargs['data_files'] = config['data_files']
80
 
81
  ds = load_dataset(**load_kwargs)
82
-
83
- # 限制样本数
84
  if config.get('max_samples'):
85
  ds = ds.select(range(min(len(ds), config['max_samples'])))
86
 
87
  all_datasets.append(ds)
88
- logger.info(f" Loaded {name}: {len(ds)} samples")
89
 
90
  except Exception as e:
91
  logger.error(f"Error loading {name}: {e}")
@@ -93,8 +76,7 @@ class GRPOPromptDataset(Dataset):
93
 
94
  if not all_datasets:
95
  raise ValueError("No datasets loaded successfully")
96
-
97
- # 合并数据集
98
  if len(all_datasets) == 1:
99
  self.dataset = all_datasets[0]
100
  else:
@@ -107,7 +89,6 @@ class GRPOPromptDataset(Dataset):
107
  stopping_strategy='all_exhausted'
108
  )
109
 
110
- # 限制总样本数
111
  if max_samples and len(self.dataset) > max_samples:
112
  self.dataset = self.dataset.select(range(max_samples))
113
 
@@ -119,15 +100,12 @@ class GRPOPromptDataset(Dataset):
119
  def __getitem__(self, idx):
120
  try:
121
  sample = self.dataset[idx]
122
-
123
- # 提取prompt
124
  prompt = sample.get('prompt', '')
125
 
126
  if not prompt:
127
  logger.warning(f"Empty prompt at index {idx}")
128
  return None
129
-
130
- # Tokenize (不添加EOS,因为这是prompt)
131
  encoding = self.tokenizer(
132
  prompt,
133
  max_length=self.max_length,
@@ -149,8 +127,6 @@ class GRPOPromptDataset(Dataset):
149
 
150
 
151
  def grpo_collate_fn(batch):
152
- """GRPO专用collate函数"""
153
- # 过滤None
154
  batch = [item for item in batch if item is not None]
155
 
156
  if not batch:
@@ -172,7 +148,6 @@ def create_grpo_prompt_dataloader(
172
  max_samples: Optional[int] = None,
173
  shuffle: bool = True
174
  ):
175
- """创建GRPO prompt数据加载器"""
176
  dataset = GRPOPromptDataset(
177
  mix_name=mix_name,
178
  tokenizer=tokenizer,
 
 
 
 
1
  import torch
2
  from torch.utils.data import Dataset, DataLoader
3
  from datasets import load_dataset, interleave_datasets
 
15
 
16
 
17
  class GRPOPromptDataset(Dataset):
 
 
 
18
  def __init__(
19
  self,
20
  mix_name: str = 'default',
 
29
 
30
  self.tokenizer = tokenizer
31
  self.max_length = max_length
32
+
 
33
  if mix_name not in GRPO_PROMPT_MIX:
34
  raise ValueError(
35
  f"Unknown mix: {mix_name}. "
 
39
  mix_config = GRPO_PROMPT_MIX[mix_name]
40
  dataset_names = mix_config.get('datasets', [])
41
  weights = mix_config.get('weights', [])
42
+
 
 
 
 
 
43
  all_datasets = []
44
 
45
  for name in dataset_names:
 
48
  continue
49
 
50
  config = GRPO_DATASETS[name]
51
+
 
52
  data_file = config.get('data_files')
53
  if data_file and not os.path.exists(data_file):
54
  logger.error(f"Data file not found: {data_file}")
 
55
  continue
56
 
57
  try:
 
65
  load_kwargs['data_files'] = config['data_files']
66
 
67
  ds = load_dataset(**load_kwargs)
 
 
68
  if config.get('max_samples'):
69
  ds = ds.select(range(min(len(ds), config['max_samples'])))
70
 
71
  all_datasets.append(ds)
 
72
 
73
  except Exception as e:
74
  logger.error(f"Error loading {name}: {e}")
 
76
 
77
  if not all_datasets:
78
  raise ValueError("No datasets loaded successfully")
79
+
 
80
  if len(all_datasets) == 1:
81
  self.dataset = all_datasets[0]
82
  else:
 
89
  stopping_strategy='all_exhausted'
90
  )
91
 
 
92
  if max_samples and len(self.dataset) > max_samples:
93
  self.dataset = self.dataset.select(range(max_samples))
94
 
 
100
  def __getitem__(self, idx):
101
  try:
102
  sample = self.dataset[idx]
103
+
 
104
  prompt = sample.get('prompt', '')
105
 
106
  if not prompt:
107
  logger.warning(f"Empty prompt at index {idx}")
108
  return None
 
 
109
  encoding = self.tokenizer(
110
  prompt,
111
  max_length=self.max_length,
 
127
 
128
 
129
  def grpo_collate_fn(batch):
 
 
130
  batch = [item for item in batch if item is not None]
131
 
132
  if not batch:
 
148
  max_samples: Optional[int] = None,
149
  shuffle: bool = True
150
  ):
 
151
  dataset = GRPOPromptDataset(
152
  mix_name=mix_name,
153
  tokenizer=tokenizer,