| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, TensorDataset |
| import numpy as np |
| import logging |
| import hashlib |
|
|
| from math_verifier import MathReward |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class DCPOTrainer: |
| def __init__( |
| self, |
| actor_model, |
| ref_model, |
| tokenizer, |
| learning_rate: float = 1e-6, |
| group_size: int = 4, |
| eps_low: float = 0.16, |
| eps_high: float = 0.2, |
| r_max: float = 10.0, |
| grpo_epochs: int = 1, |
| max_grad_norm: float = 1.0, |
| use_amp: bool = True, |
| gradient_accumulation_steps: int = 1, |
| inner_batch_size: int = 4, |
| use_reference_comparison: bool = True, |
| use_progressive_reward: bool = False, |
| phase1_steps: int = 2000, |
| phase2_steps: int = 4000 |
| ): |
| self.actor = actor_model |
| self.ref_model = ref_model |
| self.tokenizer = tokenizer |
| |
| self.use_progressive_reward = use_progressive_reward |
| |
| if use_progressive_reward: |
| from progressive_reward import ProgressiveMathReward |
| self.math_verifier = ProgressiveMathReward( |
| use_reference_comparison=use_reference_comparison, |
| phase1_steps=phase1_steps, |
| phase2_steps=phase2_steps, |
| verbose=True |
| ) |
| else: |
| self.math_verifier = MathReward( |
| use_reference_comparison=use_reference_comparison |
| ) |
| |
| self.group_size = group_size |
| self.eps_low = eps_low |
| self.eps_high = eps_high |
| self.r_max = r_max |
| self.grpo_epochs = grpo_epochs |
| self.use_amp = use_amp |
| self.max_grad_norm = max_grad_norm |
| |
| self.gradient_accumulation_steps = gradient_accumulation_steps |
| self.inner_batch_size = inner_batch_size |
| self.experience_buffer = [] |
| self.current_step = 0 |
|
|
| if hasattr(actor_model, 'module'): |
| self.device = next(actor_model.module.parameters()).device |
| else: |
| self.device = next(actor_model.parameters()).device |
|
|
| self.optimizer = torch.optim.AdamW( |
| self.actor.parameters(), |
| lr=learning_rate, |
| weight_decay=0.01 |
| ) |
|
|
| self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp) |
|
|
| if self.ref_model: |
| self.ref_model.eval() |
| self.ref_model.requires_grad_(False) |
|
|
| self.sas_stats = {} |
|
|
| def _get_stable_hash(self, text): |
| return hashlib.md5(text.encode('utf-8')).hexdigest() |
|
|
| def state_dict(self): |
| return { |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'sas_stats': self.sas_stats, |
| 'scaler_state_dict': self.scaler.state_dict() if self.scaler is not None else None, |
| 'current_step': self.current_step |
| } |
|
|
| def load_state_dict(self, state_dict): |
| if 'optimizer_state_dict' in state_dict: |
| self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) |
| for state in self.optimizer.state.values(): |
| for k, v in state.items(): |
| if isinstance(v, torch.Tensor): |
| state[k] = v.to(self.device) |
| |
| if 'sas_stats' in state_dict: |
| self.sas_stats = state_dict['sas_stats'] |
|
|
| |
| if 'scaler_state_dict' in state_dict and state_dict['scaler_state_dict'] is not None: |
| self.scaler.load_state_dict(state_dict['scaler_state_dict']) |
| |
| if 'current_step' in state_dict: |
| self.current_step = state_dict['current_step'] |
|
|
| def update_step(self, step): |
| self.current_step = step |
| if self.use_progressive_reward: |
| self.math_verifier.update_step(step) |
|
|
| def _get_unwrapped_model(self, model): |
| if hasattr(model, 'module'): |
| return model.module |
| return model |
|
|
| @torch.no_grad() |
| def generate_and_prepare(self, prompt_batch, max_gen_len=512, temperature=1.0): |
| self.actor.eval() |
| prompts_text = prompt_batch['prompt'] |
| ground_truths = prompt_batch['ground_truth'] |
| |
| inputs = self.tokenizer( |
| prompts_text, |
| return_tensors="pt", |
| padding=True, |
| padding_side="left" |
| ).to(self.device) |
| |
| prompts_ids = inputs['input_ids'] |
| attention_mask = inputs['attention_mask'] |
| prompt_len = int(prompts_ids.shape[1]) |
| |
| prompts_ids_repeated = prompts_ids.repeat_interleave(self.group_size, dim=0) |
| attention_mask_repeated = attention_mask.repeat_interleave(self.group_size, dim=0) |
| |
| input_data = { |
| 'segments': [{'type': 'text', 'data': prompts_ids_repeated, 'modality_id': 0}], |
| 'attention_mask': attention_mask_repeated |
| } |
| |
| |
| unwrapped_actor = self._get_unwrapped_model(self.actor) |
| with torch.amp.autocast('cuda', enabled=self.use_amp): |
| generated_ids = unwrapped_actor.generate( |
| input_data, |
| max_new_tokens=max_gen_len, |
| do_sample=True, |
| temperature=temperature, |
| top_p=0.95, |
| pad_token_id=self.tokenizer.pad_token_id |
| ) |
| |
| sequences = torch.cat([prompts_ids_repeated, generated_ids], dim=1) |
| decoded_responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| |
| |
| full_responses_for_reward = [] |
| for r in decoded_responses: |
| if not r.strip().startswith("<think>"): |
| full_responses_for_reward.append("<think>\n" + r.strip()) |
| else: |
| full_responses_for_reward.append(r) |
| |
| expanded_gts = [] |
| for gt in ground_truths: |
| expanded_gts.extend([gt] * self.group_size) |
|
|
| raw_rewards = self.math_verifier.compute_rewards(full_responses_for_reward, expanded_gts) |
| rewards_tensor = torch.tensor(raw_rewards, device=self.device, dtype=torch.float32) |
| |
| gen_mask = (generated_ids != self.tokenizer.pad_token_id).long() |
| full_attention_mask = torch.cat([attention_mask_repeated, gen_mask], dim=1) |
| |
| batch_size = sequences.size(0) |
| seq_len = sequences.size(1) |
| position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device) |
| |
| for i in range(batch_size): |
| |
| non_pad_positions = (full_attention_mask[i] == 1).nonzero(as_tuple=True)[0] |
| if len(non_pad_positions) > 0: |
| start_pos = non_pad_positions[0].item() |
| valid_len = len(non_pad_positions) |
| |
| position_ids[i, start_pos:start_pos + valid_len] = torch.arange(valid_len, device=self.device) |
| |
| full_input_data = {'segments': [{'type': 'text', 'data': sequences, 'modality_id': 0}]} |
| |
| with torch.amp.autocast('cuda', enabled=self.use_amp): |
| actor_out = self.actor( |
| full_input_data, |
| attention_mask=full_attention_mask, |
| position_ids=position_ids |
| ) |
| |
| logits = actor_out['logits'][:, :-1, :] |
| targets = sequences[:, 1:] |
| log_probs = F.log_softmax(logits, dim=-1) |
| per_token_log_probs = torch.gather(log_probs, -1, targets.unsqueeze(-1)).squeeze(-1) |
| |
| return { |
| 'prompts_text': prompts_text, |
| 'sequences': sequences.detach().cpu(), |
| 'old_log_probs': per_token_log_probs.detach().cpu(), |
| 'rewards': rewards_tensor.cpu(), |
| 'attention_mask': full_attention_mask.cpu(), |
| 'position_ids': position_ids.cpu(), |
| 'prompt_length': prompt_len |
| } |
|
|
| def _update_sas_stats(self, prompt_text, new_rewards): |
| """更新 SAS 均值和方差统计""" |
| prompt_hash = self._get_stable_hash(prompt_text) |
| |
| mu_new = new_rewards.mean().item() |
| var_new = new_rewards.var(unbiased=False).item() if len(new_rewards) > 1 else 0.0 |
| |
| if prompt_hash not in self.sas_stats: |
| self.sas_stats[prompt_hash] = { |
| 'i': 1, |
| 'mu_total': mu_new, |
| 'var_total': var_new |
| } |
| return mu_new, np.sqrt(var_new + 1e-8), mu_new, np.sqrt(var_new + 1e-8) |
| |
| stats = self.sas_stats[prompt_hash] |
| i = stats['i'] + 1 |
| mu_old = stats['mu_total'] |
| var_old = stats['var_total'] |
| |
| mu_total = (mu_new + (i - 1) * mu_old) / i |
| term3 = ((i - 1) / i) * (mu_old - mu_new)**2 |
| var_total = (var_new + (i - 1) * var_old + term3) / i |
| |
| stats['i'] = i |
| stats['mu_total'] = mu_total |
| stats['var_total'] = var_total |
| |
| return mu_new, np.sqrt(var_new + 1e-8), mu_total, np.sqrt(var_total + 1e-8) |
|
|
| def _compute_sas_advantages(self, experience_batch): |
| prompts = experience_batch['prompts_text'] |
| rewards = experience_batch['rewards'].view(-1, self.group_size) |
| |
| final_advantages = [] |
| |
| for idx, prompt in enumerate(prompts): |
| group_rewards = rewards[idx] |
| mu_new, std_new, mu_total, std_total = self._update_sas_stats(prompt, group_rewards) |
| |
| A_new = (group_rewards - mu_new) / (std_new + 1e-8) |
| A_total = (group_rewards - mu_total) / (std_total + 1e-8) |
| |
| i = self.sas_stats[self._get_stable_hash(prompt)]['i'] |
| |
| SA_new = ((i - 1) / i) * A_new + (1 / i) * A_total |
| SA_total = (1 / i) * A_new + ((i - 1) / i) * A_total |
| |
| mask = (torch.abs(SA_new) < torch.abs(SA_total)).float() |
| A_final = mask * SA_new + (1 - mask) * SA_total |
| final_advantages.append(A_final) |
| |
| return torch.cat(final_advantages) |
|
|
| def train_step(self, experience): |
| self.experience_buffer.append(experience) |
| if len(self.experience_buffer) < self.gradient_accumulation_steps: |
| return None |
|
|
| all_advantages = [] |
| for exp in self.experience_buffer: |
| adv = self._compute_sas_advantages(exp) |
| exp['advantages'] = adv.detach() |
| all_advantages.append(exp['advantages']) |
| |
| self.actor.train() |
| |
| max_seq_len = max([e['sequences'].size(1) for e in self.experience_buffer]) |
| max_lp_len = max([e['old_log_probs'].size(1) for e in self.experience_buffer]) |
| |
| def pad_tensor(t, target_len, val): |
| return F.pad(t, (0, target_len - t.size(1)), value=val) |
|
|
| padded_seqs = [] |
| padded_old_lp = [] |
| padded_attn_masks = [] |
| padded_pos_ids = [] |
| prompt_lens_list = [] |
| |
| for e in self.experience_buffer: |
| padded_seqs.append(pad_tensor(e['sequences'], max_seq_len, self.tokenizer.pad_token_id)) |
|
|
| padded_old_lp.append(pad_tensor(e['old_log_probs'], max_lp_len, 0.0)) |
| |
| padded_attn_masks.append(pad_tensor(e['attention_mask'], max_seq_len, 0)) |
| padded_pos_ids.append(pad_tensor(e['position_ids'], max_seq_len, 0)) |
| |
| prompt_lens_list.extend([e['prompt_length']] * (len(e['sequences']))) |
|
|
| |
| cat_sequences = torch.cat(padded_seqs, dim=0) |
| cat_old_log_probs = torch.cat(padded_old_lp, dim=0) |
| cat_advantages = torch.cat(all_advantages, dim=0) |
| cat_prompt_lens = torch.tensor(prompt_lens_list) |
| cat_attention_masks = torch.cat(padded_attn_masks, dim=0) |
| cat_position_ids = torch.cat(padded_pos_ids, dim=0) |
| |
| self.experience_buffer = [] |
| |
| dataset = TensorDataset( |
| cat_sequences, |
| cat_old_log_probs, |
| cat_advantages, |
| cat_prompt_lens, |
| cat_attention_masks, |
| cat_position_ids |
| ) |
| dataloader = DataLoader(dataset, batch_size=self.inner_batch_size, shuffle=True) |
| |
| total_loss = 0 |
| update_steps = 0 |
| |
| for _ in range(self.grpo_epochs): |
| for batch in dataloader: |
| seqs, old_lp, advs, p_lens, attn_masks, pos_ids = [b.to(self.device) for b in batch] |
| |
| input_data = {'segments': [{'type': 'text', 'data': seqs, 'modality_id': 0}]} |
| |
| with torch.amp.autocast('cuda', enabled=self.use_amp): |
| outputs = self.actor( |
| input_data, |
| attention_mask=attn_masks, |
| position_ids=pos_ids |
| ) |
| logits = outputs['logits'][:, :-1, :] |
| targets = seqs[:, 1:] |
| |
| new_log_probs = F.log_softmax(logits, dim=-1) |
| new_token_log_probs = torch.gather(new_log_probs, -1, targets.unsqueeze(-1)).squeeze(-1) |
|
|
| mask = torch.zeros_like(new_token_log_probs) |
| for i, pl in enumerate(p_lens): |
| pl_val = int(pl.item()) |
| start_idx = max(0, pl_val - 1) |
| if start_idx < mask.size(1): |
| mask[i, start_idx:] = 1.0 |
|
|
| is_padding = (targets == self.tokenizer.pad_token_id) |
| is_valid_old_lp = (old_lp != 0.0) |
| mask = mask * (~is_padding).float() * is_valid_old_lp.float() |
| |
| q_probs = torch.exp(old_lp).clamp(min=1e-10, max=1.0) |
| term_low = 1.0 - (4.0 * self.eps_low) / q_probs |
| lower_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_low, min=0.0)) |
| term_high = 1.0 + (4.0 * self.eps_high) / q_probs |
| upper_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_high, min=0.0)) |
| |
| ratio = torch.exp(new_token_log_probs - old_lp) |
| ratio = torch.clamp(ratio, 0, self.r_max) |
| |
| advs_expanded = advs.unsqueeze(1).expand_as(ratio) |
| surr1 = ratio * advs_expanded |
| clipped_ratio = torch.min(torch.max(ratio, lower_clip), upper_clip) |
| surr2 = clipped_ratio * advs_expanded |
| |
| element_wise_loss = torch.min(surr1, surr2) |
| masked_loss = element_wise_loss * mask |
| response_lens = torch.clamp(mask.sum(dim=1), min=1.0) |
| per_response_loss = masked_loss.sum(dim=1) / response_lens |
| loss = -per_response_loss.mean() |
| |
| self.optimizer.zero_grad() |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| |
| total_loss += loss.item() |
| update_steps += 1 |
| |
| return total_loss / max(update_steps, 1) |