szxllm commited on
Commit
e576d4e
·
verified ·
1 Parent(s): 810cabd

Update dcpo.py

Browse files
Files changed (1) hide show
  1. dcpo.py +28 -70
dcpo.py CHANGED
@@ -5,19 +5,11 @@ import numpy as np
5
  import logging
6
  import hashlib
7
 
8
- # ========== 关键修改1: 导入改进的验证器 ==========
9
  from math_verifier import MathReward
10
- # 如果要使用渐进式奖励,取消下面的注释:
11
- # from progressive_reward import ProgressiveMathReward
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
  class DCPOTrainer:
16
- """
17
- DCPO: Dynamic Clipping Policy Optimization Trainer
18
- 修复版:包含 DDP 设备修复、优化器状态恢复修复、显存优化、attention_mask 和 position_ids 修复。
19
- 改进版:集成改进的奖励验证器
20
- """
21
  def __init__(
22
  self,
23
  actor_model,
@@ -33,21 +25,18 @@ class DCPOTrainer:
33
  use_amp: bool = True,
34
  gradient_accumulation_steps: int = 1,
35
  inner_batch_size: int = 4,
36
- # ========== 关键修改2: 新增参数 ==========
37
- use_reference_comparison: bool = True, # 是否使用参考推理对比
38
- use_progressive_reward: bool = False, # 是否使用渐进式奖励
39
- phase1_steps: int = 2000, # 渐进式阶段1步数
40
- phase2_steps: int = 4000 # 渐进式阶段2步数
41
  ):
42
  self.actor = actor_model
43
  self.ref_model = ref_model
44
  self.tokenizer = tokenizer
45
 
46
- # ========== 关键修改3: 初始化验证器 ==========
47
  self.use_progressive_reward = use_progressive_reward
48
 
49
  if use_progressive_reward:
50
- # 使用渐进式奖励(实验性)
51
  from progressive_reward import ProgressiveMathReward
52
  self.math_verifier = ProgressiveMathReward(
53
  use_reference_comparison=use_reference_comparison,
@@ -55,13 +44,10 @@ class DCPOTrainer:
55
  phase2_steps=phase2_steps,
56
  verbose=True
57
  )
58
- logger.info("使用渐进式奖励验证器")
59
  else:
60
- # 使用标准改进版验证器(推荐)
61
  self.math_verifier = MathReward(
62
  use_reference_comparison=use_reference_comparison
63
  )
64
- logger.info(f"使用改进版奖励验证器 (reference_comparison={use_reference_comparison})")
65
 
66
  self.group_size = group_size
67
  self.eps_low = eps_low
@@ -74,51 +60,41 @@ class DCPOTrainer:
74
  self.gradient_accumulation_steps = gradient_accumulation_steps
75
  self.inner_batch_size = inner_batch_size
76
  self.experience_buffer = []
77
-
78
- # ========== 关键修改4: 添加当前步数跟踪(用于渐进式奖励) ==========
79
  self.current_step = 0
80
-
81
- # 自动获取设备:兼容 DDP
82
  if hasattr(actor_model, 'module'):
83
  self.device = next(actor_model.module.parameters()).device
84
  else:
85
  self.device = next(actor_model.parameters()).device
86
-
87
- # 优化器初始化
88
  self.optimizer = torch.optim.AdamW(
89
  self.actor.parameters(),
90
  lr=learning_rate,
91
  weight_decay=0.01
92
  )
93
-
94
- # 混合精度 Scaler
95
  self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
96
 
97
  if self.ref_model:
98
  self.ref_model.eval()
99
  self.ref_model.requires_grad_(False)
100
 
101
- # SAS 统计缓存
102
  self.sas_stats = {}
103
 
104
  def _get_stable_hash(self, text):
105
- """生成跨进程/跨运行一致的哈希值"""
106
  return hashlib.md5(text.encode('utf-8')).hexdigest()
107
 
108
  def state_dict(self):
109
- """导出 Trainer 状态"""
110
  return {
111
  'optimizer_state_dict': self.optimizer.state_dict(),
112
  'sas_stats': self.sas_stats,
113
  'scaler_state_dict': self.scaler.state_dict() if self.scaler is not None else None,
114
- 'current_step': self.current_step # 保存当前步数
115
  }
116
 
117
  def load_state_dict(self, state_dict):
