|
|
import torch |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader, TensorDataset |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import gc |
|
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GRPOTrainer: |
|
|
def __init__( |
|
|
self, |
|
|
actor_model, |
|
|
reward_model, |
|
|
ref_model, |
|
|
tokenizer, |
|
|
learning_rate: float = 1e-6, |
|
|
kl_coef: float = 0.04, |
|
|
group_size: int = 4, |
|
|
clip_epsilon: float = 0.2, |
|
|
max_grad_norm: float = 1.0, |
|
|
grpo_epochs: int = 1, |
|
|
update_batch_size: int = 4, |
|
|
use_amp: bool = True, |
|
|
value_clip: bool = False, |
|
|
entropy_coef: float = 0.01, |
|
|
advantage_normalization: str = 'group', |
|
|
kl_estimation_method: str = 'forward' |
|
|
): |
|
|
self.actor = actor_model |
|
|
self.reward_model = reward_model |
|
|
self.ref_model = ref_model |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
self.kl_coef = kl_coef |
|
|
self.group_size = group_size |
|
|
self.clip_epsilon = clip_epsilon |
|
|
self.max_grad_norm = max_grad_norm |
|
|
self.grpo_epochs = grpo_epochs |
|
|
self.update_batch_size = update_batch_size |
|
|
self.use_amp = use_amp |
|
|
self.entropy_coef = entropy_coef |
|
|
self.advantage_normalization = advantage_normalization |
|
|
self.kl_estimation_method = kl_estimation_method |
|
|
|
|
|
self.device = next(actor_model.parameters()).device |
|
|
|
|
|
|
|
|
self.ref_model.eval() |
|
|
self.ref_model.requires_grad_(False) |
|
|
self.reward_model.eval() |
|
|
self.reward_model.requires_grad_(False) |
|
|
|
|
|
|
|
|
self.optimizer = optim.AdamW( |
|
|
filter(lambda p: p.requires_grad, actor_model.parameters()), |
|
|
lr=learning_rate, |
|
|
weight_decay=0.01, |
|
|
betas=(0.9, 0.95), |
|
|
eps=1e-8 |
|
|
) |
|
|
|
|
|
|
|
|
self.scaler = torch.amp.GradScaler('cuda', enabled=self.use_amp) |
|
|
|
|
|
self.training_stats = { |
|
|
'iterations': 0, |
|
|
'total_samples': 0, |
|
|
'avg_rewards': [], |
|
|
'avg_kl': [], |
|
|
'policy_losses': [] |
|
|
} |
|
|
|
|
|
logger.info(f"GRPO Trainer initialized:") |
|
|
logger.info(f" Group Size: {group_size}") |
|
|
logger.info(f" KL Coef: {kl_coef}") |
|
|
logger.info(f" Clip Epsilon: {clip_epsilon}") |
|
|
logger.info(f" Learning Rate: {learning_rate}") |
|
|
logger.info(f" Update Batch Size: {update_batch_size}") |
|
|
logger.info(f" Mixed Precision: {use_amp}") |
|
|
logger.info(f" KL Estimation: {kl_estimation_method}") |
|
|
|
|
|
def _compute_kl_divergence( |
|
|
self, |
|
|
log_probs: torch.Tensor, |
|
|
ref_log_probs: torch.Tensor, |
|
|
mask: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if self.kl_estimation_method == 'forward': |
|
|
kl = log_probs - ref_log_probs |
|
|
elif self.kl_estimation_method == 'reverse': |
|
|
kl = ref_log_probs - log_probs |
|
|
else: |
|
|
forward_kl = log_probs - ref_log_probs |
|
|
reverse_kl = ref_log_probs - log_probs |
|
|
kl = 0.5 * (forward_kl + reverse_kl) |
|
|
|
|
|
kl_penalty = (kl * mask).sum(dim=-1) |
|
|
return kl_penalty |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_experience( |
|
|
self, |
|
|
prompts_dataloader: DataLoader, |
|
|
max_gen_len: int, |
|
|
temperature: float = 1.0, |
|
|
top_p: float = 0.9 |
|
|
) -> Dict: |
|
|
|
|
|
self.actor.eval() |
|
|
|
|
|
all_sequences = [] |
|
|
all_log_probs = [] |
|
|
all_advantages = [] |
|
|
all_prompt_lens = [] |
|
|
all_rewards = [] |
|
|
|
|
|
logger.info("Generating experience...") |
|
|
|
|
|
for prompts in tqdm(prompts_dataloader, desc="Generating Experience"): |
|
|
try: |
|
|
|
|
|
if isinstance(prompts, (list, tuple)): |
|
|
prompts = prompts[0] |
|
|
|
|
|
prompts = prompts.to(self.device) |
|
|
batch_size = prompts.shape[0] |
|
|
|
|
|
|
|
|
prompts_repeated = prompts.repeat_interleave(self.group_size, dim=0) |
|
|
prompt_len = prompts_repeated.shape[1] |
|
|
|
|
|
input_data = { |
|
|
'segments': [{ |
|
|
'type': 'text', |
|
|
'data': prompts_repeated, |
|
|
'modality_id': 0 |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp): |
|
|
response_ids = self.actor.generate( |
|
|
input_data, |
|
|
max_new_tokens=max_gen_len, |
|
|
do_sample=True, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
sequences = torch.cat([prompts_repeated, response_ids], dim=1) |
|
|
|
|
|
|
|
|
if sequences.shape[1] <= prompt_len: |
|
|
logger.warning("Generated sequence too short, skipping batch") |
|
|
continue |
|
|
|
|
|
full_input_data = { |
|
|
'segments': [{ |
|
|
'type': 'text', |
|
|
'data': sequences, |
|
|
'modality_id': 0 |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp): |
|
|
actor_out = self.actor(full_input_data) |
|
|
ref_out = self.ref_model(full_input_data) |
|
|
|
|
|
logits = actor_out['logits'][:, :-1, :] |
|
|
ref_logits = ref_out['logits'][:, :-1, :] |
|
|
targets = sequences[:, 1:] |
|
|
|
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
|
ref_log_probs = F.log_softmax(ref_logits, dim=-1) |
|
|
|
|
|
|
|
|
per_token_log_probs = torch.gather( |
|
|
log_probs, -1, targets.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
per_token_ref_log_probs = torch.gather( |
|
|
ref_log_probs, -1, targets.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
|
response_mask = torch.arange( |
|
|
sequences.size(1) - 1, device=self.device |
|
|
) >= (prompt_len - 1) |
|
|
response_mask = response_mask.unsqueeze(0).expand_as(per_token_log_probs) |
|
|
response_mask = response_mask.float() |
|
|
|
|
|
kl_penalty = self._compute_kl_divergence( |
|
|
per_token_log_probs, |
|
|
per_token_ref_log_probs, |
|
|
response_mask |
|
|
) |
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp): |
|
|
reward_output = self.reward_model(full_input_data) |
|
|
|
|
|
|
|
|
if reward_output.dim() == 2: |
|
|
raw_rewards = reward_output[:, -1] |
|
|
else: |
|
|
raw_rewards = reward_output.squeeze(-1) |
|
|
|
|
|
|
|
|
total_rewards = raw_rewards - self.kl_coef * kl_penalty |
|
|
|
|
|
|
|
|
rewards_grouped = total_rewards.view(batch_size, self.group_size) |
|
|
|
|
|
if self.advantage_normalization == 'group': |
|
|
|
|
|
mean_grouped = rewards_grouped.mean(dim=1, keepdim=True) |
|
|
std_grouped = rewards_grouped.std(dim=1, keepdim=True) + 1e-8 |
|
|
advantages = (rewards_grouped - mean_grouped) / std_grouped |
|
|
elif self.advantage_normalization == 'global': |
|
|
|
|
|
advantages = (rewards_grouped - rewards_grouped.mean()) / ( |
|
|
rewards_grouped.std() + 1e-8 |
|
|
) |
|
|
else: |
|
|
advantages = rewards_grouped - rewards_grouped.mean(dim=1, keepdim=True) |
|
|
|
|
|
advantages = advantages.view(-1) |
|
|
|
|
|
|
|
|
all_sequences.append(sequences.cpu()) |
|
|
all_log_probs.append(per_token_log_probs.detach().cpu()) |
|
|
all_advantages.append(advantages.detach().cpu()) |
|
|
all_prompt_lens.append( |
|
|
torch.full((sequences.size(0),), prompt_len, dtype=torch.long) |
|
|
) |
|
|
all_rewards.append(total_rewards.detach().cpu()) |
|
|
|
|
|
|
|
|
del logits, ref_logits, actor_out, ref_out |
|
|
del log_probs, ref_log_probs, reward_output |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating experience for batch: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
finally: |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
if not all_sequences: |
|
|
raise RuntimeError("No valid sequences generated") |
|
|
|
|
|
|
|
|
experience = { |
|
|
'sequences': torch.cat(all_sequences, dim=0), |
|
|
'log_probs': torch.cat(all_log_probs, dim=0), |
|
|
'advantages': torch.cat(all_advantages, dim=0), |
|
|
'prompt_lengths': torch.cat(all_prompt_lens, dim=0), |
|
|
'rewards': torch.cat(all_rewards, dim=0) |
|
|
} |
|
|
|
|
|
|
|
|
logger.info(f"Generated {len(experience['sequences'])} sequences") |
|
|
logger.info(f"Avg Reward: {experience['rewards'].mean().item():.4f}") |
|
|
logger.info(f"Reward Std: {experience['rewards'].std().item():.4f}") |
|
|
logger.info(f"Avg Advantage: {experience['advantages'].mean().item():.4f}") |
|
|
|
|
|
return experience |
|
|
|
|
|
def grpo_step( |
|
|
self, |
|
|
dataset: TensorDataset |
|
|
) -> Dict[str, float]: |
|
|
self.actor.train() |
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=self.update_batch_size, |
|
|
shuffle=True, |
|
|
drop_last=False |
|
|
) |
|
|
|
|
|
epoch_stats = { |
|
|
'total_loss': 0.0, |
|
|
'policy_loss': 0.0, |
|
|
'entropy': 0.0, |
|
|
'approx_kl': 0.0, |
|
|
'clip_fraction': 0.0, |
|
|
'steps': 0 |
|
|
} |
|
|
|
|
|
for batch_data in dataloader: |
|
|
sequences, old_log_probs, advantages, prompt_lens = batch_data |
|
|
|
|
|
sequences = sequences.to(self.device) |
|
|
old_log_probs = old_log_probs.to(self.device) |
|
|
advantages = advantages.to(self.device) |
|
|
|
|
|
input_data = { |
|
|
'segments': [{ |
|
|
'type': 'text', |
|
|
'data': sequences, |
|
|
'modality_id': 0 |
|
|
}] |
|
|
} |
|
|
|
|
|
with torch.amp.autocast('cuda', enabled=self.use_amp): |
|
|
outputs = self.actor(input_data) |
|
|
logits = outputs['logits'][:, :-1, :] |
|
|
|
|
|
|
|
|
targets = sequences[:, 1:] |
|
|
log_probs_dist = F.log_softmax(logits, dim=-1) |
|
|
new_log_probs = torch.gather( |
|
|
log_probs_dist, -1, targets.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
|
|
|
|
|
|
mask = torch.zeros_like(new_log_probs) |
|
|
for i, pl in enumerate(prompt_lens): |
|
|
mask[i, pl-1:] = 1.0 |
|
|
|
|
|
|
|
|
ratio = torch.exp(new_log_probs - old_log_probs) |
|
|
|
|
|
|
|
|
adv_expanded = advantages.unsqueeze(-1).expand_as(new_log_probs) |
|
|
|
|
|
|
|
|
surr1 = ratio * adv_expanded |
|
|
surr2 = torch.clamp( |
|
|
ratio, |
|
|
1.0 - self.clip_epsilon, |
|
|
1.0 + self.clip_epsilon |
|
|
) * adv_expanded |
|
|
|
|
|
|
|
|
policy_loss = -torch.min(surr1, surr2) |
|
|
policy_loss = (policy_loss * mask).sum() / (mask.sum() + 1e-8) |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
entropy = -(probs * log_probs_dist).sum(dim=-1) |
|
|
entropy_bonus = (entropy * mask).sum() / (mask.sum() + 1e-8) |
|
|
|
|
|
|
|
|
loss = policy_loss - self.entropy_coef * entropy_bonus |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
log_ratio = new_log_probs - old_log_probs |
|
|
approx_kl = ((ratio - 1) - log_ratio) * mask |
|
|
approx_kl = approx_kl.sum() / (mask.sum() + 1e-8) |
|
|
|
|
|
clip_fraction = ((ratio > 1 + self.clip_epsilon) | |
|
|
(ratio < 1 - self.clip_epsilon)).float() |
|
|
clip_fraction = (clip_fraction * mask).sum() / (mask.sum() + 1e-8) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
self.scaler.scale(loss).backward() |
|
|
|
|
|
|
|
|
self.scaler.unscale_(self.optimizer) |
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
|
self.actor.parameters(), |
|
|
self.max_grad_norm |
|
|
) |
|
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
|
|
|
|
|
|
epoch_stats['total_loss'] += loss.item() |
|
|
epoch_stats['policy_loss'] += policy_loss.item() |
|
|
epoch_stats['entropy'] += entropy_bonus.item() |
|
|
epoch_stats['approx_kl'] += approx_kl.item() |
|
|
epoch_stats['clip_fraction'] += clip_fraction.item() |
|
|
epoch_stats['steps'] += 1 |
|
|
|
|
|
|
|
|
for key in epoch_stats: |
|
|
if key != 'steps': |
|
|
epoch_stats[key] /= max(epoch_stats['steps'], 1) |
|
|
|
|
|
return epoch_stats |
|
|
|
|
|
def train( |
|
|
self, |
|
|
prompt_dataloader: DataLoader, |
|
|
num_iterations: int = 1, |
|
|
max_gen_len: int = 50, |
|
|
temperature: float = 1.0, |
|
|
save_every: int = 5, |
|
|
save_path: str = "checkpoints" |
|
|
): |
|
|
|
|
|
logger.info(f"\n{'='*80}") |
|
|
logger.info(f"Starting GRPO Training") |
|
|
logger.info(f" Iterations: {num_iterations}") |
|
|
logger.info(f" Max Gen Length: {max_gen_len}") |
|
|
logger.info(f" Temperature: {temperature}") |
|
|
logger.info(f"{'='*80}\n") |
|
|
|
|
|
for iteration in range(num_iterations): |
|
|
try: |
|
|
|
|
|
experience = self.generate_experience( |
|
|
prompt_dataloader, |
|
|
max_gen_len, |
|
|
temperature |
|
|
) |
|
|
|
|
|
dataset = TensorDataset( |
|
|
experience['sequences'], |
|
|
experience['log_probs'], |
|
|
experience['advantages'], |
|
|
experience['prompt_lengths'] |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(f"Optimizing policy for {self.grpo_epochs} epochs...") |
|
|
all_epoch_stats = [] |
|
|
|
|
|
for epoch in range(self.grpo_epochs): |
|
|
stats = self.grpo_step(dataset) |
|
|
all_epoch_stats.append(stats) |
|
|
|
|
|
logger.info( |
|
|
f" Epoch {epoch+1}/{self.grpo_epochs} | " |
|
|
f"Loss: {stats['total_loss']:.4f} | " |
|
|
f"KL: {stats['approx_kl']:.4f} | " |
|
|
f"Clip%: {stats['clip_fraction']*100:.1f}" |
|
|
) |
|
|
|
|
|
|
|
|
avg_stats = { |
|
|
key: np.mean([s[key] for s in all_epoch_stats]) |
|
|
for key in all_epoch_stats[0].keys() |
|
|
} |
|
|
|
|
|
self.training_stats['iterations'] += 1 |
|
|
self.training_stats['total_samples'] += len(experience['sequences']) |
|
|
self.training_stats['avg_rewards'].append( |
|
|
experience['rewards'].mean().item() |
|
|
) |
|
|
self.training_stats['avg_kl'].append(avg_stats['approx_kl']) |
|
|
self.training_stats['policy_losses'].append(avg_stats['policy_loss']) |
|
|
|
|
|
|
|
|
logger.info(f"\n{'='*80}") |
|
|
logger.info(f"Iteration {iteration+1}/{num_iterations} Complete") |
|
|
logger.info(f" Avg Reward: {experience['rewards'].mean():.4f}") |
|
|
logger.info(f" Avg Advantage: {experience['advantages'].mean():.4f}") |
|
|
logger.info(f" Policy Loss: {avg_stats['policy_loss']:.4f}") |
|
|
logger.info(f" Approx KL: {avg_stats['approx_kl']:.4f}") |
|
|
logger.info(f" Entropy: {avg_stats['entropy']:.4f}") |
|
|
logger.info(f" Clip Fraction: {avg_stats['clip_fraction']*100:.1f}%") |
|
|
logger.info(f"{'='*80}\n") |
|
|
|
|
|
|
|
|
if (iteration + 1) % save_every == 0: |
|
|
self.save_checkpoint( |
|
|
f"{save_path}/grpo_iter_{iteration+1}.pt" |
|
|
) |
|
|
|
|
|
|
|
|
del experience, dataset |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in iteration {iteration+1}: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
continue |
|
|
|
|
|
logger.info("GRPO Training Complete!") |
|
|
self.print_training_summary() |
|
|
|
|
|
def save_checkpoint(self, path: str): |
|
|
import os |
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) |
|
|
|
|
|
checkpoint = { |
|
|
'actor_state_dict': self.actor.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scaler_state_dict': self.scaler.state_dict(), |
|
|
'training_stats': self.training_stats, |
|
|
'config': { |
|
|
'kl_coef': self.kl_coef, |
|
|
'group_size': self.group_size, |
|
|
'clip_epsilon': self.clip_epsilon, |
|
|
} |
|
|
} |
|
|
|
|
|
torch.save(checkpoint, path) |
|
|
logger.info(f"Checkpoint saved to {path}") |
|
|
|
|
|
def load_checkpoint(self, path: str): |
|
|
checkpoint = torch.load(path, map_location=self.device) |
|
|
|
|
|
self.actor.load_state_dict(checkpoint['actor_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
|
|
|
if 'scaler_state_dict' in checkpoint and self.use_amp: |
|
|
self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
|
|
|
self.training_stats = checkpoint['training_stats'] |
|
|
|
|
|
logger.info(f"Checkpoint loaded from {path}") |
|
|
|
|
|
def print_training_summary(self): |
|
|
logger.info("\n" + "="*80) |
|
|
logger.info("Training Summary") |
|
|
logger.info("="*80) |
|
|
logger.info(f"Total Iterations: {self.training_stats['iterations']}") |
|
|
logger.info(f"Total Samples: {self.training_stats['total_samples']}") |
|
|
|
|
|
if self.training_stats['avg_rewards']: |
|
|
logger.info( |
|
|
f"Final Avg Reward: " |
|
|
f"{self.training_stats['avg_rewards'][-1]:.4f}" |
|
|
) |
|
|
logger.info( |
|
|
f"Reward Improvement: " |
|
|
f"{self.training_stats['avg_rewards'][-1] - self.training_stats['avg_rewards'][0]:.4f}" |
|
|
) |
|
|
|
|
|
logger.info("="*80 + "\n") |