szxllm commited on
Commit
e68927b
·
verified ·
1 Parent(s): 4f003b4

Update grpo.py

Browse files
Files changed (1) hide show
  1. grpo.py +538 -629
grpo.py CHANGED
@@ -1,630 +1,539 @@
1
- """
2
- 改进的 GRPO (Group Relative Policy Optimization) 训练器
3
- 修复了所有已知问题
4
- """
5
-
6
- import torch
7
- import torch.optim as optim
8
- import torch.nn.functional as F
9
- from torch.utils.data import DataLoader, TensorDataset
10
- from typing import Dict, List, Tuple, Optional
11
- from tqdm import tqdm
12
- import numpy as np
13
- import gc
14
- import logging
15
-
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
-
20
- class GRPOTrainer:
21
- """
22
- GRPO训练器 - Group Relative Policy Optimization
23
- 参考 DeepSeekMath/DeepSeek-V3 策略
24
-
25
- 主要修复:
26
- 1. 修复了 generate() 返回格式问题
27
- 2. 修复了 reward_model 输出处理
28
- 3. 添加了完整的混合精度训练支持
29
- 4. 改进了 KL 散度计算的数值稳定性
30
- 5. 修复了 past_key_values 的使用逻辑
31
- 6. 改进了内存管理和错误处理
32
- """
33
-
34
- def __init__(
35
- self,
36
- actor_model,
37
- reward_model,
38
- ref_model,
39
- tokenizer,
40
- learning_rate: float = 1e-6,
41
- kl_coef: float = 0.04,
42
- group_size: int = 4,
43
- clip_epsilon: float = 0.2,
44
- max_grad_norm: float = 1.0,
45
- grpo_epochs: int = 1,
46
- update_batch_size: int = 4,
47
- use_amp: bool = True,
48
- value_clip: bool = False,
49
- entropy_coef: float = 0.01,
50
- advantage_normalization: str = 'group', # 'group', 'global', 'none'
51
- kl_estimation_method: str = 'forward' # 'forward', 'reverse', 'symmetric'
52
- ):
53
- """
54
- 初始化GRPO训练器
55
-
56
- Args:
57
- actor_model: 要训练的策略模型
58
- reward_model: 奖励模型(冻结)
59
- ref_model: 参考模型(冻结)
60
- tokenizer: 分词器
61
- learning_rate: 学习率
62
- kl_coef: KL散度惩罚系数
63
- group_size: 每个prompt生成的样本数
64
- clip_epsilon: PPO clip范围
65
- max_grad_norm: 梯度裁剪阈值
66
- grpo_epochs: 每批经验的训练轮数
67
- update_batch_size: 更新时的mini-batch大小
68
- use_amp: 是否使用混合精度训练
69
- value_clip: 是否对value进行clip(当前未使用value网络)
70
- entropy_coef: 熵正则化系数
71
- advantage_normalization: 优势函数归一化方式
72
- kl_estimation_method: KL散度估计方法
73
- """
74
- self.actor = actor_model
75
- self.reward_model = reward_model
76
- self.ref_model = ref_model
77
- self.tokenizer = tokenizer
78
-
79
- self.kl_coef = kl_coef
80
- self.group_size = group_size
81
- self.clip_epsilon = clip_epsilon
82
- self.max_grad_norm = max_grad_norm
83
- self.grpo_epochs = grpo_epochs
84
- self.update_batch_size = update_batch_size
85
- self.use_amp = use_amp
86
- self.entropy_coef = entropy_coef
87
- self.advantage_normalization = advantage_normalization
88
- self.kl_estimation_method = kl_estimation_method
89
-
90
- self.device = next(actor_model.parameters()).device
91
-
92
- # 冻结参考模型和奖励模型
93
- self.ref_model.eval()
94
- self.ref_model.requires_grad_(False)
95
- self.reward_model.eval()
96
- self.reward_model.requires_grad_(False)
97
-
98
- # 优化器配置
99
- self.optimizer = optim.AdamW(
100
- filter(lambda p: p.requires_grad, actor_model.parameters()),
101
- lr=learning_rate,
102
- weight_decay=0.01,
103
- betas=(0.9, 0.95),
104
- eps=1e-8
105
- )
106
-
107
- # 混合精度训练 - 修复:添加 GradScaler
108
- self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
109
-
110
- # 训练统计
111
- self.training_stats = {
112
- 'iterations': 0,
113
- 'total_samples': 0,
114
- 'avg_rewards': [],
115
- 'avg_kl': [],
116
- 'policy_losses': []
117
- }
118
-
119
- logger.info(f"GRPO Trainer initialized:")
120
- logger.info(f" Group Size: {group_size}")
121
- logger.info(f" KL Coef: {kl_coef}")
122
- logger.info(f" Clip Epsilon: {clip_epsilon}")
123
- logger.info(f" Learning Rate: {learning_rate}")
124
- logger.info(f" Update Batch Size: {update_batch_size}")
125
- logger.info(f" Mixed Precision: {use_amp}")
126
- logger.info(f" KL Estimation: {kl_estimation_method}")
127
-
128
- def _compute_kl_divergence(
129
- self,
130
- log_probs: torch.Tensor,
131
- ref_log_probs: torch.Tensor,
132
- mask: torch.Tensor
133
- ) -> torch.Tensor:
134
- """
135
- 计算KL散度(改进数值稳定性)
136
-
137
- Args:
138
- log_probs: 当前策略的log概率
139
- ref_log_probs: 参考策略的log概率
140
- mask: 有效token的mask
141
-
142
- Returns:
143
- KL散度(标量)
144
- """
145
- if self.kl_estimation_method == 'forward':
146
- # KL(π||π_ref) = Σ π * log(π/π_ref)
147
- # Σ exp(log_π) * (log_π - log_π_ref)
148
- # 为了数值稳定,使用 log_π - log_π_ref
149
- kl = log_probs - ref_log_probs
150
- elif self.kl_estimation_method == 'reverse':
151
- # KL(π_ref||π) = Σ π_ref * log(π_ref/π)
152
- kl = ref_log_probs - log_probs
153
- else: # symmetric
154
- # 对称KL散度
155
- forward_kl = log_probs - ref_log_probs
156
- reverse_kl = ref_log_probs - log_probs
157
- kl = 0.5 * (forward_kl + reverse_kl)
158
-
159
- # 应用mask并求和
160
- kl_penalty = (kl * mask).sum(dim=-1)
161
- return kl_penalty
162
-
163
- @torch.no_grad()
164
- def generate_experience(
165
- self,
166
- prompts_dataloader: DataLoader,
167
- max_gen_len: int,
168
- temperature: float = 1.0,
169
- top_p: float = 0.9
170
- ) -> Dict:
171
- """
172
- 生成经验数据:采样 -> 计算 LogProbs -> 计算 Rewards(含KL)
173
-
174
- 修复:
175
- 1. 正确处理 generate() 的返回值
176
- 2. 修复 reward_model 的输出处理
177
- 3. 改进数值稳定性
178
- """
179
- self.actor.eval()
180
-
181
- all_sequences = []
182
- all_log_probs = []
183
- all_advantages = []
184
- all_prompt_lens = []
185
- all_rewards = []
186
-
187
- logger.info("Generating experience...")
188
-
189
- for prompts in tqdm(prompts_dataloader, desc="Generating Experience"):
190
- try:
191
- # 处理不同的输入格式
192
- if isinstance(prompts, (list, tuple)):
193
- prompts = prompts[0]
194
-
195
- prompts = prompts.to(self.device)
196
- batch_size = prompts.shape[0]
197
-
198
- # 扩展prompts以生成group_size个样本
199
- prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0)
200
- prompt_len = prompts_repeated.shape[1]
201
-
202
- input_data = {
203
- 'segments': [{
204
- 'type': 'text',
205
- 'data': prompts_repeated,
206
- 'modality_id': 0
207
- }]
208
- }
209
-
210
- # 1. 采样生成(修复:generate只返回新生成的tokens)
211
- with torch.amp.autocast('cuda', enabled=self.use_amp):
212
- response_ids = self.actor.generate(
213
- input_data,
214
- max_new_tokens=max_gen_len,
215
- do_sample=True,
216
- temperature=temperature,
217
- top_p=top_p,
218
- eos_token_id=self.tokenizer.eos_token_id,
219
- pad_token_id=self.tokenizer.pad_token_id,
220
- use_cache=True # 使用缓存加速生成
221
- )
222
-
223
- # 修复:拼接完整序列(prompt + response)
224
- sequences = torch.cat([prompts_repeated, response_ids], dim=1)
225
-
226
- # 检查序列长度
227
- if sequences.shape[1] <= prompt_len:
228
- logger.warning("Generated sequence too short, skipping batch")
229
- continue
230
-
231
- full_input_data = {
232
- 'segments': [{
233
- 'type': 'text',
234
- 'data': sequences,
235
- 'modality_id': 0
236
- }]
237
- }
238
-
239
- # 2. 计算当前策略和参考策略的 LogProbs
240
- with torch.amp.autocast('cuda', enabled=self.use_amp):
241
- actor_out = self.actor(full_input_data)
242
- ref_out = self.ref_model(full_input_data)
243
-
244
- logits = actor_out['logits'][:, :-1, :]
245
- ref_logits = ref_out['logits'][:, :-1, :]
246
- targets = sequences[:, 1:]
247
-
248
- # 计算log probabilities(改进数值稳定性)
249
- log_probs = F.log_softmax(logits, dim=-1)
250
- ref_log_probs = F.log_softmax(ref_logits, dim=-1)
251
-
252
- # 提取对应token的log概率
253
- per_token_log_probs = torch.gather(
254
- log_probs, -1, targets.unsqueeze(-1)
255
- ).squeeze(-1)
256
- per_token_ref_log_probs = torch.gather(
257
- ref_log_probs, -1, targets.unsqueeze(-1)
258
- ).squeeze(-1)
259
-
260
- # 3. 计算 KL 散度 (只针对response部分)
261
- response_mask = torch.arange(
262
- sequences.size(1) - 1, device=self.device
263
- ) >= (prompt_len - 1)
264
- response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs)
265
- response_mask = response_mask.float()
266
-
267
- # 使用改进的KL计算
268
- kl_penalty = self._compute_kl_divergence(
269
- per_token_log_probs,
270
- per_token_ref_log_probs,
271
- response_mask
272
- )
273
-
274
- # 4. 计算环境奖励(修复:正确处理reward_model输出)
275
- with torch.amp.autocast('cuda', enabled=self.use_amp):
276
- reward_output = self.reward_model(full_input_data)
277
-
278
- # reward_model返回 (batch_size, seq_len),取最后一个位置的奖励
279
- if reward_output.dim() == 2:
280
- raw_rewards = reward_output[:, -1]
281
- else:
282
- raw_rewards = reward_output.squeeze(-1)
283
-
284
- # 5. 组合总奖励: R_total = R_env - β * KL
285
- total_rewards = raw_rewards - self.kl_coef * kl_penalty
286
-
287
- # 6. 计算组内相对优势 (Group Relative Advantage)
288
- rewards_grouped = total_rewards.view(batch_size, self.group_size)
289
-
290
- if self.advantage_normalization == 'group':
291
- # 组内标准化
292
- mean_grouped = rewards_grouped.mean(dim=1, keepdim=True)
293
- std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
294
- advantages = (rewards_grouped - mean_grouped) / std_grouped
295
- elif self.advantage_normalization == 'global':
296
- # 全局标准化
297
- advantages = (rewards_grouped - rewards_grouped.mean()) / (
298
- rewards_grouped.std() + 1e-8
299
- )
300
- else: # 'none'
301
- advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True)
302
-
303
- advantages = advantages.view(-1)
304
-
305
- # 保存数据
306
- all_sequences.append(sequences.cpu())
307
- all_log_probs.append(per_token_log_probs.detach().cpu())
308
- all_advantages.append(advantages.detach().cpu())
309
- all_prompt_lens.append(
310
- torch.full((sequences.size(0),), prompt_len, dtype=torch.long)
311
- )
312
- all_rewards.append(total_rewards.detach().cpu())
313
-
314
- # 清理中间变量
315
- del logits, ref_logits, actor_out, ref_out
316
- del log_probs, ref_log_probs, reward_output
317
-
318
- except Exception as e:
319
- logger.error(f"Error generating experience for batch: {e}")
320
- import traceback
321
- traceback.print_exc()
322
- continue
323
-
324
- finally:
325
- torch.cuda.empty_cache()
326
-
327
- if not all_sequences:
328
- raise RuntimeError("No valid sequences generated")
329
-
330
- # 合并所有数据
331
- experience = {
332
- 'sequences': torch.cat(all_sequences, dim=0),
333
- 'log_probs': torch.cat(all_log_probs, dim=0),
334
- 'advantages': torch.cat(all_advantages, dim=0),
335
- 'prompt_lengths': torch.cat(all_prompt_lens, dim=0),
336
- 'rewards': torch.cat(all_rewards, dim=0)
337
- }
338
-
339
- # 统计信息
340
- logger.info(f"Generated {len(experience['sequences'])} sequences")
341
- logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}")
342
- logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}")
343
- logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}")
344
-
345
- return experience
346
-
347
- def grpo_step(
348
- self,
349
- dataset: TensorDataset
350
- ) -> Dict[str, float]:
351
- """
352
- 执行 GRPO 优化步骤
353
-
354
- 修复:
355
- 1. 使用 GradScaler 进行混合精度训练
356
- 2. 改进损失计算
357
- 3. 更好的统计信息收集
358
- """
359
- self.actor.train()
360
-
361
- dataloader = DataLoader(
362
- dataset,
363
- batch_size=self.update_batch_size,
364
- shuffle=True,
365
- drop_last=False
366
- )
367
-
368
- epoch_stats = {
369
- 'total_loss': 0.0,
370
- 'policy_loss': 0.0,
371
- 'entropy': 0.0,
372
- 'approx_kl': 0.0,
373
- 'clip_fraction': 0.0,
374
- 'steps': 0
375
- }
376
-
377
- for batch_data in dataloader:
378
- sequences, old_log_probs, advantages, prompt_lens = batch_data
379
-
380
- sequences = sequences.to(self.device)
381
- old_log_probs = old_log_probs.to(self.device)
382
- advantages = advantages.to(self.device)
383
-
384
- input_data = {
385
- 'segments': [{
386
- 'type': 'text',
387
- 'data': sequences,
388
- 'modality_id': 0
389
- }]
390
- }
391
-
392
- # 修复:使用 GradScaler 进行混合精度训练
393
- with torch.amp.autocast('cuda', enabled=self.use_amp):
394
- outputs = self.actor(input_data)
395
- logits = outputs['logits'][:, :-1, :]
396
-
397
- # 计算新的log probabilities
398
- targets = sequences[:, 1:]
399
- log_probs_dist = F.log_softmax(logits, dim=-1)
400
- new_log_probs = torch.gather(
401
- log_probs_dist, -1, targets.unsqueeze(-1)
402
- ).squeeze(-1)
403
-
404
- # 构建response mask
405
- mask = torch.zeros_like(new_log_probs)
406
- for i, pl in enumerate(prompt_lens):
407
- mask[i, pl-1:] = 1.0
408
-
409
- # 计算概率比率
410
- ratio = torch.exp(new_log_probs - old_log_probs)
411
-
412
- # 扩展advantages到序列维度
413
- adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs)
414
-
415
- # PPO clip损失
416
- surr1 = ratio * adv_expanded
417
- surr2 = torch.clamp(
418
- ratio,
419
- 1.0 - self.clip_epsilon,
420
- 1.0 + self.clip_epsilon
421
- ) * adv_expanded
422
-
423
- # 策略损失(最小化负目标)
424
- policy_loss = -torch.min(surr1, surr2)
425
- policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8)
426
-
427
- # 熵奖励(鼓励探索)
428
- probs = F.softmax(logits, dim=-1)
429
- entropy = -(probs * log_probs_dist).sum(dim=-1)
430
- entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8)
431
-
432
- # 总损失
433
- loss = policy_loss - self.entropy_coef * entropy_bonus
434
-
435
- # 统计信息
436
- with torch.no_grad():
437
- log_ratio = new_log_probs - old_log_probs
438
- approx_kl = ((ratio - 1) - log_ratio) * mask
439
- approx_kl = approx_kl.sum() / (mask.sum() + 1e-8)
440
-
441
- clip_fraction = ((ratio > 1 + self.clip_epsilon) |
442
- (ratio < 1 - self.clip_epsilon)).float()
443
- clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8)
444
-
445
- # 修复:使用 GradScaler 进行反向传播
446
- self.optimizer.zero_grad()
447
- self.scaler.scale(loss).backward()
448
-
449
- # 梯度裁剪
450
- self.scaler.unscale_(self.optimizer)
451
- grad_norm = torch.nn.utils.clip_grad_norm_(
452
- self.actor.parameters(),
453
- self.max_grad_norm
454
- )
455
-
456
- self.scaler.step(self.optimizer)
457
- self.scaler.update()
458
-
459
- # 累积统计
460
- epoch_stats['total_loss'] += loss.item()
461
- epoch_stats['policy_loss'] += policy_loss.item()
462
- epoch_stats['entropy'] += entropy_bonus.item()
463
- epoch_stats['approx_kl'] += approx_kl.item()
464
- epoch_stats['clip_fraction'] += clip_fraction.item()
465
- epoch_stats['steps'] += 1
466
-
467
- # 计算平均值
468
- for key in epoch_stats:
469
- if key != 'steps':
470
- epoch_stats[key] /= max(epoch_stats['steps'], 1)
471
-
472
- return epoch_stats
473
-
474
- def train(
475
- self,
476
- prompt_dataloader: DataLoader,
477
- num_iterations: int = 1,
478
- max_gen_len: int = 50,
479
- temperature: float = 1.0,
480
- save_every: int = 5,
481
- save_path: str = "checkpoints"
482
- ):
483
- """
484
- 完整的GRPO训练循环
485
-
486
- Args:
487
- prompt_dataloader: 提供prompts的数据加载器
488
- num_iterations: 训练迭代次数
489
- max_gen_len: 最大生成长度
490
- temperature: 采样温度
491
- save_every: 每N次迭代保存一次checkpoint
492
- save_path: checkpoint保存路径
493
- """
494
- logger.info(f"\n{'='*80}")
495
- logger.info(f"Starting GRPO Training")
496
- logger.info(f" Iterations: {num_iterations}")
497
- logger.info(f" Max Gen Length: {max_gen_len}")
498
- logger.info(f" Temperature: {temperature}")
499
- logger.info(f"{'='*80}\n")
500
-
501
- for iteration in range(num_iterations):
502
- try:
503
- # 1. 生成经验
504
- experience = self.generate_experience(
505
- prompt_dataloader,
506
- max_gen_len,
507
- temperature
508
- )
509
-
510
- dataset = TensorDataset(
511
- experience['sequences'],
512
- experience['log_probs'],
513
- experience['advantages'],
514
- experience['prompt_lengths']
515
- )
516
-
517
- # 2. 策略优化(多个epoch)
518
- logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...")
519
- all_epoch_stats = []
520
-
521
- for epoch in range(self.grpo_epochs):
522
- stats = self.grpo_step(dataset)
523
- all_epoch_stats.append(stats)
524
-
525
- logger.info(
526
- f" Epoch {epoch+1}/{self.grpo_epochs} | "
527
- f"Loss: {stats['total_loss']:.4f} | "
528
- f"KL: {stats['approx_kl']:.4f} | "
529
- f"Clip%: {stats['clip_fraction']*100:.1f}"
530
- )
531
-
532
- # 3. 汇总统计
533
- avg_stats = {
534
- key: np.mean([s[key] for s in all_epoch_stats])
535
- for key in all_epoch_stats[0].keys()
536
- }
537
-
538
- self.training_stats['iterations'] += 1
539
- self.training_stats['total_samples'] += len(experience['sequences'])
540
- self.training_stats['avg_rewards'].append(
541
- experience['rewards'].mean().item()
542
- )
543
- self.training_stats['avg_kl'].append(avg_stats['approx_kl'])
544
- self.training_stats['policy_losses'].append(avg_stats['policy_loss'])
545
-
546
- # 4. 打印进度
547
- logger.info(f"\n{'='*80}")
548
- logger.info(f"Iteration {iteration+1}/{num_iterations} Complete")
549
- logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}")
550
- logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}")
551
- logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}")
552
- logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}")
553
- logger.info(f" Entropy: {avg_stats['entropy']:.4f}")
554
- logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%")
555
- logger.info(f"{'='*80}\n")
556
-
557
- # 5. 保存checkpoint
558
- if (iteration + 1) % save_every == 0:
559
- self.save_checkpoint(
560
- f"{save_path}/grpo_iter_{iteration+1}.pt"
561
- )
562
-
563
- # 6. 清理内存
564
- del experience, dataset
565
- gc.collect()
566
- torch.cuda.empty_cache()
567
-
568
- except Exception as e:
569
- logger.error(f"Error in iteration {iteration+1}: {e}")
570
- import traceback
571
- traceback.print_exc()
572
- continue
573
-
574
- logger.info("GRPO Training Complete!")
575
- self.print_training_summary()
576
-
577
- def save_checkpoint(self, path: str):
578
- """保存训练checkpoint"""
579
- import os
580
- os.makedirs(os.path.dirname(path), exist_ok=True)
581
-
582
- checkpoint = {
583
- 'actor_state_dict': self.actor.state_dict(),
584
- 'optimizer_state_dict': self.optimizer.state_dict(),
585
- 'scaler_state_dict': self.scaler.state_dict(), # 修复:保存scaler状态
586
- 'training_stats': self.training_stats,
587
- 'config': {
588
- 'kl_coef': self.kl_coef,
589
- 'group_size': self.group_size,
590
- 'clip_epsilon': self.clip_epsilon,
591
- }
592
- }
593
-
594
- torch.save(checkpoint, path)
595
- logger.info(f"Checkpoint saved to {path}")
596
-
597
- def load_checkpoint(self, path: str):
598
- """加载训练checkpoint"""
599
- checkpoint = torch.load(path, map_location=self.device)
600
-
601
- self.actor.load_state_dict(checkpoint['actor_state_dict'])
602
- self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
603
-
604
- # 修复:加载scaler状态
605
- if 'scaler_state_dict' in checkpoint and self.use_amp:
606
- self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
607
-
608
- self.training_stats = checkpoint['training_stats']
609
-
610
- logger.info(f"Checkpoint loaded from {path}")
611
-
612
- def print_training_summary(self):
613
- """打印训练摘要"""
614
- logger.info("\n" + "="*80)
615
- logger.info("Training Summary")
616
- logger.info("="*80)
617
- logger.info(f"Total Iterations: {self.training_stats['iterations']}")
618
- logger.info(f"Total Samples: {self.training_stats['total_samples']}")
619
-
620
- if self.training_stats['avg_rewards']:
621
- logger.info(
622
- f"Final Avg Reward: "
623
- f"{self.training_stats['avg_rewards'][-1]:.4f}"
624
- )
625
- logger.info(
626
- f"Reward Improvement: "
627
- f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}"
628
- )
629
-
630
  logger.info("="*80 + "\n")
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader, TensorDataset
5
+ from typing import Dict, List, Tuple, Optional
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import gc
9
+ import logging
10
+
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class GRPOTrainer:
16
+ def __init__(
17
+ self,
18
+ actor_model,
19
+ reward_model,
20
+ ref_model,
21
+ tokenizer,
22
+ learning_rate: float = 1e-6,
23
+ kl_coef: float = 0.04,
24
+ group_size: int = 4,
25
+ clip_epsilon: float = 0.2,
26
+ max_grad_norm: float = 1.0,
27
+ grpo_epochs: int = 1,
28
+ update_batch_size: int = 4,
29
+ use_amp: bool = True,
30
+ value_clip: bool = False,
31
+ entropy_coef: float = 0.01,
32
+ advantage_normalization: str = 'group', # 'group', 'global', 'none'
33
+ kl_estimation_method: str = 'forward' # 'forward', 'reverse', 'symmetric'
34
+ ):
35
+ self.actor = actor_model
36
+ self.reward_model = reward_model
37
+ self.ref_model = ref_model
38
+ self.tokenizer = tokenizer
39
+
40
+ self.kl_coef = kl_coef
41
+ self.group_size = group_size
42
+ self.clip_epsilon = clip_epsilon
43
+ self.max_grad_norm = max_grad_norm
44
+ self.grpo_epochs = grpo_epochs
45
+ self.update_batch_size = update_batch_size
46
+ self.use_amp = use_amp
47
+ self.entropy_coef = entropy_coef
48
+ self.advantage_normalization = advantage_normalization
49
+ self.kl_estimation_method = kl_estimation_method
50
+
51
+ self.device = next(actor_model.parameters()).device
52
+
53
+ # 冻结参考模型和奖励模型
54
+ self.ref_model.eval()
55
+ self.ref_model.requires_grad_(False)
56
+ self.reward_model.eval()
57
+ self.reward_model.requires_grad_(False)
58
+
59
+ # 优化器配置
60
+ self.optimizer = optim.AdamW(
61
+ filter(lambda p: p.requires_grad, actor_model.parameters()),
62
+ lr=learning_rate,
63
+ weight_decay=0.01,
64
+ betas=(0.9, 0.95),
65
+ eps=1e-8
66
+ )
67
+
68
+ # 混合精度训练
69
+ self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
70
+
71
+ self.training_stats = {
72
+ 'iterations': 0,
73
+ 'total_samples': 0,
74
+ 'avg_rewards': [],
75
+ 'avg_kl': [],
76
+ 'policy_losses': []
77
+ }
78
+
79
+ logger.info(f"GRPO Trainer initialized:")
80
+ logger.info(f" Group Size: {group_size}")
81
+ logger.info(f" KL Coef: {kl_coef}")
82
+ logger.info(f" Clip Epsilon: {clip_epsilon}")
83
+ logger.info(f" Learning Rate: {learning_rate}")
84
+ logger.info(f" Update Batch Size: {update_batch_size}")
85
+ logger.info(f" Mixed Precision: {use_amp}")
86
+ logger.info(f" KL Estimation: {kl_estimation_method}")
87
+
88
+ def _compute_kl_divergence(
89
+ self,
90
+ log_probs: torch.Tensor,
91
+ ref_log_probs: torch.Tensor,
92
+ mask: torch.Tensor
93
+ ) -> torch.Tensor:
94
+
95
+ if self.kl_estimation_method == 'forward':
96
+ kl = log_probs - ref_log_probs
97
+ elif self.kl_estimation_method == 'reverse':
98
+ kl = ref_log_probs - log_probs
99
+ else:
100
+ forward_kl = log_probs - ref_log_probs
101
+ reverse_kl = ref_log_probs - log_probs
102
+ kl = 0.5 * (forward_kl + reverse_kl)
103
+
104
+ kl_penalty = (kl * mask).sum(dim=-1)
105
+ return kl_penalty
106
+
107
+ @torch.no_grad()
108
+ def generate_experience(
109
+ self,
110
+ prompts_dataloader: DataLoader,
111
+ max_gen_len: int,
112
+ temperature: float = 1.0,
113
+ top_p: float = 0.9
114
+ ) -> Dict:
115
+
116
+ self.actor.eval()
117
+
118
+ all_sequences = []
119
+ all_log_probs = []
120
+ all_advantages = []
121
+ all_prompt_lens = []
122
+ all_rewards = []
123
+
124
+ logger.info("Generating experience...")
125
+
126
+ for prompts in tqdm(prompts_dataloader, desc="Generating Experience"):
127
+ try:
128
+ # 处理不同的输入格式
129
+ if isinstance(prompts, (list, tuple)):
130
+ prompts = prompts[0]
131
+
132
+ prompts = prompts.to(self.device)
133
+ batch_size = prompts.shape[0]
134
+
135
+ # 扩展prompts以生成group_size个样本
136
+ prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0)
137
+ prompt_len = prompts_repeated.shape[1]
138
+
139
+ input_data = {
140
+ 'segments': [{
141
+ 'type': 'text',
142
+ 'data': prompts_repeated,
143
+ 'modality_id': 0
144
+ }]
145
+ }
146
+
147
+ # 1. 采样生成
148
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
149
+ response_ids = self.actor.generate(
150
+ input_data,
151
+ max_new_tokens=max_gen_len,
152
+ do_sample=True,
153
+ temperature=temperature,
154
+ top_p=top_p,
155
+ eos_token_id=self.tokenizer.eos_token_id,
156
+ pad_token_id=self.tokenizer.pad_token_id,
157
+ use_cache=True
158
+ )
159
+
160
+ sequences = torch.cat([prompts_repeated, response_ids], dim=1)
161
+
162
+ # 检查序列长度
163
+ if sequences.shape[1] <= prompt_len:
164
+ logger.warning("Generated sequence too short, skipping batch")
165
+ continue
166
+
167
+ full_input_data = {
168
+ 'segments': [{
169
+ 'type': 'text',
170
+ 'data': sequences,
171
+ 'modality_id': 0
172
+ }]
173
+ }
174
+
175
+ # 2. 计算当前策略和参考策略的 LogProbs
176
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
177
+ actor_out = self.actor(full_input_data)
178
+ ref_out = self.ref_model(full_input_data)
179
+
180
+ logits = actor_out['logits'][:, :-1, :]
181
+ ref_logits = ref_out['logits'][:, :-1, :]
182
+ targets = sequences[:, 1:]
183
+
184
+ log_probs = F.log_softmax(logits, dim=-1)
185
+ ref_log_probs = F.log_softmax(ref_logits, dim=-1)
186
+
187
+ # 提取对应token的log概率
188
+ per_token_log_probs = torch.gather(
189
+ log_probs, -1, targets.unsqueeze(-1)
190
+ ).squeeze(-1)
191
+ per_token_ref_log_probs = torch.gather(
192
+ ref_log_probs, -1, targets.unsqueeze(-1)
193
+ ).squeeze(-1)
194
+
195
+ # 3. 计算 KL 散度 (只针对response部分)
196
+ response_mask = torch.arange(
197
+ sequences.size(1) - 1, device=self.device
198
+ ) >= (prompt_len - 1)
199
+ response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs)
200
+ response_mask = response_mask.float()
201
+
202
+ kl_penalty = self._compute_kl_divergence(
203
+ per_token_log_probs,
204
+ per_token_ref_log_probs,
205
+ response_mask
206
+ )
207
+
208
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
209
+ reward_output = self.reward_model(full_input_data)
210
+
211
+ # reward_model返回 (batch_size, seq_len),取最后一个位置的奖励
212
+ if reward_output.dim() == 2:
213
+ raw_rewards = reward_output[:, -1]
214
+ else:
215
+ raw_rewards = reward_output.squeeze(-1)
216
+
217
+ # 5. 组合总奖励: R_total = R_env - β * KL
218
+ total_rewards = raw_rewards - self.kl_coef * kl_penalty
219
+
220
+ # 6. 计算组内相对优势
221
+ rewards_grouped = total_rewards.view(batch_size, self.group_size)
222
+
223
+ if self.advantage_normalization == 'group':
224
+ # 组内标准化
225
+ mean_grouped = rewards_grouped.mean(dim=1, keepdim=True)
226
+ std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
227
+ advantages = (rewards_grouped - mean_grouped) / std_grouped
228
+ elif self.advantage_normalization == 'global':
229
+ # 全局标准化
230
+ advantages = (rewards_grouped - rewards_grouped.mean()) / (
231
+ rewards_grouped.std() + 1e-8
232
+ )
233
+ else: # 'none'
234
+ advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True)
235
+
236
+ advantages = advantages.view(-1)
237
+
238
+ # 保存数据
239
+ all_sequences.append(sequences.cpu())
240
+ all_log_probs.append(per_token_log_probs.detach().cpu())
241
+ all_advantages.append(advantages.detach().cpu())
242
+ all_prompt_lens.append(
243
+ torch.full((sequences.size(0),), prompt_len, dtype=torch.long)
244
+ )
245
+ all_rewards.append(total_rewards.detach().cpu())
246
+
247
+ # 清理中间变量
248
+ del logits, ref_logits, actor_out, ref_out
249
+ del log_probs, ref_log_probs, reward_output
250
+
251
+ except Exception as e:
252
+ logger.error(f"Error generating experience for batch: {e}")
253
+ import traceback
254
+ traceback.print_exc()
255
+ continue
256
+
257
+ finally:
258
+ torch.cuda.empty_cache()
259
+
260
+ if not all_sequences:
261
+ raise RuntimeError("No valid sequences generated")
262
+
263
+ # 合并所有数据
264
+ experience = {
265
+ 'sequences': torch.cat(all_sequences, dim=0),
266
+ 'log_probs': torch.cat(all_log_probs, dim=0),
267
+ 'advantages': torch.cat(all_advantages, dim=0),
268
+ 'prompt_lengths': torch.cat(all_prompt_lens, dim=0),
269
+ 'rewards': torch.cat(all_rewards, dim=0)
270
+ }
271
+
272
+ # 统计信息
273
+ logger.info(f"Generated {len(experience['sequences'])} sequences")
274
+ logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}")
275
+ logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}")
276
+ logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}")
277
+
278
+ return experience
279
+
280
+ def grpo_step(
281
+ self,
282
+ dataset: TensorDataset
283
+ ) -> Dict[str, float]:
284
+ self.actor.train()
285
+
286
+ dataloader = DataLoader(
287
+ dataset,
288
+ batch_size=self.update_batch_size,
289
+ shuffle=True,
290
+ drop_last=False
291
+ )
292
+
293
+ epoch_stats = {
294
+ 'total_loss': 0.0,
295
+ 'policy_loss': 0.0,
296
+ 'entropy': 0.0,
297
+ 'approx_kl': 0.0,
298
+ 'clip_fraction': 0.0,
299
+ 'steps': 0
300
+ }
301
+
302
+ for batch_data in dataloader:
303
+ sequences, old_log_probs, advantages, prompt_lens = batch_data
304
+
305
+ sequences = sequences.to(self.device)
306
+ old_log_probs = old_log_probs.to(self.device)
307
+ advantages = advantages.to(self.device)
308
+
309
+ input_data = {
310
+ 'segments': [{
311
+ 'type': 'text',
312
+ 'data': sequences,
313
+ 'modality_id': 0
314
+ }]
315
+ }
316
+
317
+ with torch.amp.autocast('cuda', enabled=self.use_amp):
318
+ outputs = self.actor(input_data)
319
+ logits = outputs['logits'][:, :-1, :]
320
+
321
+ # 计算新的log probabilities
322
+ targets = sequences[:, 1:]
323
+ log_probs_dist = F.log_softmax(logits, dim=-1)
324
+ new_log_probs = torch.gather(
325
+ log_probs_dist, -1, targets.unsqueeze(-1)
326
+ ).squeeze(-1)
327
+
328
+ # 构建response mask
329
+ mask = torch.zeros_like(new_log_probs)
330
+ for i, pl in enumerate(prompt_lens):
331
+ mask[i, pl-1:] = 1.0
332
+
333
+ # 计算概率比率
334
+ ratio = torch.exp(new_log_probs - old_log_probs)
335
+
336
+ # 扩展advantages到序列维度
337
+ adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs)
338
+
339
+ # PPO clip损失
340
+ surr1 = ratio * adv_expanded
341
+ surr2 = torch.clamp(
342
+ ratio,
343
+ 1.0 - self.clip_epsilon,
344
+ 1.0 + self.clip_epsilon
345
+ ) * adv_expanded
346
+
347
+ # 策略损失
348
+ policy_loss = -torch.min(surr1, surr2)
349
+ policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8)
350
+
351
+ # 熵奖励
352
+ probs = F.softmax(logits, dim=-1)
353
+ entropy = -(probs * log_probs_dist).sum(dim=-1)
354
+ entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8)
355
+
356
+ # 总损失
357
+ loss = policy_loss - self.entropy_coef * entropy_bonus
358
+
359
+ # 统计信息
360
+ with torch.no_grad():
361
+ log_ratio = new_log_probs - old_log_probs
362
+ approx_kl = ((ratio - 1) - log_ratio) * mask
363
+ approx_kl = approx_kl.sum() / (mask.sum() + 1e-8)
364
+
365
+ clip_fraction = ((ratio > 1 + self.clip_epsilon) |
366
+ (ratio < 1 - self.clip_epsilon)).float()
367
+ clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8)
368
+
369
+ self.optimizer.zero_grad()
370
+ self.scaler.scale(loss).backward()
371
+
372
+ # 梯度裁剪
373
+ self.scaler.unscale_(self.optimizer)
374
+ grad_norm = torch.nn.utils.clip_grad_norm_(
375
+ self.actor.parameters(),
376
+ self.max_grad_norm
377
+ )
378
+
379
+ self.scaler.step(self.optimizer)
380
+ self.scaler.update()
381
+
382
+ # 累积统计
383
+ epoch_stats['total_loss'] += loss.item()
384
+ epoch_stats['policy_loss'] += policy_loss.item()
385
+ epoch_stats['entropy'] += entropy_bonus.item()
386
+ epoch_stats['approx_kl'] += approx_kl.item()
387
+ epoch_stats['clip_fraction'] += clip_fraction.item()
388
+ epoch_stats['steps'] += 1
389
+
390
+ # 计算平均值
391
+ for key in epoch_stats:
392
+ if key != 'steps':
393
+ epoch_stats[key] /= max(epoch_stats['steps'], 1)
394
+
395
+ return epoch_stats
396
+
397
+ def train(
398
+ self,
399
+ prompt_dataloader: DataLoader,
400
+ num_iterations: int = 1,
401
+ max_gen_len: int = 50,
402
+ temperature: float = 1.0,
403
+ save_every: int = 5,
404
+ save_path: str = "checkpoints"
405
+ ):
406
+
407
+ logger.info(f"\n{'='*80}")
408
+ logger.info(f"Starting GRPO Training")
409
+ logger.info(f" Iterations: {num_iterations}")
410
+ logger.info(f" Max Gen Length: {max_gen_len}")
411
+ logger.info(f" Temperature: {temperature}")
412
+ logger.info(f"{'='*80}\n")
413
+
414
+ for iteration in range(num_iterations):
415
+ try:
416
+ # 1. 生成经验
417
+ experience = self.generate_experience(
418
+ prompt_dataloader,
419
+ max_gen_len,
420
+ temperature
421
+ )
422
+
423
+ dataset = TensorDataset(
424
+ experience['sequences'],
425
+ experience['log_probs'],
426
+ experience['advantages'],
427
+ experience['prompt_lengths']
428
+ )
429
+
430
+ # 2. 策略优化
431
+ logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...")
432
+ all_epoch_stats = []
433
+
434
+ for epoch in range(self.grpo_epochs):
435
+ stats = self.grpo_step(dataset)
436
+ all_epoch_stats.append(stats)
437
+
438
+ logger.info(
439
+ f" Epoch {epoch+1}/{self.grpo_epochs} | "
440
+ f"Loss: {stats['total_loss']:.4f} | "
441
+ f"KL: {stats['approx_kl']:.4f} | "
442
+ f"Clip%: {stats['clip_fraction']*100:.1f}"
443
+ )
444
+
445
+ # 3. 汇总统计
446
+ avg_stats = {
447
+ key: np.mean([s[key] for s in all_epoch_stats])
448
+ for key in all_epoch_stats[0].keys()
449
+ }
450
+
451
+ self.training_stats['iterations'] += 1
452
+ self.training_stats['total_samples'] += len(experience['sequences'])
453
+ self.training_stats['avg_rewards'].append(
454
+ experience['rewards'].mean().item()
455
+ )
456
+ self.training_stats['avg_kl'].append(avg_stats['approx_kl'])
457
+ self.training_stats['policy_losses'].append(avg_stats['policy_loss'])
458
+
459
+ # 4. 打印进度
460
+ logger.info(f"\n{'='*80}")
461
+ logger.info(f"Iteration {iteration+1}/{num_iterations} Complete")
462
+ logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}")
463
+ logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}")
464
+ logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}")
465
+ logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}")
466
+ logger.info(f" Entropy: {avg_stats['entropy']:.4f}")
467
+ logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%")
468
+ logger.info(f"{'='*80}\n")
469
+
470
+ # 5. 保存checkpoint
471
+ if (iteration + 1) % save_every == 0:
472
+ self.save_checkpoint(
473
+ f"{save_path}/grpo_iter_{iteration+1}.pt"
474
+ )
475
+
476
+ # 6. 清理内存
477
+ del experience, dataset
478
+ gc.collect()
479
+ torch.cuda.empty_cache()
480
+
481
+ except Exception as e:
482
+ logger.error(f"Error in iteration {iteration+1}: {e}")
483
+ import traceback
484
+ traceback.print_exc()
485
+ continue
486
+
487
+ logger.info("GRPO Training Complete!")
488
+ self.print_training_summary()
489
+
490
+ def save_checkpoint(self, path: str):
491
+ import os
492
+ os.makedirs(os.path.dirname(path), exist_ok=True)
493
+
494
+ checkpoint = {
495
+ 'actor_state_dict': self.actor.state_dict(),
496
+ 'optimizer_state_dict': self.optimizer.state_dict(),
497
+ 'scaler_state_dict': self.scaler.state_dict(),
498
+ 'training_stats': self.training_stats,
499
+ 'config': {
500
+ 'kl_coef': self.kl_coef,
501
+ 'group_size': self.group_size,
502
+ 'clip_epsilon': self.clip_epsilon,
503
+ }
504
+ }
505
+
506
+ torch.save(checkpoint, path)
507
+ logger.info(f"Checkpoint saved to {path}")
508
+
509
+ def load_checkpoint(self, path: str):
510
+ checkpoint = torch.load(path, map_location=self.device)
511
+
512
+ self.actor.load_state_dict(checkpoint['actor_state_dict'])
513
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
514
+
515
+ if 'scaler_state_dict' in checkpoint and self.use_amp:
516
+ self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
517
+
518
+ self.training_stats = checkpoint['training_stats']
519
+
520
+ logger.info(f"Checkpoint loaded from {path}")
521
+
522
+ def print_training_summary(self):
523
+ logger.info("\n" + "="*80)
524
+ logger.info("Training Summary")
525
+ logger.info("="*80)
526
+ logger.info(f"Total Iterations: {self.training_stats['iterations']}")
527
+ logger.info(f"Total Samples: {self.training_stats['total_samples']}")
528
+
529
+ if self.training_stats['avg_rewards']:
530
+ logger.info(
531
+ f"Final Avg Reward: "
532
+ f"{self.training_stats['avg_rewards'][-1]:.4f}"
533
+ )
534
+ logger.info(
535
+ f"Reward Improvement: "
536
+ f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}"
537
+ )
538
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  logger.info("="*80 + "\n")