import torch import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset from typing import Dict, List, Tuple, Optional from tqdm import tqdm import numpy as np import gc import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class GRPOTrainer: def __init__( self, actor_model, reward_model, ref_model, tokenizer, learning_rate: float = 1e-6, kl_coef: float = 0.04, group_size: int = 4, clip_epsilon: float = 0.2, max_grad_norm: float = 1.0, grpo_epochs: int = 1, update_batch_size: int = 4, use_amp: bool = True, value_clip: bool = False, entropy_coef: float = 0.01, advantage_normalization: str = 'group', # 'group', 'global', 'none' kl_estimation_method: str = 'forward' # 'forward', 'reverse', 'symmetric' ): self.actor = actor_model self.reward_model = reward_model self.ref_model = ref_model self.tokenizer = tokenizer self.kl_coef = kl_coef self.group_size = group_size self.clip_epsilon = clip_epsilon self.max_grad_norm = max_grad_norm self.grpo_epochs = grpo_epochs self.update_batch_size = update_batch_size self.use_amp = use_amp self.entropy_coef = entropy_coef self.advantage_normalization = advantage_normalization self.kl_estimation_method = kl_estimation_method self.device = next(actor_model.parameters()).device # 冻结参考模型和奖励模型 self.ref_model.eval() self.ref_model.requires_grad_(False) self.reward_model.eval() self.reward_model.requires_grad_(False) # 优化器配置 self.optimizer = optim.AdamW( filter(lambda p: p.requires_grad, actor_model.parameters()), lr=learning_rate, weight_decay=0.01, betas=(0.9, 0.95), eps=1e-8 ) # 混合精度训练 self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) self.training_stats = { 'iterations': 0, 'total_samples': 0, 'avg_rewards': [], 'avg_kl': [], 'policy_losses': [] } logger.info(f"GRPO Trainer initialized:") logger.info(f" Group Size: {group_size}") logger.info(f" KL Coef: {kl_coef}") logger.info(f" Clip Epsilon: {clip_epsilon}") logger.info(f" Learning Rate: {learning_rate}") logger.info(f" Update Batch Size: {update_batch_size}") logger.info(f" Mixed Precision: {use_amp}") logger.info(f" KL Estimation: {kl_estimation_method}") def _compute_kl_divergence( self, log_probs: torch.Tensor, ref_log_probs: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: if self.kl_estimation_method == 'forward': kl = log_probs - ref_log_probs elif self.kl_estimation_method == 'reverse': kl = ref_log_probs - log_probs else: forward_kl = log_probs - ref_log_probs reverse_kl = ref_log_probs - log_probs kl = 0.5 * (forward_kl + reverse_kl) kl_penalty = (kl * mask).sum(dim=-1) return kl_penalty @torch.no_grad() def generate_experience( self, prompts_dataloader: DataLoader, max_gen_len: int, temperature: float = 1.0, top_p: float = 0.9 ) -> Dict: self.actor.eval() all_sequences = [] all_log_probs = [] all_advantages = [] all_prompt_lens = [] all_rewards = [] logger.info("Generating experience...") for prompts in tqdm(prompts_dataloader, desc="Generating Experience"): try: # 处理不同的输入格式 if isinstance(prompts, (list, tuple)): prompts = prompts[0] prompts = prompts.to(self.device) batch_size = prompts.shape[0] # 扩展prompts以生成group_size个样本 prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0) prompt_len = prompts_repeated.shape[1] input_data = { 'segments': [{ 'type': 'text', 'data': prompts_repeated, 'modality_id': 0 }] } # 1. 采样生成 with torch.amp.autocast('cuda', enabled=self.use_amp): response_ids = self.actor.generate( input_data, max_new_tokens=max_gen_len, do_sample=True, temperature=temperature, top_p=top_p, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, use_cache=True ) sequences = torch.cat([prompts_repeated, response_ids], dim=1) # 检查序列长度 if sequences.shape[1] <= prompt_len: logger.warning("Generated sequence too short, skipping batch") continue full_input_data = { 'segments': [{ 'type': 'text', 'data': sequences, 'modality_id': 0 }] } # 2. 计算当前策略和参考策略的 LogProbs with torch.amp.autocast('cuda', enabled=self.use_amp): actor_out = self.actor(full_input_data) ref_out = self.ref_model(full_input_data) logits = actor_out['logits'][:, :-1, :] ref_logits = ref_out['logits'][:, :-1, :] targets = sequences[:, 1:] log_probs = F.log_softmax(logits, dim=-1) ref_log_probs = F.log_softmax(ref_logits, dim=-1) # 提取对应token的log概率 per_token_log_probs = torch.gather( log_probs, -1, targets.unsqueeze(-1) ).squeeze(-1) per_token_ref_log_probs = torch.gather( ref_log_probs, -1, targets.unsqueeze(-1) ).squeeze(-1) # 3. 计算 KL 散度 (只针对response部分) response_mask = torch.arange( sequences.size(1) - 1, device=self.device ) >= (prompt_len - 1) response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs) response_mask = response_mask.float() kl_penalty = self._compute_kl_divergence( per_token_log_probs, per_token_ref_log_probs, response_mask ) with torch.amp.autocast('cuda', enabled=self.use_amp): reward_output = self.reward_model(full_input_data) # reward_model返回 (batch_size, seq_len),取最后一个位置的奖励 if reward_output.dim() == 2: raw_rewards = reward_output[:, -1] else: raw_rewards = reward_output.squeeze(-1) # 5. 组合总奖励: R_total = R_env - β * KL total_rewards = raw_rewards - self.kl_coef * kl_penalty # 6. 计算组内相对优势 rewards_grouped = total_rewards.view(batch_size, self.group_size) if self.advantage_normalization == 'group': # 组内标准化 mean_grouped = rewards_grouped.mean(dim=1, keepdim=True) std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8 advantages = (rewards_grouped - mean_grouped) / std_grouped elif self.advantage_normalization == 'global': # 全局标准化 advantages = (rewards_grouped - rewards_grouped.mean()) / ( rewards_grouped.std() + 1e-8 ) else: # 'none' advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True) advantages = advantages.view(-1) # 保存数据 all_sequences.append(sequences.cpu()) all_log_probs.append(per_token_log_probs.detach().cpu()) all_advantages.append(advantages.detach().cpu()) all_prompt_lens.append( torch.full((sequences.size(0),), prompt_len, dtype=torch.long) ) all_rewards.append(total_rewards.detach().cpu()) # 清理中间变量 del logits, ref_logits, actor_out, ref_out del log_probs, ref_log_probs, reward_output except Exception as e: logger.error(f"Error generating experience for batch: {e}") import traceback traceback.print_exc() continue finally: torch.cuda.empty_cache() if not all_sequences: raise RuntimeError("No valid sequences generated") # 合并所有数据 experience = { 'sequences': torch.cat(all_sequences, dim=0), 'log_probs': torch.cat(all_log_probs, dim=0), 'advantages': torch.cat(all_advantages, dim=0), 'prompt_lengths': torch.cat(all_prompt_lens, dim=0), 'rewards': torch.cat(all_rewards, dim=0) } # 统计信息 logger.info(f"Generated {len(experience['sequences'])} sequences") logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}") logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}") logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}") return experience def grpo_step( self, dataset: TensorDataset ) -> Dict[str, float]: self.actor.train() dataloader = DataLoader( dataset, batch_size=self.update_batch_size, shuffle=True, drop_last=False ) epoch_stats = { 'total_loss': 0.0, 'policy_loss': 0.0, 'entropy': 0.0, 'approx_kl': 0.0, 'clip_fraction': 0.0, 'steps': 0 } for batch_data in dataloader: sequences, old_log_probs, advantages, prompt_lens = batch_data sequences = sequences.to(self.device) old_log_probs = old_log_probs.to(self.device) advantages = advantages.to(self.device) input_data = { 'segments': [{ 'type': 'text', 'data': sequences, 'modality_id': 0 }] } with torch.amp.autocast('cuda', enabled=self.use_amp): outputs = self.actor(input_data) logits = outputs['logits'][:, :-1, :] # 计算新的log probabilities targets = sequences[:, 1:] log_probs_dist = F.log_softmax(logits, dim=-1) new_log_probs = torch.gather( log_probs_dist, -1, targets.unsqueeze(-1) ).squeeze(-1) # 构建response mask mask = torch.zeros_like(new_log_probs) for i, pl in enumerate(prompt_lens): mask[i, pl-1:] = 1.0 # 计算概率比率 ratio = torch.exp(new_log_probs - old_log_probs) # 扩展advantages到序列维度 adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs) # PPO clip损失 surr1 = ratio * adv_expanded surr2 = torch.clamp( ratio, 1.0 - self.clip_epsilon, 1.0 + self.clip_epsilon ) * adv_expanded # 策略损失 policy_loss = -torch.min(surr1, surr2) policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8) # 熵奖励 probs = F.softmax(logits, dim=-1) entropy = -(probs * log_probs_dist).sum(dim=-1) entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8) # 总损失 loss = policy_loss - self.entropy_coef * entropy_bonus # 统计信息 with torch.no_grad(): log_ratio = new_log_probs - old_log_probs approx_kl = ((ratio - 1) - log_ratio) * mask approx_kl = approx_kl.sum() / (mask.sum() + 1e-8) clip_fraction = ((ratio > 1 + self.clip_epsilon) | (ratio < 1 - self.clip_epsilon)).float() clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8) self.optimizer.zero_grad() self.scaler.scale(loss).backward() # 梯度裁剪 self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( self.actor.parameters(), self.max_grad_norm ) self.scaler.step(self.optimizer) self.scaler.update() # 累积统计 epoch_stats['total_loss'] += loss.item() epoch_stats['policy_loss'] += policy_loss.item() epoch_stats['entropy'] += entropy_bonus.item() epoch_stats['approx_kl'] += approx_kl.item() epoch_stats['clip_fraction'] += clip_fraction.item() epoch_stats['steps'] += 1 # 计算平均值 for key in epoch_stats: if key != 'steps': epoch_stats[key] /= max(epoch_stats['steps'], 1) return epoch_stats def train( self, prompt_dataloader: DataLoader, num_iterations: int = 1, max_gen_len: int = 50, temperature: float = 1.0, save_every: int = 5, save_path: str = "checkpoints" ): logger.info(f"\n{'='*80}") logger.info(f"Starting GRPO Training") logger.info(f" Iterations: {num_iterations}") logger.info(f" Max Gen Length: {max_gen_len}") logger.info(f" Temperature: {temperature}") logger.info(f"{'='*80}\n") for iteration in range(num_iterations): try: # 1. 生成经验 experience = self.generate_experience( prompt_dataloader, max_gen_len, temperature ) dataset = TensorDataset( experience['sequences'], experience['log_probs'], experience['advantages'], experience['prompt_lengths'] ) # 2. 策略优化 logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...") all_epoch_stats = [] for epoch in range(self.grpo_epochs): stats = self.grpo_step(dataset) all_epoch_stats.append(stats) logger.info( f" Epoch {epoch+1}/{self.grpo_epochs} | " f"Loss: {stats['total_loss']:.4f} | " f"KL: {stats['approx_kl']:.4f} | " f"Clip%: {stats['clip_fraction']*100:.1f}" ) # 3. 汇总统计 avg_stats = { key: np.mean([s[key] for s in all_epoch_stats]) for key in all_epoch_stats[0].keys() } self.training_stats['iterations'] += 1 self.training_stats['total_samples'] += len(experience['sequences']) self.training_stats['avg_rewards'].append( experience['rewards'].mean().item() ) self.training_stats['avg_kl'].append(avg_stats['approx_kl']) self.training_stats['policy_losses'].append(avg_stats['policy_loss']) # 4. 打印进度 logger.info(f"\n{'='*80}") logger.info(f"Iteration {iteration+1}/{num_iterations} Complete") logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}") logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}") logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}") logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}") logger.info(f" Entropy: {avg_stats['entropy']:.4f}") logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%") logger.info(f"{'='*80}\n") # 5. 保存checkpoint if (iteration + 1) % save_every == 0: self.save_checkpoint( f"{save_path}/grpo_iter_{iteration+1}.pt" ) # 6. 清理内存 del experience, dataset gc.collect() torch.cuda.empty_cache() except Exception as e: logger.error(f"Error in iteration {iteration+1}: {e}") import traceback traceback.print_exc() continue logger.info("GRPO Training Complete!") self.print_training_summary() def save_checkpoint(self, path: str): import os os.makedirs(os.path.dirname(path), exist_ok=True) checkpoint = { 'actor_state_dict': self.actor.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scaler_state_dict': self.scaler.state_dict(), 'training_stats': self.training_stats, 'config': { 'kl_coef': self.kl_coef, 'group_size': self.group_size, 'clip_epsilon': self.clip_epsilon, } } torch.save(checkpoint, path) logger.info(f"Checkpoint saved to {path}") def load_checkpoint(self, path: str): checkpoint = torch.load(path, map_location=self.device) self.actor.load_state_dict(checkpoint['actor_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) if 'scaler_state_dict' in checkpoint and self.use_amp: self.scaler.load_state_dict(checkpoint['scaler_state_dict']) self.training_stats = checkpoint['training_stats'] logger.info(f"Checkpoint loaded from {path}") def print_training_summary(self): logger.info("\n" + "="*80) logger.info("Training Summary") logger.info("="*80) logger.info(f"Total Iterations: {self.training_stats['iterations']}") logger.info(f"Total Samples: {self.training_stats['total_samples']}") if self.training_stats['avg_rewards']: logger.info( f"Final Avg Reward: " f"{self.training_stats['avg_rewards'][-1]:.4f}" ) logger.info( f"Reward Improvement: " f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}" ) logger.info("="*80 + "\n")