szxllm commited on
Commit
810cabd
·
verified ·
1 Parent(s): 1ef8665

Upload 2 files

Browse files
Files changed (2) hide show
  1. dcpo.py +431 -0
  2. dcpo_train.py +404 -0
dcpo.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import DataLoader, TensorDataset
4
+ import numpy as np
5
+ import logging
6
+ import hashlib
7
+
8
+ # ========== 关键修改1: 导入改进的验证器 ==========
9
+ from math_verifier import MathReward
10
+ # 如果要使用渐进式奖励,取消下面的注释:
11
+ # from progressive_reward import ProgressiveMathReward
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class DCPOTrainer:
16
+ """
17
+ DCPO: Dynamic Clipping Policy Optimization Trainer
18
+ 修复版:包含 DDP 设备修复、优化器状态恢复修复、显存优化、attention_mask 和 position_ids 修复。
19
+ 改进版:集成改进的奖励验证器
20
+ """
21
+ def __init__(
22
+ self,
23
+ actor_model,
24
+ ref_model,
25
+ tokenizer,
26
+ learning_rate: float = 1e-6,
27
+ group_size: int = 4,
28
+ eps_low: float = 0.16,
29
+ eps_high: float = 0.2,
30
+ r_max: float = 10.0,
31
+ grpo_epochs: int = 1,
32
+ max_grad_norm: float = 1.0,
33
+ use_amp: bool = True,
34
+ gradient_accumulation_steps: int = 1,
35
+ inner_batch_size: int = 4,
36
+ # ========== 关键修改2: 新增参数 ==========
37
+ use_reference_comparison: bool = True, # 是否使用参考推理对比
38
+ use_progressive_reward: bool = False, # 是否使用渐进式奖励
39
+ phase1_steps: int = 2000, # 渐进式阶段1步数
40
+ phase2_steps: int = 4000 # 渐进式阶段2步数
41
+ ):
42
+ self.actor = actor_model
43
+ self.ref_model = ref_model
44
+ self.tokenizer = tokenizer
45
+
46
+ # ========== 关键修改3: 初始化验证器 ==========
47
+ self.use_progressive_reward = use_progressive_reward
48
+
49
+ if use_progressive_reward:
50
+ # 使用渐进式奖励(实验性)
51
+ from progressive_reward import ProgressiveMathReward
52
+ self.math_verifier = ProgressiveMathReward(
53
+ use_reference_comparison=use_reference_comparison,
54
+ phase1_steps=phase1_steps,
55
+ phase2_steps=phase2_steps,
56
+ verbose=True
57
+ )
58
+ logger.info("使用渐进式奖励验证器")
59
+ else:
60
+ # 使用标准改进版验证器(推荐)
61
+ self.math_verifier = MathReward(
62
+ use_reference_comparison=use_reference_comparison
63
+ )
64
+ logger.info(f"使用改进版奖励验证器 (reference_comparison={use_reference_comparison})")
65
+
66
+ self.group_size = group_size
67
+ self.eps_low = eps_low
68
+ self.eps_high = eps_high
69
+ self.r_max = r_max
70
+ self.grpo_epochs = grpo_epochs
71
+ self.use_amp = use_amp
72
+ self.max_grad_norm = max_grad_norm
73
+
74
+ self.gradient_accumulation_steps = gradient_accumulation_steps
75
+ self.inner_batch_size = inner_batch_size
76
+ self.experience_buffer = []
77
+
78
+ # ========== 关键修改4: 添加当前步数跟踪(用于渐进式奖励) ==========
79
+ self.current_step = 0
80
+
81
+ # 自动获取设备:兼容 DDP
82
+ if hasattr(actor_model, 'module'):
83
+ self.device = next(actor_model.module.parameters()).device
84
+ else:
85
+ self.device = next(actor_model.parameters()).device
86
+
87
+ # 优化器初始化
88
+ self.optimizer = torch.optim.AdamW(
89
+ self.actor.parameters(),
90
+ lr=learning_rate,
91
+ weight_decay=0.01
92
+ )
93
+
94
+ # 混合精度 Scaler
95
+ self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
96
+
97
+ if self.ref_model:
98
+ self.ref_model.eval()
99
+ self.ref_model.requires_grad_(False)
100
+
101
+ # SAS 统计缓存
102
+ self.sas_stats = {}
103
+
104
+ def _get_stable_hash(self, text):
105
+ """生成跨进程/跨运行一致的哈希值"""
106
+ return hashlib.md5(text.encode('utf-8')).hexdigest()
107
+
108
+ def state_dict(self):
109
+ """导出 Trainer 状态"""
110
+ return {
111
+ 'optimizer_state_dict': self.optimizer.state_dict(),
112
+ 'sas_stats': self.sas_stats,
113
+ 'scaler_state_dict': self.scaler.state_dict() if self.scaler is not None else None,
114
+ 'current_step': self.current_step # 保存当前步数
115
+ }
116
+
117
+ def load_state_dict(self, state_dict):
118
+ """加载 Trainer 状态,并修复优化器 Tensor 设备问题"""
119
+ if 'optimizer_state_dict' in state_dict:
120
+ self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
121
+ # 强制将优化器状态移动到当前 GPU
122
+ for state in self.optimizer.state.values():
123
+ for k, v in state.items():
124
+ if isinstance(v, torch.Tensor):
125
+ state[k] = v.to(self.device)
126
+
127
+ if 'sas_stats' in state_dict:
128
+ self.sas_stats = state_dict['sas_stats']
129
+ logger.info(f"Loaded SAS stats for {len(self.sas_stats)} prompts")
130
+
131
+ if 'scaler_state_dict' in state_dict and state_dict['scaler_state_dict'] is not None:
132
+ self.scaler.load_state_dict(state_dict['scaler_state_dict'])
133
+
134
+ if 'current_step' in state_dict:
135
+ self.current_step = state_dict['current_step']
136
+ logger.info(f"Loaded current_step: {self.current_step}")
137
+
138
+ # ========== 关键修改5: 新增方法用于更新步数 ==========
139
+ def update_step(self, step):
140
+ """更新当前训练步数(用于渐进式奖励)"""
141
+ self.current_step = step
142
+ if self.use_progressive_reward:
143
+ self.math_verifier.update_step(step)
144
+
145
+ def _get_unwrapped_model(self, model):
146
+ """辅助函数:获取原始模型(剥离 DDP wrapper)"""
147
+ if hasattr(model, 'module'):
148
+ return model.module
149
+ return model
150
+
151
+ @torch.no_grad()
152
+ def generate_and_prepare(self, prompt_batch, max_gen_len=512, temperature=1.0):
153
+ self.actor.eval()
154
+ prompts_text = prompt_batch['prompt']
155
+ ground_truths = prompt_batch['ground_truth']
156
+
157
+ inputs = self.tokenizer(
158
+ prompts_text,
159
+ return_tensors="pt",
160
+ padding=True,
161
+ padding_side="left"
162
+ ).to(self.device)
163
+
164
+ prompts_ids = inputs['input_ids']
165
+ attention_mask = inputs['attention_mask']
166
+ prompt_len = int(prompts_ids.shape[1])
167
+
168
+ prompts_ids_repeated = prompts_ids.repeat_interleave(self.group_size, dim=0)
169
+ attention_mask_repeated = attention_mask.repeat_interleave(self.group_size, dim=0)
170
+
171
+ input_data = {
172
+ 'segments': [{'type': 'text', 'data': prompts_ids_repeated, 'modality_id': 0}],
173
+ 'attention_mask': attention_mask_repeated
174
+ }
175
+
176
+ # 推理时使用 unwrapped model
177
+ unwrapped_actor = self._get_unwrapped_model(self.actor)
178
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
179
+ generated_ids = unwrapped_actor.generate(
180
+ input_data,
181
+ max_new_tokens=max_gen_len,
182
+ do_sample=True,
183
+ temperature=temperature,
184
+ top_p=0.95,
185
+ pad_token_id=self.tokenizer.pad_token_id
186
+ )
187
+
188
+ sequences = torch.cat([prompts_ids_repeated, generated_ids], dim=1)
189
+ decoded_responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
190
+
191
+ # 处理 Think 标签
192
+ full_responses_for_reward = []
193
+ for r in decoded_responses:
194
+ if not r.strip().startswith("<think>"):
195
+ full_responses_for_reward.append("<think>\n" + r.strip())
196
+ else:
197
+ full_responses_for_reward.append(r)
198
+
199
+ expanded_gts = []
200
+ for gt in ground_truths:
201
+ expanded_gts.extend([gt] * self.group_size)
202
+
203
+ # ========== 计算奖励(使用改进的验证器)==========
204
+ # 改进的验证器会自动处理 reasoning 和 reference_completion 字段
205
+ raw_rewards = self.math_verifier.compute_rewards(full_responses_for_reward, expanded_gts)
206
+ rewards_tensor = torch.tensor(raw_rewards, device=self.device, dtype=torch.float32)
207
+
208
+ # 计算旧策略的 Log Probs
209
+ gen_mask = (generated_ids != self.tokenizer.pad_token_id).long()
210
+ full_attention_mask = torch.cat([attention_mask_repeated, gen_mask], dim=1)
211
+
212
+ # ✅ 修复:构建正确的 position_ids(考虑左 padding)
213
+ batch_size = sequences.size(0)
214
+ seq_len = sequences.size(1)
215
+ position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device)
216
+
217
+ for i in range(batch_size):
218
+ # 找到第一个非 padding token 的位置
219
+ non_pad_positions = (full_attention_mask[i] == 1).nonzero(as_tuple=True)[0]
220
+ if len(non_pad_positions) > 0:
221
+ start_pos = non_pad_positions[0].item()
222
+ valid_len = len(non_pad_positions)
223
+ # 从 0 开始编号有效 token 的位置
224
+ position_ids[i, start_pos:start_pos + valid_len] = torch.arange(valid_len, device=self.device)
225
+
226
+ full_input_data = {'segments': [{'type': 'text', 'data': sequences, 'modality_id': 0}]}
227
+
228
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
229
+ actor_out = self.actor(
230
+ full_input_data,
231
+ attention_mask=full_attention_mask,
232
+ position_ids=position_ids # ✅ 添加 position_ids
233
+ )
234
+
235
+ logits = actor_out['logits'][:, :-1, :]
236
+ targets = sequences[:, 1:]
237
+ log_probs = F.log_softmax(logits, dim=-1)
238
+ per_token_log_probs = torch.gather(log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
239
+
240
+ # ✅ 显存优化:提前移到 CPU
241
+ return {
242
+ 'prompts_text': prompts_text,
243
+ 'sequences': sequences.detach().cpu(), # ✅ 移到 CPU
244
+ 'old_log_probs': per_token_log_probs.detach().cpu(), # ✅ 移到 CPU
245
+ 'rewards': rewards_tensor.cpu(), # ✅ 移到 CPU
246
+ 'attention_mask': full_attention_mask.cpu(), # ✅ 新增:保存 attention_mask
247
+ 'position_ids': position_ids.cpu(), # ✅ 新增:保存 position_ids
248
+ 'prompt_length': prompt_len
249
+ }
250
+
251
+ def _update_sas_stats(self, prompt_text, new_rewards):
252
+ """更新 SAS 均值和方差统计"""
253
+ prompt_hash = self._get_stable_hash(prompt_text)
254
+
255
+ mu_new = new_rewards.mean().item()
256
+ var_new = new_rewards.var(unbiased=False).item() if len(new_rewards) > 1 else 0.0
257
+
258
+ if prompt_hash not in self.sas_stats:
259
+ self.sas_stats[prompt_hash] = {
260
+ 'i': 1,
261
+ 'mu_total': mu_new,
262
+ 'var_total': var_new
263
+ }
264
+ return mu_new, np.sqrt(var_new + 1e-8), mu_new, np.sqrt(var_new + 1e-8)
265
+
266
+ stats = self.sas_stats[prompt_hash]
267
+ i = stats['i'] + 1
268
+ mu_old = stats['mu_total']
269
+ var_old = stats['var_total']
270
+
271
+ # 增量更新公式
272
+ mu_total = (mu_new + (i - 1) * mu_old) / i
273
+ term3 = ((i - 1) / i) * (mu_old - mu_new)**2
274
+ var_total = (var_new + (i - 1) * var_old + term3) / i
275
+
276
+ stats['i'] = i
277
+ stats['mu_total'] = mu_total
278
+ stats['var_total'] = var_total
279
+
280
+ return mu_new, np.sqrt(var_new + 1e-8), mu_total, np.sqrt(var_total + 1e-8)
281
+
282
+ def _compute_sas_advantages(self, experience_batch):
283
+ prompts = experience_batch['prompts_text']
284
+ rewards = experience_batch['rewards'].view(-1, self.group_size)
285
+
286
+ final_advantages = []
287
+
288
+ for idx, prompt in enumerate(prompts):
289
+ group_rewards = rewards[idx]
290
+ mu_new, std_new, mu_total, std_total = self._update_sas_stats(prompt, group_rewards)
291
+
292
+ A_new = (group_rewards - mu_new) / (std_new + 1e-8)
293
+ A_total = (group_rewards - mu_total) / (std_total + 1e-8)
294
+
295
+ i = self.sas_stats[self._get_stable_hash(prompt)]['i']
296
+
297
+ SA_new = ((i - 1) / i) * A_new + (1 / i) * A_total
298
+ SA_total = (1 / i) * A_new + ((i - 1) / i) * A_total
299
+
300
+ mask = (torch.abs(SA_new) < torch.abs(SA_total)).float()
301
+ A_final = mask * SA_new + (1 - mask) * SA_total
302
+ final_advantages.append(A_final)
303
+
304
+ return torch.cat(final_advantages)
305
+
306
+ def train_step(self, experience):
307
+ """执行训练步骤:梯度累积 -> PPO/GRPO Update"""
308
+ self.experience_buffer.append(experience)
309
+ if len(self.experience_buffer) < self.gradient_accumulation_steps:
310
+ return None
311
+
312
+ all_advantages = []
313
+ for exp in self.experience_buffer:
314
+ adv = self._compute_sas_advantages(exp)
315
+ exp['advantages'] = adv.detach() # 保持在 CPU
316
+ all_advantages.append(exp['advantages'])
317
+
318
+ self.actor.train()
319
+
320
+ max_seq_len = max([e['sequences'].size(1) for e in self.experience_buffer])
321
+ max_lp_len = max([e['old_log_probs'].size(1) for e in self.experience_buffer])
322
+
323
+ def pad_tensor(t, target_len, val):
324
+ return F.pad(t, (0, target_len - t.size(1)), value=val)
325
+
326
+ padded_seqs = []
327
+ padded_old_lp = []
328
+ padded_attn_masks = [] # ✅ 新增
329
+ padded_pos_ids = [] # ✅ 新增
330
+ prompt_lens_list = []
331
+
332
+ for e in self.experience_buffer:
333
+ padded_seqs.append(pad_tensor(e['sequences'], max_seq_len, self.tokenizer.pad_token_id))
334
+
335
+ # ✅ 修复:使用 0.0 填充(exp(0)=1,数值稳定)
336
+ padded_old_lp.append(pad_tensor(e['old_log_probs'], max_lp_len, 0.0))
337
+
338
+ # ✅ 新增:padding attention_mask 和 position_ids
339
+ padded_attn_masks.append(pad_tensor(e['attention_mask'], max_seq_len, 0))
340
+ padded_pos_ids.append(pad_tensor(e['position_ids'], max_seq_len, 0))
341
+
342
+ prompt_lens_list.extend([e['prompt_length']] * (len(e['sequences'])))
343
+
344
+ # 显存优化:Dataset 保持在 CPU
345
+ cat_sequences = torch.cat(padded_seqs, dim=0)
346
+ cat_old_log_probs = torch.cat(padded_old_lp, dim=0)
347
+ cat_advantages = torch.cat(all_advantages, dim=0)
348
+ cat_prompt_lens = torch.tensor(prompt_lens_list)
349
+ cat_attention_masks = torch.cat(padded_attn_masks, dim=0) # ✅ 新增
350
+ cat_position_ids = torch.cat(padded_pos_ids, dim=0) # ✅ 新增
351
+
352
+ self.experience_buffer = []
353
+
354
+ dataset = TensorDataset(
355
+ cat_sequences,
356
+ cat_old_log_probs,
357
+ cat_advantages,
358
+ cat_prompt_lens,
359
+ cat_attention_masks, # ✅ 新增
360
+ cat_position_ids # ✅ 新增
361
+ )
362
+ dataloader = DataLoader(dataset, batch_size=self.inner_batch_size, shuffle=True)
363
+
364
+ total_loss = 0
365
+ update_steps = 0
366
+
367
+ for _ in range(self.grpo_epochs):
368
+ for batch in dataloader:
369
+ # ✅ 解包所有数据
370
+ seqs, old_lp, advs, p_lens, attn_masks, pos_ids = [b.to(self.device) for b in batch]
371
+
372
+ input_data = {'segments': [{'type': 'text', 'data': seqs, 'modality_id': 0}]}
373
+
374
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
375
+ # ✅ 修复:传入 attention_mask 和 position_ids
376
+ outputs = self.actor(
377
+ input_data,
378
+ attention_mask=attn_masks,
379
+ position_ids=pos_ids
380
+ )
381
+ logits = outputs['logits'][:, :-1, :]
382
+ targets = seqs[:, 1:]
383
+
384
+ new_log_probs = F.log_softmax(logits, dim=-1)
385
+ new_token_log_probs = torch.gather(new_log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
386
+
387
+ # Mask 构建(保持原有逻辑)
388
+ mask = torch.zeros_like(new_token_log_probs)
389
+ for i, pl in enumerate(p_lens):
390
+ pl_val = int(pl.item())
391
+ start_idx = max(0, pl_val - 1)
392
+ if start_idx < mask.size(1):
393
+ mask[i, start_idx:] = 1.0
394
+
395
+ # 过滤 padding 和无效的 old_log_probs
396
+ is_padding = (targets == self.tokenizer.pad_token_id)
397
+ is_valid_old_lp = (old_lp != 0.0) # ✅ 修改:过滤填充值
398
+ mask = mask * (~is_padding).float() * is_valid_old_lp.float()
399
+
400
+ # ✅ 修复:DCPO Loss 计算 - 数值稳定性
401
+ q_probs = torch.exp(old_lp).clamp(min=1e-10, max=1.0) # ✅ clamp 避免除零
402
+ term_low = 1.0 - (4.0 * self.eps_low) / q_probs
403
+ lower_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_low, min=0.0))
404
+ term_high = 1.0 + (4.0 * self.eps_high) / q_probs
405
+ upper_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_high, min=0.0))
406
+
407
+ ratio = torch.exp(new_token_log_probs - old_lp)
408
+ ratio = torch.clamp(ratio, 0, self.r_max)
409
+
410
+ advs_expanded = advs.unsqueeze(1).expand_as(ratio)
411
+ surr1 = ratio * advs_expanded
412
+ clipped_ratio = torch.min(torch.max(ratio, lower_clip), upper_clip)
413
+ surr2 = clipped_ratio * advs_expanded
414
+
415
+ element_wise_loss = torch.min(surr1, surr2)
416
+ masked_loss = element_wise_loss * mask
417
+ response_lens = torch.clamp(mask.sum(dim=1), min=1.0)
418
+ per_response_loss = masked_loss.sum(dim=1) / response_lens
419
+ loss = -per_response_loss.mean()
420
+
421
+ self.optimizer.zero_grad()
422
+ self.scaler.scale(loss).backward()
423
+ self.scaler.unscale_(self.optimizer)
424
+ torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
425
+ self.scaler.step(self.optimizer)
426
+ self.scaler.update()
427
+
428
+ total_loss += loss.item()
429
+ update_steps += 1
430
+
431
+ return total_loss / max(update_steps, 1)
dcpo_train.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.distributed as dist
4
+ from torch.nn.parallel import DistributedDataParallel as DDP
5
+ from transformers import AutoTokenizer
6
+ from torch.utils.data import DataLoader, Dataset
7
+ import json
8
+ import logging
9
+ from tqdm import tqdm
10
+ import glob
11
+ from datetime import datetime
12
+ import gc
13
+ import warnings
14
+
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+ from model import MultiModalDenseTransformer
18
+ from dcpo import DCPOTrainer
19
+
20
+ # ================= DDP 设置 =================
21
+ def setup_distributed():
22
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
23
+ dist.init_process_group(backend="nccl")
24
+ rank = int(os.environ["RANK"])
25
+ local_rank = int(os.environ["LOCAL_RANK"])
26
+ world_size = int(os.environ["WORLD_SIZE"])
27
+ torch.cuda.set_device(local_rank)
28
+ if rank == 0:
29
+ print(f"Initialized DDP: Rank {rank}/{world_size}")
30
+ return rank, local_rank, world_size
31
+ else:
32
+ print("Initialized Single GPU Mode")
33
+ return 0, 0, 1
34
+
35
+ RANK, LOCAL_RANK, WORLD_SIZE = setup_distributed()
36
+ IS_MAIN = RANK == 0
37
+
38
+ logger = logging.getLogger(__name__)
39
+ logger.setLevel(logging.INFO if IS_MAIN else logging.WARNING)
40
+
41
+ # ================= 数据集 =================
42
+ class MathDataset(Dataset):
43
+ def __init__(self, path):
44
+ self.data = []
45
+ with open(path, 'r', encoding='utf-8') as f:
46
+ for line in f:
47
+ if line.strip():
48
+ self.data.append(json.loads(line))
49
+
50
+ def __len__(self):
51
+ return len(self.data)
52
+
53
+ def __getitem__(self, idx):
54
+ return self.data[idx]
55
+
56
+ def math_collate(batch):
57
+ return {
58
+ 'prompt': [item['prompt'] for item in batch],
59
+ 'ground_truth': [item['ground_truth'] for item in batch]
60
+ }
61
+
62
+ # ================= 主函数 =================
63
+ def main():
64
+ # ------------------ 配置区域 ------------------
65
+ CONFIG = {
66
+ 'sft_checkpoint': '/root/checkpoints/dcpo_posttrain_round3/step_1200.pt',
67
+ 'data_path': '/root/dataset/r1_zero_math.jsonl',
68
+ 'save_dir': '/root/checkpoints/dcpo_training',
69
+ 'resume_from': None,
70
+
71
+ 'model_dim': 1536,
72
+ 'n_layers': 12,
73
+ 'n_heads': 12,
74
+ 'n_kv_heads': 4,
75
+
76
+ 'group_size': 4,
77
+ 'batch_size': 1,
78
+ 'learning_rate': 1e-6,
79
+ 'max_steps': 5000,
80
+ 'max_gen_len': 512,
81
+ 'save_interval': 1400,
82
+
83
+ 'dcpo_eps_low': 0.16,
84
+ 'dcpo_eps_high': 0.2,
85
+ 'dcpo_r_max': 10.0,
86
+
87
+ 'gradient_accumulation_steps': 8,
88
+ 'inner_batch_size': 4,
89
+
90
+ # ========== 关键新增1: 奖励验证器配置 ==========
91
+ 'use_reference_comparison': True, # 是否使用参考推理对比
92
+ 'use_progressive_reward': False, # 是否使用渐进式奖励(实验性)
93
+ 'phase1_steps': 2000, # 渐进式阶段1(宽松格式)
94
+ 'phase2_steps': 4000, # 渐进式阶段2(中等格式)
95
+ }
96
+ # ---------------------------------------------
97
+
98
+ # 初始化日志文件 Handler
99
+ file_handler = None
100
+ if IS_MAIN:
101
+ os.makedirs(CONFIG['save_dir'], exist_ok=True)
102
+ current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
103
+ log_file = os.path.join(CONFIG['save_dir'], f"dcpo_train_{current_time}.log")
104
+
105
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
106
+ file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
107
+ logger.addHandler(file_handler)
108
+
109
+ # 将配置写入日志文件
110
+ logger.info(f"DCPO Configuration: {json.dumps(CONFIG, indent=2)}")
111
+
112
+ # ========== 关键新增2: 记录使用的验证器类型 ==========
113
+ if CONFIG['use_progressive_reward']:
114
+ logger.info(f"使用渐进式奖励验证器:")
115
+ logger.info(f" - 阶段1 (0-{CONFIG['phase1_steps']}): 宽松格式")
116
+ logger.info(f" - 阶段2 ({CONFIG['phase1_steps']}-{CONFIG['phase2_steps']}): 中等格式")
117
+ logger.info(f" - 阶段3 ({CONFIG['phase2_steps']}+): 完整要求")
118
+ else:
119
+ logger.info(f"使用标准改进版验证器 (reference_comparison={CONFIG['use_reference_comparison']})")
120
+
121
+ metrics_file = os.path.join(CONFIG['save_dir'], "metrics.jsonl")
122
+ if not os.path.exists(metrics_file):
123
+ with open(metrics_file, 'w', encoding='utf-8') as f:
124
+ pass
125
+
126
+ # 1. 加载 Tokenizer
127
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
128
+ if tokenizer.pad_token is None:
129
+ tokenizer.pad_token = tokenizer.eos_token
130
+ tokenizer.pad_token_id = tokenizer.eos_token_id
131
+
132
+ # 2. 初始化模型
133
+ def create_model():
134
+ return MultiModalDenseTransformer(
135
+ model_dim=CONFIG['model_dim'],
136
+ vocab_size=len(tokenizer),
137
+ n_layers=CONFIG['n_layers'],
138
+ n_heads=CONFIG['n_heads'],
139
+ n_kv_heads=CONFIG['n_kv_heads'],
140
+ max_seq_len=2048,
141
+ use_gradient_checkpointing=True
142
+ )
143
+
144
+ device = torch.device(f"cuda:{LOCAL_RANK}")
145
+
146
+ if IS_MAIN:
147
+ print("Initializing Actor Model...")
148
+
149
+ actor = create_model().to(device)
150
+ ref = None
151
+
152
+ if WORLD_SIZE > 1:
153
+ actor = DDP(actor, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
154
+
155
+ # 3. 初始化 Trainer
156
+ # ========== 关键新增3: 传入新的验证器参数 ==========
157
+ trainer = DCPOTrainer(
158
+ actor_model=actor,
159
+ ref_model=ref,
160
+ tokenizer=tokenizer,
161
+ learning_rate=CONFIG['learning_rate'],
162
+ group_size=CONFIG['group_size'],
163
+ eps_low=CONFIG['dcpo_eps_low'],
164
+ eps_high=CONFIG['dcpo_eps_high'],
165
+ r_max=CONFIG['dcpo_r_max'],
166
+ use_amp=True,
167
+ gradient_accumulation_steps=CONFIG['gradient_accumulation_steps'],
168
+ inner_batch_size=CONFIG['inner_batch_size'],
169
+ # 新增参数
170
+ use_reference_comparison=CONFIG['use_reference_comparison'],
171
+ use_progressive_reward=CONFIG['use_progressive_reward'],
172
+ phase1_steps=CONFIG['phase1_steps'],
173
+ phase2_steps=CONFIG['phase2_steps']
174
+ )
175
+
176
+ # 4. 恢复状态
177
+ start_step = 0
178
+ samples_seen = 0
179
+
180
+ if CONFIG['resume_from']:
181
+ resume_path = CONFIG['resume_from']
182
+ if IS_MAIN:
183
+ print(f"Resuming from: {resume_path}")
184
+
185
+ checkpoint = torch.load(resume_path, map_location='cpu')
186
+
187
+ if WORLD_SIZE > 1:
188
+ actor.module.load_state_dict(checkpoint['model_state_dict'])
189
+ else:
190
+ actor.load_state_dict(checkpoint['model_state_dict'])
191
+
192
+ if 'trainer_state_dict' in checkpoint:
193
+ trainer.load_state_dict(checkpoint['trainer_state_dict'])
194
+
195
+ if 'rng_state' in checkpoint:
196
+ torch.set_rng_state(checkpoint['rng_state'])
197
+ if 'cuda_rng_state' in checkpoint:
198
+ try:
199
+ torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state'])
200
+ except:
201
+ torch.cuda.set_rng_state(checkpoint['cuda_rng_state'][LOCAL_RANK])
202
+
203
+ start_step = checkpoint.get('step', 0) + 1
204
+ samples_seen = checkpoint.get('samples_seen', start_step * CONFIG['batch_size'] * WORLD_SIZE)
205
+
206
+ # ========== 关键新增4: 恢复时更新步数(用于渐进式奖励) ==========
207
+ if CONFIG['use_progressive_reward']:
208
+ trainer.update_step(start_step)
209
+ if IS_MAIN:
210
+ print(f"Restored progressive reward state to step {start_step}")
211
+
212
+ del checkpoint
213
+ gc.collect()
214
+ torch.cuda.empty_cache()
215
+ else:
216
+ if IS_MAIN:
217
+ print(f"Loading SFT checkpoint: {CONFIG['sft_checkpoint']}")
218
+ checkpoint = torch.load(CONFIG['sft_checkpoint'], map_location='cpu')
219
+ state_dict = checkpoint['model_state_dict'] if 'model_state_dict' in checkpoint else checkpoint
220
+ new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
221
+
222
+ if WORLD_SIZE > 1:
223
+ actor.module.load_state_dict(new_state_dict)
224
+ else:
225
+ actor.load_state_dict(new_state_dict)
226
+
227
+ del checkpoint, state_dict, new_state_dict
228
+ gc.collect()
229
+ torch.cuda.empty_cache()
230
+
231
+ # 5. Dataloader
232
+ dataset = MathDataset(CONFIG['data_path'])
233
+ if WORLD_SIZE > 1:
234
+ sampler = torch.utils.data.DistributedSampler(
235
+ dataset, num_replicas=WORLD_SIZE, rank=RANK, shuffle=True, seed=42
236
+ )
237
+ else:
238
+ sampler = None
239
+
240
+ dataloader = DataLoader(
241
+ dataset, batch_size=CONFIG['batch_size'],
242
+ collate_fn=math_collate, sampler=sampler, shuffle=(sampler is None)
243
+ )
244
+
245
+ if IS_MAIN:
246
+ print(f"Starting Training from step {start_step}")
247
+
248
+ # 6. Dataloader 指针恢复
249
+ if sampler:
250
+ epoch = start_step // len(dataloader)
251
+ sampler.set_epoch(epoch)
252
+
253
+ data_iter = iter(dataloader)
254
+ steps_in_epoch = start_step % len(dataloader)
255
+
256
+ if start_step > 0 and steps_in_epoch > 0:
257
+ if IS_MAIN:
258
+ print(f"Fast-forwarding dataloader by {steps_in_epoch} steps...")
259
+
260
+ for _ in range(steps_in_epoch):
261
+ try:
262
+ next(data_iter)
263
+ except StopIteration:
264
+ if sampler:
265
+ epoch += 1
266
+ sampler.set_epoch(epoch)
267
+ data_iter = iter(dataloader)
268
+ next(data_iter)
269
+
270
+ # 7. 训练循环
271
+ progress_bar = tqdm(
272
+ range(start_step, CONFIG['max_steps']),
273
+ disable=not IS_MAIN,
274
+ initial=start_step,
275
+ total=CONFIG['max_steps'],
276
+ ncols=120,
277
+ mininterval=1.0
278
+ )
279
+
280
+ running_reward = 0.0
281
+ running_loss = 0.0
282
+
283
+ for step in progress_bar:
284
+ try:
285
+ # ========== 关键新增5: 更新训练步数(用于渐进式奖励) ==========
286
+ if CONFIG['use_progressive_reward']:
287
+ trainer.update_step(step)
288
+
289
+ try:
290
+ batch = next(data_iter)
291
+ except StopIteration:
292
+ if sampler:
293
+ epoch = step // len(dataloader) + 1
294
+ sampler.set_epoch(epoch)
295
+ data_iter = iter(dataloader)
296
+ batch = next(data_iter)
297
+
298
+ samples_seen += CONFIG['batch_size'] * WORLD_SIZE
299
+
300
+ # 1. 生成 + SAS
301
+ experience = trainer.generate_and_prepare(
302
+ batch,
303
+ max_gen_len=CONFIG['max_gen_len']
304
+ )
305
+
306
+ step_reward = experience['rewards'].mean().item()
307
+ if running_reward == 0: running_reward = step_reward
308
+ else: running_reward = 0.95 * running_reward + 0.05 * step_reward
309
+
310
+ # 2. 训练步骤
311
+ loss = trainer.train_step(experience)
312
+
313
+ # 状态栏缩写
314
+ status_dict = {"Rw": f"{running_reward:.2f}"}
315
+
316
+ # ========== 关键新增6: 添加阶段信息显示(如果使用渐进式) ==========
317
+ if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'):
318
+ status_dict["Ph"] = f"{trainer.math_verifier.current_phase}"
319
+
320
+ if loss is not None:
321
+ if running_loss == 0: running_loss = loss
322
+ else: running_loss = 0.9 * running_loss + 0.1 * loss
323
+ status_dict["Ls"] = f"{running_loss:.3f}"
324
+
325
+ if IS_MAIN:
326
+ current_lr = trainer.optimizer.param_groups[0]['lr']
327
+ metrics_data = {
328
+ "step": step,
329
+ "running_reward": float(running_reward),
330
+ "reward": float(step_reward),
331
+ "loss": float(loss),
332
+ "lr": float(current_lr),
333
+ "samples_seen": samples_seen,
334
+ "timestamp": datetime.now().isoformat()
335
+ }
336
+
337
+ # ========== 关键新增7: 记录渐进式阶段信息 ==========
338
+ if CONFIG['use_progressive_reward'] and hasattr(trainer.math_verifier, 'current_phase'):
339
+ metrics_data['reward_phase'] = trainer.math_verifier.current_phase
340
+
341
+ with open(os.path.join(CONFIG['save_dir'], "metrics.jsonl"), "a", encoding='utf-8') as f:
342
+ f.write(json.dumps(metrics_data) + "\n")
343
+
344
+ if step % 10 == 0:
345
+ log_msg = f"Step {step} | Reward: {step_reward:.4f} | Loss: {loss:.4f}"
346
+ progress_bar.write(log_msg)
347
+ if file_handler:
348
+ file_handler.emit(logging.LogRecord(
349
+ name="train", level=logging.INFO, pathname=__file__, lineno=0,
350
+ msg=log_msg, args=(), exc_info=None
351
+ ))
352
+ else:
353
+ status_dict["St"] = "Acc"
354
+
355
+ progress_bar.set_description(f"{' '.join([f'{k}:{v}' for k,v in status_dict.items()])}")
356
+
357
+ # 保存逻辑
358
+ is_accum_boundary = (len(trainer.experience_buffer) == 0)
359
+
360
+ if step > 0 and step % CONFIG['save_interval'] == 0 and IS_MAIN:
361
+ if not is_accum_boundary:
362
+ msg = "Saving checkpoint during gradient accumulation! Partial gradients will be lost."
363
+ progress_bar.write(msg)
364
+ if file_handler: logger.warning(msg)
365
+
366
+ save_path = f"{CONFIG['save_dir']}/step_{step}.pt"
367
+ model_to_save = actor.module if hasattr(actor, 'module') else actor
368
+
369
+ torch.save({
370
+ 'step': step,
371
+ 'samples_seen': samples_seen,
372
+ 'model_state_dict': model_to_save.state_dict(),
373
+ 'trainer_state_dict': trainer.state_dict(),
374
+ 'rng_state': torch.get_rng_state(),
375
+ 'cuda_rng_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
376
+ }, save_path)
377
+
378
+ msg = f"Checkpoint saved: {save_path}"
379
+ progress_bar.write(msg)
380
+ if file_handler: logger.info(msg)
381
+
382
+ del experience
383
+ del batch
384
+
385
+ except Exception as e:
386
+ err_msg = f"Step {step} Error: {e}"
387
+ if IS_MAIN:
388
+ progress_bar.write(err_msg)
389
+ logger.error(err_msg)
390
+ import traceback
391
+ traceback.print_exc()
392
+ continue
393
+
394
+ if IS_MAIN:
395
+ final_path = f"{CONFIG['save_dir']}/final_dcpo.pt"
396
+ model_to_save = actor.module if hasattr(actor, 'module') else actor
397
+ torch.save({'model_state_dict': model_to_save.state_dict()}, final_path)
398
+ print("DCPO Training Finished.")
399
+
400
+ if WORLD_SIZE > 1:
401
+ dist.destroy_process_group()
402
+
403
+ if __name__ == "__main__":
404
+ main()