118
- """加载 Trainer 状态,并修复优化器 Tensor 设备问题"""
119
  if 'optimizer_state_dict' in state_dict:
120
  self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
121
- # 强制将优化器状态移动到当前 GPU
122
  for state in self.optimizer.state.values():
123
  for k, v in state.items():
124
  if isinstance(v, torch.Tensor):
@@ -126,24 +102,20 @@ class DCPOTrainer:
126
 
127
  if 'sas_stats' in state_dict:
128
  self.sas_stats = state_dict['sas_stats']
129
- logger.info(f"Loaded SAS stats for {len(self.sas_stats)} prompts")
130
 
131
  if 'scaler_state_dict' in state_dict and state_dict['scaler_state_dict'] is not None:
132
  self.scaler.load_state_dict(state_dict['scaler_state_dict'])
133
 
134
  if 'current_step' in state_dict:
135
  self.current_step = state_dict['current_step']
136
- logger.info(f"Loaded current_step: {self.current_step}")
137
 
138
- # ========== 关键修改5: 新增方法用于更新步数 ==========
139
  def update_step(self, step):
140
- """更新当前训练步数(用于渐进式奖励)"""
141
  self.current_step = step
142
  if self.use_progressive_reward:
143
  self.math_verifier.update_step(step)
144
 
145
  def _get_unwrapped_model(self, model):
146
- """辅助函数:获取原始模型(剥离 DDP wrapper)"""
147
  if hasattr(model, 'module'):
148
  return model.module
149
  return model
@@ -199,17 +171,13 @@ class DCPOTrainer:
199
  expanded_gts = []
200
  for gt in ground_truths:
201
  expanded_gts.extend([gt] * self.group_size)
202
-
203
- # ========== 计算奖励(使用改进的验证器)==========
204
- # 改进的验证器会自动处理 reasoning 和 reference_completion 字段
205
  raw_rewards = self.math_verifier.compute_rewards(full_responses_for_reward, expanded_gts)
206
  rewards_tensor = torch.tensor(raw_rewards, device=self.device, dtype=torch.float32)
207
 
208
- # 计算旧策略的 Log Probs
209
  gen_mask = (generated_ids != self.tokenizer.pad_token_id).long()
210
  full_attention_mask = torch.cat([attention_mask_repeated, gen_mask], dim=1)
211
 
212
- # ✅ 修复:构建正确的 position_ids(考虑左 padding)
213
  batch_size = sequences.size(0)
214
  seq_len = sequences.size(1)
215
  position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device)
@@ -229,7 +197,7 @@ class DCPOTrainer:
229
  actor_out = self.actor(
230
  full_input_data,
231
  attention_mask=full_attention_mask,
232
- position_ids=position_ids # ✅ 添加 position_ids
233
  )
234
 
235
  logits = actor_out['logits'][:, :-1, :]
@@ -237,14 +205,13 @@ class DCPOTrainer:
237
  log_probs = F.log_softmax(logits, dim=-1)
