MultiModal / grpo.py
szxllm's picture
Update grpo.py
e68927b verified
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
import numpy as np
import gc
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class GRPOTrainer:
def __init__(
self,
actor_model,
reward_model,
ref_model,
tokenizer,
learning_rate: float = 1e-6,
kl_coef: float = 0.04,
group_size: int = 4,
clip_epsilon: float = 0.2,
max_grad_norm: float = 1.0,
grpo_epochs: int = 1,
update_batch_size: int = 4,
use_amp: bool = True,
value_clip: bool = False,
entropy_coef: float = 0.01,
advantage_normalization: str = 'group', # 'group', 'global', 'none'
kl_estimation_method: str = 'forward' # 'forward', 'reverse', 'symmetric'
):
self.actor = actor_model
self.reward_model = reward_model
self.ref_model = ref_model
self.tokenizer = tokenizer
self.kl_coef = kl_coef
self.group_size = group_size
self.clip_epsilon = clip_epsilon
self.max_grad_norm = max_grad_norm
self.grpo_epochs = grpo_epochs
self.update_batch_size = update_batch_size
self.use_amp = use_amp
self.entropy_coef = entropy_coef
self.advantage_normalization = advantage_normalization
self.kl_estimation_method = kl_estimation_method
self.device = next(actor_model.parameters()).device
# 冻结参考模型和奖励模型
self.ref_model.eval()
self.ref_model.requires_grad_(False)
self.reward_model.eval()
self.reward_model.requires_grad_(False)
# 优化器配置
self.optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, actor_model.parameters()),
lr=learning_rate,
weight_decay=0.01,
betas=(0.9, 0.95),
eps=1e-8
)
# 混合精度训练
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp)
self.training_stats = {
'iterations': 0,
'total_samples': 0,
'avg_rewards': [],
'avg_kl': [],
'policy_losses': []
}
logger.info(f"GRPO Trainer initialized:")
logger.info(f" Group Size: {group_size}")
logger.info(f" KL Coef: {kl_coef}")
logger.info(f" Clip Epsilon: {clip_epsilon}")
logger.info(f" Learning Rate: {learning_rate}")
logger.info(f" Update Batch Size: {update_batch_size}")
logger.info(f" Mixed Precision: {use_amp}")
logger.info(f" KL Estimation: {kl_estimation_method}")
def _compute_kl_divergence(
self,
log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
mask: torch.Tensor
) -> torch.Tensor:
if self.kl_estimation_method == 'forward':
kl = log_probs - ref_log_probs
elif self.kl_estimation_method == 'reverse':
kl = ref_log_probs - log_probs
else:
forward_kl = log_probs - ref_log_probs
reverse_kl = ref_log_probs - log_probs
kl = 0.5 * (forward_kl + reverse_kl)
kl_penalty = (kl * mask).sum(dim=-1)
return kl_penalty
@torch.no_grad()
def generate_experience(
self,
prompts_dataloader: DataLoader,
max_gen_len: int,
temperature: float = 1.0,
top_p: float = 0.9
) -> Dict:
self.actor.eval()
all_sequences = []
all_log_probs = []
all_advantages = []
all_prompt_lens = []
all_rewards = []
logger.info("Generating experience...")
for prompts in tqdm(prompts_dataloader, desc="Generating Experience"):
try:
# 处理不同的输入格式
if isinstance(prompts, (list, tuple)):
prompts = prompts[0]
prompts = prompts.to(self.device)
batch_size = prompts.shape[0]
# 扩展prompts以生成group_size个样本
prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0)
prompt_len = prompts_repeated.shape[1]
input_data = {
'segments': [{
'type': 'text',
'data': prompts_repeated,
'modality_id': 0
}]
}
# 1. 采样生成
with torch.amp.autocast('cuda', enabled=self.use_amp):
response_ids = self.actor.generate(
input_data,
max_new_tokens=max_gen_len,
do_sample=True,
temperature=temperature,
top_p=top_p,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True
)
sequences = torch.cat([prompts_repeated, response_ids], dim=1)
# 检查序列长度
if sequences.shape[1] <= prompt_len:
logger.warning("Generated sequence too short, skipping batch")
continue
full_input_data = {
'segments': [{
'type': 'text',
'data': sequences,
'modality_id': 0
}]
}
# 2. 计算当前策略和参考策略的 LogProbs
with torch.amp.autocast('cuda', enabled=self.use_amp):
actor_out = self.actor(full_input_data)
ref_out = self.ref_model(full_input_data)
logits = actor_out['logits'][:, :-1, :]
ref_logits = ref_out['logits'][:, :-1, :]
targets = sequences[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
ref_log_probs = F.log_softmax(ref_logits, dim=-1)
# 提取对应token的log概率
per_token_log_probs = torch.gather(
log_probs, -1, targets.unsqueeze(-1)
).squeeze(-1)
per_token_ref_log_probs = torch.gather(
ref_log_probs, -1, targets.unsqueeze(-1)
).squeeze(-1)
# 3. 计算 KL 散度 (只针对response部分)
response_mask = torch.arange(
sequences.size(1) - 1, device=self.device
) >= (prompt_len - 1)
response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs)
response_mask = response_mask.float()
kl_penalty = self._compute_kl_divergence(
per_token_log_probs,
per_token_ref_log_probs,
response_mask
)
with torch.amp.autocast('cuda', enabled=self.use_amp):
reward_output = self.reward_model(full_input_data)
# reward_model返回 (batch_size, seq_len),取最后一个位置的奖励
if reward_output.dim() == 2:
raw_rewards = reward_output[:, -1]
else:
raw_rewards = reward_output.squeeze(-1)
# 5. 组合总奖励: R_total = R_env - β * KL
total_rewards = raw_rewards - self.kl_coef * kl_penalty
# 6. 计算组内相对优势
rewards_grouped = total_rewards.view(batch_size, self.group_size)
if self.advantage_normalization == 'group':
# 组内标准化
mean_grouped = rewards_grouped.mean(dim=1, keepdim=True)
std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8
advantages = (rewards_grouped - mean_grouped) / std_grouped
elif self.advantage_normalization == 'global':
# 全局标准化
advantages = (rewards_grouped - rewards_grouped.mean()) / (
rewards_grouped.std() + 1e-8
)
else: # 'none'
advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True)
advantages = advantages.view(-1)
# 保存数据
all_sequences.append(sequences.cpu())
all_log_probs.append(per_token_log_probs.detach().cpu())
all_advantages.append(advantages.detach().cpu())
all_prompt_lens.append(
torch.full((sequences.size(0),), prompt_len, dtype=torch.long)
)
all_rewards.append(total_rewards.detach().cpu())
# 清理中间变量
del logits, ref_logits, actor_out, ref_out
del log_probs, ref_log_probs, reward_output
except Exception as e:
logger.error(f"Error generating experience for batch: {e}")
import traceback
traceback.print_exc()
continue
finally:
torch.cuda.empty_cache()
if not all_sequences:
raise RuntimeError("No valid sequences generated")
# 合并所有数据
experience = {
'sequences': torch.cat(all_sequences, dim=0),
'log_probs': torch.cat(all_log_probs, dim=0),
'advantages': torch.cat(all_advantages, dim=0),
'prompt_lengths': torch.cat(all_prompt_lens, dim=0),
'rewards': torch.cat(all_rewards, dim=0)
}
# 统计信息
logger.info(f"Generated {len(experience['sequences'])} sequences")
logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}")
logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}")
logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}")
return experience
def grpo_step(
self,
dataset: TensorDataset
) -> Dict[str, float]:
self.actor.train()
dataloader = DataLoader(
dataset,
batch_size=self.update_batch_size,
shuffle=True,
drop_last=False
)
epoch_stats = {
'total_loss': 0.0,
'policy_loss': 0.0,
'entropy': 0.0,
'approx_kl': 0.0,
'clip_fraction': 0.0,
'steps': 0
}
for batch_data in dataloader:
sequences, old_log_probs, advantages, prompt_lens = batch_data
sequences = sequences.to(self.device)
old_log_probs = old_log_probs.to(self.device)
advantages = advantages.to(self.device)
input_data = {
'segments': [{
'type': 'text',
'data': sequences,
'modality_id': 0
}]
}
with torch.amp.autocast('cuda', enabled=self.use_amp):
outputs = self.actor(input_data)
logits = outputs['logits'][:, :-1, :]
# 计算新的log probabilities
targets = sequences[:, 1:]
log_probs_dist = F.log_softmax(logits, dim=-1)
new_log_probs = torch.gather(
log_probs_dist, -1, targets.unsqueeze(-1)
).squeeze(-1)
# 构建response mask
mask = torch.zeros_like(new_log_probs)
for i, pl in enumerate(prompt_lens):
mask[i, pl-1:] = 1.0
# 计算概率比率
ratio = torch.exp(new_log_probs - old_log_probs)
# 扩展advantages到序列维度
adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs)
# PPO clip损失
surr1 = ratio * adv_expanded
surr2 = torch.clamp(
ratio,
1.0 - self.clip_epsilon,
1.0 + self.clip_epsilon
) * adv_expanded
# 策略损失
policy_loss = -torch.min(surr1, surr2)
policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8)
# 熵奖励
probs = F.softmax(logits, dim=-1)
entropy = -(probs * log_probs_dist).sum(dim=-1)
entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8)
# 总损失
loss = policy_loss - self.entropy_coef * entropy_bonus
# 统计信息
with torch.no_grad():
log_ratio = new_log_probs - old_log_probs
approx_kl = ((ratio - 1) - log_ratio) * mask
approx_kl = approx_kl.sum() / (mask.sum() + 1e-8)
clip_fraction = ((ratio > 1 + self.clip_epsilon) |
(ratio < 1 - self.clip_epsilon)).float()
clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8)
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
# 梯度裁剪
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.actor.parameters(),
self.max_grad_norm
)
self.scaler.step(self.optimizer)
self.scaler.update()
# 累积统计
epoch_stats['total_loss'] += loss.item()
epoch_stats['policy_loss'] += policy_loss.item()
epoch_stats['entropy'] += entropy_bonus.item()
epoch_stats['approx_kl'] += approx_kl.item()
epoch_stats['clip_fraction'] += clip_fraction.item()
epoch_stats['steps'] += 1
# 计算平均值
for key in epoch_stats:
if key != 'steps':
epoch_stats[key] /= max(epoch_stats['steps'], 1)
return epoch_stats
def train(
self,
prompt_dataloader: DataLoader,
num_iterations: int = 1,
max_gen_len: int = 50,
temperature: float = 1.0,
save_every: int = 5,
save_path: str = "checkpoints"
):
logger.info(f"\n{'='*80}")
logger.info(f"Starting GRPO Training")
logger.info(f" Iterations: {num_iterations}")
logger.info(f" Max Gen Length: {max_gen_len}")
logger.info(f" Temperature: {temperature}")
logger.info(f"{'='*80}\n")
for iteration in range(num_iterations):
try:
# 1. 生成经验
experience = self.generate_experience(
prompt_dataloader,
max_gen_len,
temperature
)
dataset = TensorDataset(
experience['sequences'],
experience['log_probs'],
experience['advantages'],
experience['prompt_lengths']
)
# 2. 策略优化
logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...")
all_epoch_stats = []
for epoch in range(self.grpo_epochs):
stats = self.grpo_step(dataset)
all_epoch_stats.append(stats)
logger.info(
f" Epoch {epoch+1}/{self.grpo_epochs} | "
f"Loss: {stats['total_loss']:.4f} | "
f"KL: {stats['approx_kl']:.4f} | "
f"Clip%: {stats['clip_fraction']*100:.1f}"
)
# 3. 汇总统计
avg_stats = {
key: np.mean([s[key] for s in all_epoch_stats])
for key in all_epoch_stats[0].keys()
}
self.training_stats['iterations'] += 1
self.training_stats['total_samples'] += len(experience['sequences'])
self.training_stats['avg_rewards'].append(
experience['rewards'].mean().item()
)
self.training_stats['avg_kl'].append(avg_stats['approx_kl'])
self.training_stats['policy_losses'].append(avg_stats['policy_loss'])
# 4. 打印进度
logger.info(f"\n{'='*80}")
logger.info(f"Iteration {iteration+1}/{num_iterations} Complete")
logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}")
logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}")
logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}")
logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}")
logger.info(f" Entropy: {avg_stats['entropy']:.4f}")
logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%")
logger.info(f"{'='*80}\n")
# 5. 保存checkpoint
if (iteration + 1) % save_every == 0:
self.save_checkpoint(
f"{save_path}/grpo_iter_{iteration+1}.pt"
)
# 6. 清理内存
del experience, dataset
gc.collect()
torch.cuda.empty_cache()
except Exception as e:
logger.error(f"Error in iteration {iteration+1}: {e}")
import traceback
traceback.print_exc()
continue
logger.info("GRPO Training Complete!")
self.print_training_summary()
def save_checkpoint(self, path: str):
import os
os.makedirs(os.path.dirname(path), exist_ok=True)
checkpoint = {
'actor_state_dict': self.actor.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scaler_state_dict': self.scaler.state_dict(),
'training_stats': self.training_stats,
'config': {
'kl_coef': self.kl_coef,
'group_size': self.group_size,
'clip_epsilon': self.clip_epsilon,
}
}
torch.save(checkpoint, path)
logger.info(f"Checkpoint saved to {path}")
def load_checkpoint(self, path: str):
checkpoint = torch.load(path, map_location=self.device)
self.actor.load_state_dict(checkpoint['actor_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
if 'scaler_state_dict' in checkpoint and self.use_amp:
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
self.training_stats = checkpoint['training_stats']
logger.info(f"Checkpoint loaded from {path}")
def print_training_summary(self):
logger.info("\n" + "="*80)
logger.info("Training Summary")
logger.info("="*80)
logger.info(f"Total Iterations: {self.training_stats['iterations']}")
logger.info(f"Total Samples: {self.training_stats['total_samples']}")
if self.training_stats['avg_rewards']:
logger.info(
f"Final Avg Reward: "
f"{self.training_stats['avg_rewards'][-1]:.4f}"
)
logger.info(
f"Reward Improvement: "
f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}"
)
logger.info("="*80 + "\n")