import torch import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset import numpy as np import logging import hashlib from math_verifier import MathReward logger = logging.getLogger(__name__) class DCPOTrainer: def __init__( self, actor_model, ref_model, tokenizer, learning_rate: float = 1e-6, group_size: int = 4, eps_low: float = 0.16, eps_high: float = 0.2, r_max: float = 10.0, grpo_epochs: int = 1, max_grad_norm: float = 1.0, use_amp: bool = True, gradient_accumulation_steps: int = 1, inner_batch_size: int = 4, use_reference_comparison: bool = True, use_progressive_reward: bool = False, phase1_steps: int = 2000, phase2_steps: int = 4000 ): self.actor = actor_model self.ref_model = ref_model self.tokenizer = tokenizer self.use_progressive_reward = use_progressive_reward if use_progressive_reward: from progressive_reward import ProgressiveMathReward self.math_verifier = ProgressiveMathReward( use_reference_comparison=use_reference_comparison, phase1_steps=phase1_steps, phase2_steps=phase2_steps, verbose=True ) else: self.math_verifier = MathReward( use_reference_comparison=use_reference_comparison ) self.group_size = group_size self.eps_low = eps_low self.eps_high = eps_high self.r_max = r_max self.grpo_epochs = grpo_epochs self.use_amp = use_amp self.max_grad_norm = max_grad_norm self.gradient_accumulation_steps = gradient_accumulation_steps self.inner_batch_size = inner_batch_size self.experience_buffer = [] self.current_step = 0 if hasattr(actor_model, 'module'): self.device = next(actor_model.module.parameters()).device else: self.device = next(actor_model.parameters()).device self.optimizer = torch.optim.AdamW( self.actor.parameters(), lr=learning_rate, weight_decay=0.01 ) self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp) if self.ref_model: self.ref_model.eval() self.ref_model.requires_grad_(False) self.sas_stats = {} def _get_stable_hash(self, text): return hashlib.md5(text.encode('utf-8')).hexdigest() def state_dict(self): return { 'optimizer_state_dict': self.optimizer.state_dict(), 'sas_stats': self.sas_stats, 'scaler_state_dict': self.scaler.state_dict() if self.scaler is not None else None, 'current_step': self.current_step } def load_state_dict(self, state_dict): if 'optimizer_state_dict' in state_dict: self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(self.device) if 'sas_stats' in state_dict: self.sas_stats = state_dict['sas_stats'] if 'scaler_state_dict' in state_dict and state_dict['scaler_state_dict'] is not None: self.scaler.load_state_dict(state_dict['scaler_state_dict']) if 'current_step' in state_dict: self.current_step = state_dict['current_step'] def update_step(self, step): self.current_step = step if self.use_progressive_reward: self.math_verifier.update_step(step) def _get_unwrapped_model(self, model): if hasattr(model, 'module'): return model.module return model @torch.no_grad() def generate_and_prepare(self, prompt_batch, max_gen_len=512, temperature=1.0): self.actor.eval() prompts_text = prompt_batch['prompt'] ground_truths = prompt_batch['ground_truth'] inputs = self.tokenizer( prompts_text, return_tensors="pt", padding=True, padding_side="left" ).to(self.device) prompts_ids = inputs['input_ids'] attention_mask = inputs['attention_mask'] prompt_len = int(prompts_ids.shape[1]) prompts_ids_repeated = prompts_ids.repeat_interleave(self.group_size, dim=0) attention_mask_repeated = attention_mask.repeat_interleave(self.group_size, dim=0) input_data = { 'segments': [{'type': 'text', 'data': prompts_ids_repeated, 'modality_id': 0}], 'attention_mask': attention_mask_repeated } # 推理时使用 unwrapped model unwrapped_actor = self._get_unwrapped_model(self.actor) with torch.amp.autocast('cuda', enabled=self.use_amp): generated_ids = unwrapped_actor.generate( input_data, max_new_tokens=max_gen_len, do_sample=True, temperature=temperature, top_p=0.95, pad_token_id=self.tokenizer.pad_token_id ) sequences = torch.cat([prompts_ids_repeated, generated_ids], dim=1) decoded_responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True) # 处理 Think 标签 full_responses_for_reward = [] for r in decoded_responses: if not r.strip().startswith(""): full_responses_for_reward.append("\n" + r.strip()) else: full_responses_for_reward.append(r) expanded_gts = [] for gt in ground_truths: expanded_gts.extend([gt] * self.group_size) raw_rewards = self.math_verifier.compute_rewards(full_responses_for_reward, expanded_gts) rewards_tensor = torch.tensor(raw_rewards, device=self.device, dtype=torch.float32) gen_mask = (generated_ids != self.tokenizer.pad_token_id).long() full_attention_mask = torch.cat([attention_mask_repeated, gen_mask], dim=1) batch_size = sequences.size(0) seq_len = sequences.size(1) position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device) for i in range(batch_size): # 找到第一个非 padding token 的位置 non_pad_positions = (full_attention_mask[i] == 1).nonzero(as_tuple=True)[0] if len(non_pad_positions) > 0: start_pos = non_pad_positions[0].item() valid_len = len(non_pad_positions) # 从 0 开始编号有效 token 的位置 position_ids[i, start_pos:start_pos + valid_len] = torch.arange(valid_len, device=self.device) full_input_data = {'segments': [{'type': 'text', 'data': sequences, 'modality_id': 0}]} with torch.amp.autocast('cuda', enabled=self.use_amp): actor_out = self.actor( full_input_data, attention_mask=full_attention_mask, position_ids=position_ids ) logits = actor_out['logits'][:, :-1, :] targets = sequences[:, 1:] log_probs = F.log_softmax(logits, dim=-1) per_token_log_probs = torch.gather(log_probs, -1, targets.unsqueeze(-1)).squeeze(-1) return { 'prompts_text': prompts_text, 'sequences': sequences.detach().cpu(), 'old_log_probs': per_token_log_probs.detach().cpu(), 'rewards': rewards_tensor.cpu(), 'attention_mask': full_attention_mask.cpu(), 'position_ids': position_ids.cpu(), 'prompt_length': prompt_len } def _update_sas_stats(self, prompt_text, new_rewards): """更新 SAS 均值和方差统计""" prompt_hash = self._get_stable_hash(prompt_text) mu_new = new_rewards.mean().item() var_new = new_rewards.var(unbiased=False).item() if len(new_rewards) > 1 else 0.0 if prompt_hash not in self.sas_stats: self.sas_stats[prompt_hash] = { 'i': 1, 'mu_total': mu_new, 'var_total': var_new } return mu_new, np.sqrt(var_new + 1e-8), mu_new, np.sqrt(var_new + 1e-8) stats = self.sas_stats[prompt_hash] i = stats['i'] + 1 mu_old = stats['mu_total'] var_old = stats['var_total'] mu_total = (mu_new + (i - 1) * mu_old) / i term3 = ((i - 1) / i) * (mu_old - mu_new)**2 var_total = (var_new + (i - 1) * var_old + term3) / i stats['i'] = i stats['mu_total'] = mu_total stats['var_total'] = var_total return mu_new, np.sqrt(var_new + 1e-8), mu_total, np.sqrt(var_total + 1e-8) def _compute_sas_advantages(self, experience_batch): prompts = experience_batch['prompts_text'] rewards = experience_batch['rewards'].view(-1, self.group_size) final_advantages = [] for idx, prompt in enumerate(prompts): group_rewards = rewards[idx] mu_new, std_new, mu_total, std_total = self._update_sas_stats(prompt, group_rewards) A_new = (group_rewards - mu_new) / (std_new + 1e-8) A_total = (group_rewards - mu_total) / (std_total + 1e-8) i = self.sas_stats[self._get_stable_hash(prompt)]['i'] SA_new = ((i - 1) / i) * A_new + (1 / i) * A_total SA_total = (1 / i) * A_new + ((i - 1) / i) * A_total mask = (torch.abs(SA_new) < torch.abs(SA_total)).float() A_final = mask * SA_new + (1 - mask) * SA_total final_advantages.append(A_final) return torch.cat(final_advantages) def train_step(self, experience): self.experience_buffer.append(experience) if len(self.experience_buffer) < self.gradient_accumulation_steps: return None all_advantages = [] for exp in self.experience_buffer: adv = self._compute_sas_advantages(exp) exp['advantages'] = adv.detach() all_advantages.append(exp['advantages']) self.actor.train() max_seq_len = max([e['sequences'].size(1) for e in self.experience_buffer]) max_lp_len = max([e['old_log_probs'].size(1) for e in self.experience_buffer]) def pad_tensor(t, target_len, val): return F.pad(t, (0, target_len - t.size(1)), value=val) padded_seqs = [] padded_old_lp = [] padded_attn_masks = [] padded_pos_ids = [] prompt_lens_list = [] for e in self.experience_buffer: padded_seqs.append(pad_tensor(e['sequences'], max_seq_len, self.tokenizer.pad_token_id)) padded_old_lp.append(pad_tensor(e['old_log_probs'], max_lp_len, 0.0)) padded_attn_masks.append(pad_tensor(e['attention_mask'], max_seq_len, 0)) padded_pos_ids.append(pad_tensor(e['position_ids'], max_seq_len, 0)) prompt_lens_list.extend([e['prompt_length']] * (len(e['sequences']))) # 显存优化:Dataset 保持在 CPU cat_sequences = torch.cat(padded_seqs, dim=0) cat_old_log_probs = torch.cat(padded_old_lp, dim=0) cat_advantages = torch.cat(all_advantages, dim=0) cat_prompt_lens = torch.tensor(prompt_lens_list) cat_attention_masks = torch.cat(padded_attn_masks, dim=0) cat_position_ids = torch.cat(padded_pos_ids, dim=0) self.experience_buffer = [] dataset = TensorDataset( cat_sequences, cat_old_log_probs, cat_advantages, cat_prompt_lens, cat_attention_masks, cat_position_ids ) dataloader = DataLoader(dataset, batch_size=self.inner_batch_size, shuffle=True) total_loss = 0 update_steps = 0 for _ in range(self.grpo_epochs): for batch in dataloader: seqs, old_lp, advs, p_lens, attn_masks, pos_ids = [b.to(self.device) for b in batch] input_data = {'segments': [{'type': 'text', 'data': seqs, 'modality_id': 0}]} with torch.amp.autocast('cuda', enabled=self.use_amp): outputs = self.actor( input_data, attention_mask=attn_masks, position_ids=pos_ids ) logits = outputs['logits'][:, :-1, :] targets = seqs[:, 1:] new_log_probs = F.log_softmax(logits, dim=-1) new_token_log_probs = torch.gather(new_log_probs, -1, targets.unsqueeze(-1)).squeeze(-1) mask = torch.zeros_like(new_token_log_probs) for i, pl in enumerate(p_lens): pl_val = int(pl.item()) start_idx = max(0, pl_val - 1) if start_idx < mask.size(1): mask[i, start_idx:] = 1.0 is_padding = (targets == self.tokenizer.pad_token_id) is_valid_old_lp = (old_lp != 0.0) mask = mask * (~is_padding).float() * is_valid_old_lp.float() q_probs = torch.exp(old_lp).clamp(min=1e-10, max=1.0) term_low = 1.0 - (4.0 * self.eps_low) / q_probs lower_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_low, min=0.0)) term_high = 1.0 + (4.0 * self.eps_high) / q_probs upper_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_high, min=0.0)) ratio = torch.exp(new_token_log_probs - old_lp) ratio = torch.clamp(ratio, 0, self.r_max) advs_expanded = advs.unsqueeze(1).expand_as(ratio) surr1 = ratio * advs_expanded clipped_ratio = torch.min(torch.max(ratio, lower_clip), upper_clip) surr2 = clipped_ratio * advs_expanded element_wise_loss = torch.min(surr1, surr2) masked_loss = element_wise_loss * mask response_lens = torch.clamp(mask.sum(dim=1), min=1.0) per_response_loss = masked_loss.sum(dim=1) / response_lens loss = -per_response_loss.mean() self.optimizer.zero_grad() self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm) self.scaler.step(self.optimizer) self.scaler.update() total_loss += loss.item() update_steps += 1 return total_loss / max(update_steps, 1)