238
  per_token_log_probs = torch.gather(log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
239
 
240
- # ✅ 显存优化:提前移到 CPU
241
  return {
242
  'prompts_text': prompts_text,
243
- 'sequences': sequences.detach().cpu(), # ✅ 移到 CPU
244
- 'old_log_probs': per_token_log_probs.detach().cpu(), # ✅ 移到 CPU
245
- 'rewards': rewards_tensor.cpu(), # ✅ 移到 CPU
246
- 'attention_mask': full_attention_mask.cpu(), # ✅ 新增:保存 attention_mask
247
- 'position_ids': position_ids.cpu(), # ✅ 新增:保存 position_ids
248
  'prompt_length': prompt_len
249
  }
250
 
@@ -268,7 +235,6 @@ class DCPOTrainer:
268
  mu_old = stats['mu_total']
269
  var_old = stats['var_total']
270
 
271
- # 增量更新公式
272
  mu_total = (mu_new + (i - 1) * mu_old) / i
273
  term3 = ((i - 1) / i) * (mu_old - mu_new)**2
274
  var_total = (var_new + (i - 1) * var_old + term3) / i
@@ -304,7 +270,6 @@ class DCPOTrainer:
304
  return torch.cat(final_advantages)
305
 
306
  def train_step(self, experience):
307
- """执行训练步骤:梯度累积 -> PPO/GRPO Update"""
308
  self.experience_buffer.append(experience)
309
  if len(self.experience_buffer) < self.gradient_accumulation_steps:
310
  return None
@@ -312,7 +277,7 @@ class DCPOTrainer:
312
  all_advantages = []
313
  for exp in self.experience_buffer:
314
  adv = self._compute_sas_advantages(exp)
315
- exp['advantages'] = adv.detach() # 保持在 CPU
316
  all_advantages.append(exp['advantages'])
317
 
318
  self.actor.train()
@@ -325,17 +290,15 @@ class DCPOTrainer:
325
 
326
  padded_seqs = []
327
  padded_old_lp = []
328
- padded_attn_masks = [] # ✅ 新增
329
- padded_pos_ids = [] # ✅ 新增
330
  prompt_lens_list = []
331
 
332
  for e in self.experience_buffer:
333
  padded_seqs.append(pad_tensor(e['sequences'], max_seq_len, self.tokenizer.pad_token_id))
334
-
335
- # ✅ 修复:使用 0.0 填充(exp(0)=1,数值稳定)
336
  padded_old_lp.append(pad_tensor(e['old_log_probs'], max_lp_len, 0.0))
337
 
338
- # ✅ 新增:padding attention_mask 和 position_ids
339
  padded_attn_masks.append(pad_tensor(e['attention_mask'], max_seq_len, 0))
340
  padded_pos_ids.append(pad_tensor(e['position_ids'], max_seq_len, 0))
341
 
@@ -346,8 +309,8 @@ class DCPOTrainer:
346
  cat_old_log_probs = torch.cat(padded_old_lp, dim=0)
347
  cat_advantages = torch.cat(all_advantages, dim=0)
348
  cat_prompt_lens = torch.tensor(prompt_lens_list)
349
- cat_attention_masks = torch.cat(padded_attn_masks, dim=0) # ✅ 新增
350
- cat_position_ids = torch.cat(padded_pos_ids, dim=0) # ✅ 新增
351
 
352
  self.experience_buffer = []
353
 
@@ -356,8 +319,8 @@ class DCPOTrainer:
356
  cat_old_log_probs,
357
  cat_advantages,
358
  cat_prompt_lens,
359
- cat_attention_masks, # ✅ 新增
360
- cat_position_ids # ✅ 新增
361
  )
362
  dataloader = DataLoader(dataset, batch_size=self.inner_batch_size, shuffle=True)
363
 
@@ -366,13 +329,11 @@ class DCPOTrainer:
366
 
367
  for _ in range(self.grpo_epochs):
368
  for batch in dataloader:
369
- # ✅ 解包所有数据
370
  seqs, old_lp, advs, p_lens, attn_masks, pos_ids = [b.to(self.device) for b in batch]
371
 
372
  input_data = {'segments': [{'type': 'text', 'data': seqs, 'modality_id': 0}]}
373
 
374
  with torch.amp.autocast('cuda', enabled=self.use_amp):
375
- # ✅ 修复:传入 attention_mask 和 position_ids
376
  outputs = self.actor(
377
  input_data,
378
  attention_mask=attn_masks,
@@ -383,22 +344,19 @@ class DCPOTrainer:
383
 
384
  new_log_probs = F.log_softmax(logits, dim=-1)
385
  new_token_log_probs = torch.gather(new_log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
386
-
387
- # Mask 构建(保持原有逻辑)
388
  mask = torch.zeros_like(new_token_log_probs)
389
  for i, pl in enumerate(p_lens):
390
  pl_val = int(pl.item())
391
  start_idx = max(0, pl_val - 1)
392
  if start_idx < mask.size(1):
393
  mask[i, start_idx:] = 1.0
394
-
395
- # 过滤 padding 和无效的 old_log_probs
396
  is_padding = (targets == self.tokenizer.pad_token_id)
397
- is_valid_old_lp = (old_lp != 0.0) # ✅ 修改:过滤填充值
398
  mask = mask * (~is_padding).float() * is_valid_old_lp.float()
399
 
400
- # 修复:DCPO Loss 计算 - 数值稳定性
401
- q_probs = torch.exp(old_lp).clamp(min=1e-10, max=1.0) # ✅ clamp 避免除零
402
  term_low = 1.0 - (4.0 * self.eps_low) / q_probs
403
  lower_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_low, min=0.0))
404
  term_high = 1.0 + (4.0 * self.eps_high) / q_probs
 
5
  import logging
6
  import hashlib
7
 
 
8
  from math_verifier import MathReward
 
 
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class DCPOTrainer:
 
 
 
 
 
13
  def __init__(
14
  self,
15
  actor_model,
 
25
  use_amp: bool = True,
26
  gradient_accumulation_steps: int = 1,
27
  inner_batch_size: int = 4,
28
+ use_reference_comparison: bool = True,
29
+ use_progressive_reward: bool = False,
30
+ phase1_steps: int = 2000,
31
+ phase2_steps: int = 4000
 
32
  ):
33
  self.actor = actor_model
34
  self.ref_model = ref_model
35
  self.tokenizer = tokenizer
36
 
 
37
  self.use_progressive_reward = use_progressive_reward
38
 
39
  if use_progressive_reward:
 
40
  from progressive_reward import ProgressiveMathReward
41
  self.math_verifier = ProgressiveMathReward(
42
  use_reference_comparison=use_reference_comparison,
 
44
  phase2_steps=phase2_steps,
45
  verbose=True
46
  )
 
47
  else:
 
48
  self.math_verifier = MathReward(
49
  use_reference_comparison=use_reference_comparison
50
  )
 
51
 
52
  self.group_size = group_size
53
  self.eps_low = eps_low
 
60
  self.gradient_accumulation_steps = gradient_accumulation_steps
61
  self.inner_batch_size = inner_batch_size
62
  self.experience_buffer = []
 
 
63
  self.current_step = 0
64
+
 
65
  if hasattr(actor_model, 'module'):
66
  self.device = next(actor_model.module.parameters()).device
67
  else:
68
  self.device = next(actor_model.parameters()).device
69
+
 
70
  self.optimizer = torch.optim.AdamW(
71
  self.actor.parameters(),
72
  lr=learning_rate,
73
  weight_decay=0.01
74
  )
75
+
 
76
  self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
77
 
78
  if self.ref_model:
79
  self.ref_model.eval()
80
  self.ref_model.requires_grad_(False)
81
 
 
82
  self.sas_stats = {}
83
 
84
  def _get_stable_hash(self, text):
 
85
  return hashlib.md5(text.encode('utf-8')).hexdigest()
86
 
87
  def state_dict(self):
 
88
  return {
89
  'optimizer_state_dict': self.optimizer.state_dict(),
90
  'sas_stats': self.sas_stats,
91
  'scaler_state_dict': self.scaler.state_dict() if self.scaler is not None else None,
92
+ 'current_step': self.current_step
93
  }
94
 
95
  def load_state_dict(self, state_dict):
 
96
  if 'optimizer_state_dict' in state_dict:
97
  self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
 
98
  for state in self.optimizer.state.values():
99
  for k, v in state.items():
100
  if isinstance(v, torch.Tensor):
 
102
 
103
  if 'sas_stats' in state_dict:
104
  self.sas_stats = state_dict['sas_stats']
105
+
106
 
107
  if 'scaler_state_dict' in state_dict and state_dict['scaler_state_dict'] is not None:
108
  self.scaler.load_state_dict(state_dict['scaler_state_dict'])
109
 
110
  if 'current_step' in state_dict:
111
  self.current_step = state_dict['current_step']
 
112
 
 
113
  def update_step(self, step):
 
114
  self.current_step = step
115
  if self.use_progressive_reward:
116
  self.math_verifier.update_step(step)
117
 
118
  def _get_unwrapped_model(self, model):
 
119
  if hasattr(model, 'module'):
120
  return model.module
121
  return model
 
171
  expanded_gts = []
172
  for gt in ground_truths:
173
  expanded_gts.extend([gt] * self.group_size)
174
+
 
 
175
  raw_rewards = self.math_verifier.compute_rewards(full_responses_for_reward, expanded_gts)
176
  rewards_tensor = torch.tensor(raw_rewards, device=self.device, dtype=torch.float32)
177
 
 
178
  gen_mask = (generated_ids != self.tokenizer.pad_token_id).long()
179
  full_attention_mask = torch.cat([attention_mask_repeated, gen_mask], dim=1)
180
 
 
181
  batch_size = sequences.size(0)
182
  seq_len = sequences.size(1)
183
  position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device)
 
197
  actor_out = self.actor(
198
  full_input_data,
199
  attention_mask=full_attention_mask,
200
+ position_ids=position_ids
201
  )
202
 
203
  logits = actor_out['logits'][:, :-1, :]
 
205
  log_probs = F.log_softmax(logits, dim=-1)
206
  per_token_log_probs = torch.gather(log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
207
 
 
208
  return {
209
  'prompts_text': prompts_text,
210
+ 'sequences': sequences.detach().cpu(),
211
+ 'old_log_probs': per_token_log_probs.detach().cpu(),
212
+ 'rewards': rewards_tensor.cpu(),
213
+ 'attention_mask': full_attention_mask.cpu(),
214
+ 'position_ids': position_ids.cpu(),
215
  'prompt_length': prompt_len
216
  }
217
 
 
235
  mu_old = stats['mu_total']
236
  var_old = stats['var_total']
237
 
 
238
  mu_total = (mu_new + (i - 1) * mu_old) / i
239
  term3 = ((i - 1) / i) * (mu_old - mu_new)**2
240
  var_total = (var_new + (i - 1) * var_old + term3) / i
 
270
  return torch.cat(final_advantages)
271
 
272
  def train_step(self, experience):
 
273
  self.experience_buffer.append(experience)
274
  if len(self.experience_buffer) < self.gradient_accumulation_steps:
275
  return None
 
277
  all_advantages = []
278
  for exp in self.experience_buffer:
279
  adv = self._compute_sas_advantages(exp)
280
+ exp['advantages'] = adv.detach()
281
  all_advantages.append(exp['advantages'])
282
 
283
  self.actor.train()
 
290
 
291
  padded_seqs = []
292
  padded_old_lp = []
293
+ padded_attn_masks = []
294
+ padded_pos_ids = []
295
  prompt_lens_list = []
296
 
297
  for e in self.experience_buffer:
298
  padded_seqs.append(pad_tensor(e['sequences'], max_seq_len, self.tokenizer.pad_token_id))
299
+
 
300
  padded_old_lp.append(pad_tensor(e['old_log_probs'], max_lp_len, 0.0))
301
 
 
302
  padded_attn_masks.append(pad_tensor(e['attention_mask'], max_seq_len, 0))
303
  padded_pos_ids.append(pad_tensor(e['position_ids'], max_seq_len, 0))
304
 
 
309
  cat_old_log_probs = torch.cat(padded_old_lp, dim=0)
310
  cat_advantages = torch.cat(all_advantages, dim=0)
311
  cat_prompt_lens = torch.tensor(prompt_lens_list)
312
+ cat_attention_masks = torch.cat(padded_attn_masks, dim=0)
313
+ cat_position_ids = torch.cat(padded_pos_ids, dim=0)
314
 
315
  self.experience_buffer = []
316
 
 
319
  cat_old_log_probs,
320
  cat_advantages,
321
  cat_prompt_lens,
322
+ cat_attention_masks,
323
+ cat_position_ids
324
  )
325
  dataloader = DataLoader(dataset, batch_size=self.inner_batch_size, shuffle=True)
326
 
 
329
 
330
  for _ in range(self.grpo_epochs):
331
  for batch in dataloader:
 
332
  seqs, old_lp, advs, p_lens, attn_masks, pos_ids = [b.to(self.device) for b in batch]
333
 
334
  input_data = {'segments': [{'type': 'text', 'data': seqs, 'modality_id': 0}]}
335
 
336
  with torch.amp.autocast('cuda', enabled=self.use_amp):
 
337
  outputs = self.actor(
338
  input_data,
339
  attention_mask=attn_masks,
 
344
 
345
  new_log_probs = F.log_softmax(logits, dim=-1)
346
  new_token_log_probs = torch.gather(new_log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
347
+
 
348
  mask = torch.zeros_like(new_token_log_probs)
349
  for i, pl in enumerate(p_lens):
350
  pl_val = int(pl.item())
351
  start_idx = max(0, pl_val - 1)
352
  if start_idx < mask.size(1):
353
  mask[i, start_idx:] = 1.0
354
+
 
355
  is_padding = (targets == self.tokenizer.pad_token_id)
356
+ is_valid_old_lp = (old_lp != 0.0)
357
  mask = mask * (~is_padding).float() * is_valid_old_lp.float()
358
 
359
+ q_probs = torch.exp(old_lp).clamp(min=1e-10, max=1.0)
 
360
  term_low = 1.0 - (4.0 * self.eps_low) / q_probs
361
  lower_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_low, min=0.0))
362
  term_high = 1.0 + (4.0 * self.eps_high) / q_probs