MultiModal / dcpo.py
szxllm's picture
Update dcpo.py
e576d4e verified
Raw
History Blame Contribute Delete
15.6 kB
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import logging
import hashlib
from math_verifier import MathReward
logger = logging.getLogger(__name__)
class DCPOTrainer:
def __init__(
self,
actor_model,
ref_model,
tokenizer,
learning_rate: float = 1e-6,
group_size: int = 4,
eps_low: float = 0.16,
eps_high: float = 0.2,
r_max: float = 10.0,
grpo_epochs: int = 1,
max_grad_norm: float = 1.0,
use_amp: bool = True,
gradient_accumulation_steps: int = 1,
inner_batch_size: int = 4,
use_reference_comparison: bool = True,
use_progressive_reward: bool = False,
phase1_steps: int = 2000,
phase2_steps: int = 4000
):
self.actor = actor_model
self.ref_model = ref_model
self.tokenizer = tokenizer
self.use_progressive_reward = use_progressive_reward
if use_progressive_reward:
from progressive_reward import ProgressiveMathReward
self.math_verifier = ProgressiveMathReward(
use_reference_comparison=use_reference_comparison,
phase1_steps=phase1_steps,
phase2_steps=phase2_steps,
verbose=True
)
else:
self.math_verifier = MathReward(
use_reference_comparison=use_reference_comparison
)
self.group_size = group_size
self.eps_low = eps_low
self.eps_high = eps_high
self.r_max = r_max
self.grpo_epochs = grpo_epochs
self.use_amp = use_amp
self.max_grad_norm = max_grad_norm
self.gradient_accumulation_steps = gradient_accumulation_steps
self.inner_batch_size = inner_batch_size
self.experience_buffer = []
self.current_step = 0
if hasattr(actor_model, 'module'):
self.device = next(actor_model.module.parameters()).device
else:
self.device = next(actor_model.parameters()).device
self.optimizer = torch.optim.AdamW(
self.actor.parameters(),
lr=learning_rate,
weight_decay=0.01
)
self.scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
if self.ref_model:
self.ref_model.eval()
self.ref_model.requires_grad_(False)
self.sas_stats = {}
def _get_stable_hash(self, text):
return hashlib.md5(text.encode('utf-8')).hexdigest()
def state_dict(self):
return {
'optimizer_state_dict': self.optimizer.state_dict(),
'sas_stats': self.sas_stats,
'scaler_state_dict': self.scaler.state_dict() if self.scaler is not None else None,
'current_step': self.current_step
}
def load_state_dict(self, state_dict):
if 'optimizer_state_dict' in state_dict:
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
if 'sas_stats' in state_dict:
self.sas_stats = state_dict['sas_stats']
if 'scaler_state_dict' in state_dict and state_dict['scaler_state_dict'] is not None:
self.scaler.load_state_dict(state_dict['scaler_state_dict'])
if 'current_step' in state_dict:
self.current_step = state_dict['current_step']
def update_step(self, step):
self.current_step = step
if self.use_progressive_reward:
self.math_verifier.update_step(step)
def _get_unwrapped_model(self, model):
if hasattr(model, 'module'):
return model.module
return model
@torch.no_grad()
def generate_and_prepare(self, prompt_batch, max_gen_len=512, temperature=1.0):
self.actor.eval()
prompts_text = prompt_batch['prompt']
ground_truths = prompt_batch['ground_truth']
inputs = self.tokenizer(
prompts_text,
return_tensors="pt",
padding=True,
padding_side="left"
).to(self.device)
prompts_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
prompt_len = int(prompts_ids.shape[1])
prompts_ids_repeated = prompts_ids.repeat_interleave(self.group_size, dim=0)
attention_mask_repeated = attention_mask.repeat_interleave(self.group_size, dim=0)
input_data = {
'segments': [{'type': 'text', 'data': prompts_ids_repeated, 'modality_id': 0}],
'attention_mask': attention_mask_repeated
}
# 推理时使用 unwrapped model
unwrapped_actor = self._get_unwrapped_model(self.actor)
with torch.amp.autocast('cuda', enabled=self.use_amp):
generated_ids = unwrapped_actor.generate(
input_data,
max_new_tokens=max_gen_len,
do_sample=True,
temperature=temperature,
top_p=0.95,
pad_token_id=self.tokenizer.pad_token_id
)
sequences = torch.cat([prompts_ids_repeated, generated_ids], dim=1)
decoded_responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# 处理 Think 标签
full_responses_for_reward = []
for r in decoded_responses:
if not r.strip().startswith("<think>"):
full_responses_for_reward.append("<think>\n" + r.strip())
else:
full_responses_for_reward.append(r)
expanded_gts = []
for gt in ground_truths:
expanded_gts.extend([gt] * self.group_size)
raw_rewards = self.math_verifier.compute_rewards(full_responses_for_reward, expanded_gts)
rewards_tensor = torch.tensor(raw_rewards, device=self.device, dtype=torch.float32)
gen_mask = (generated_ids != self.tokenizer.pad_token_id).long()
full_attention_mask = torch.cat([attention_mask_repeated, gen_mask], dim=1)
batch_size = sequences.size(0)
seq_len = sequences.size(1)
position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device)
for i in range(batch_size):
# 找到第一个非 padding token 的位置
non_pad_positions = (full_attention_mask[i] == 1).nonzero(as_tuple=True)[0]
if len(non_pad_positions) > 0:
start_pos = non_pad_positions[0].item()
valid_len = len(non_pad_positions)
# 从 0 开始编号有效 token 的位置
position_ids[i, start_pos:start_pos + valid_len] = torch.arange(valid_len, device=self.device)
full_input_data = {'segments': [{'type': 'text', 'data': sequences, 'modality_id': 0}]}
with torch.amp.autocast('cuda', enabled=self.use_amp):
actor_out = self.actor(
full_input_data,
attention_mask=full_attention_mask,
position_ids=position_ids
)
logits = actor_out['logits'][:, :-1, :]
targets = sequences[:, 1:]
log_probs = F.log_softmax(logits, dim=-1)
per_token_log_probs = torch.gather(log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
return {
'prompts_text': prompts_text,
'sequences': sequences.detach().cpu(),
'old_log_probs': per_token_log_probs.detach().cpu(),
'rewards': rewards_tensor.cpu(),
'attention_mask': full_attention_mask.cpu(),
'position_ids': position_ids.cpu(),
'prompt_length': prompt_len
}
def _update_sas_stats(self, prompt_text, new_rewards):
"""更新 SAS 均值和方差统计"""
prompt_hash = self._get_stable_hash(prompt_text)
mu_new = new_rewards.mean().item()
var_new = new_rewards.var(unbiased=False).item() if len(new_rewards) > 1 else 0.0
if prompt_hash not in self.sas_stats:
self.sas_stats[prompt_hash] = {
'i': 1,
'mu_total': mu_new,
'var_total': var_new
}
return mu_new, np.sqrt(var_new + 1e-8), mu_new, np.sqrt(var_new + 1e-8)
stats = self.sas_stats[prompt_hash]
i = stats['i'] + 1
mu_old = stats['mu_total']
var_old = stats['var_total']
mu_total = (mu_new + (i - 1) * mu_old) / i
term3 = ((i - 1) / i) * (mu_old - mu_new)**2
var_total = (var_new + (i - 1) * var_old + term3) / i
stats['i'] = i
stats['mu_total'] = mu_total
stats['var_total'] = var_total
return mu_new, np.sqrt(var_new + 1e-8), mu_total, np.sqrt(var_total + 1e-8)
def _compute_sas_advantages(self, experience_batch):
prompts = experience_batch['prompts_text']
rewards = experience_batch['rewards'].view(-1, self.group_size)
final_advantages = []
for idx, prompt in enumerate(prompts):
group_rewards = rewards[idx]
mu_new, std_new, mu_total, std_total = self._update_sas_stats(prompt, group_rewards)
A_new = (group_rewards - mu_new) / (std_new + 1e-8)
A_total = (group_rewards - mu_total) / (std_total + 1e-8)
i = self.sas_stats[self._get_stable_hash(prompt)]['i']
SA_new = ((i - 1) / i) * A_new + (1 / i) * A_total
SA_total = (1 / i) * A_new + ((i - 1) / i) * A_total
mask = (torch.abs(SA_new) < torch.abs(SA_total)).float()
A_final = mask * SA_new + (1 - mask) * SA_total
final_advantages.append(A_final)
return torch.cat(final_advantages)
def train_step(self, experience):
self.experience_buffer.append(experience)
if len(self.experience_buffer) < self.gradient_accumulation_steps:
return None
all_advantages = []
for exp in self.experience_buffer:
adv = self._compute_sas_advantages(exp)
exp['advantages'] = adv.detach()
all_advantages.append(exp['advantages'])
self.actor.train()
max_seq_len = max([e['sequences'].size(1) for e in self.experience_buffer])
max_lp_len = max([e['old_log_probs'].size(1) for e in self.experience_buffer])
def pad_tensor(t, target_len, val):
return F.pad(t, (0, target_len - t.size(1)), value=val)
padded_seqs = []
padded_old_lp = []
padded_attn_masks = []
padded_pos_ids = []
prompt_lens_list = []
for e in self.experience_buffer:
padded_seqs.append(pad_tensor(e['sequences'], max_seq_len, self.tokenizer.pad_token_id))
padded_old_lp.append(pad_tensor(e['old_log_probs'], max_lp_len, 0.0))
padded_attn_masks.append(pad_tensor(e['attention_mask'], max_seq_len, 0))
padded_pos_ids.append(pad_tensor(e['position_ids'], max_seq_len, 0))
prompt_lens_list.extend([e['prompt_length']] * (len(e['sequences'])))
# 显存优化:Dataset 保持在 CPU
cat_sequences = torch.cat(padded_seqs, dim=0)
cat_old_log_probs = torch.cat(padded_old_lp, dim=0)
cat_advantages = torch.cat(all_advantages, dim=0)
cat_prompt_lens = torch.tensor(prompt_lens_list)
cat_attention_masks = torch.cat(padded_attn_masks, dim=0)
cat_position_ids = torch.cat(padded_pos_ids, dim=0)
self.experience_buffer = []
dataset = TensorDataset(
cat_sequences,
cat_old_log_probs,
cat_advantages,
cat_prompt_lens,
cat_attention_masks,
cat_position_ids
)
dataloader = DataLoader(dataset, batch_size=self.inner_batch_size, shuffle=True)
total_loss = 0
update_steps = 0
for _ in range(self.grpo_epochs):
for batch in dataloader:
seqs, old_lp, advs, p_lens, attn_masks, pos_ids = [b.to(self.device) for b in batch]
input_data = {'segments': [{'type': 'text', 'data': seqs, 'modality_id': 0}]}
with torch.amp.autocast('cuda', enabled=self.use_amp):
outputs = self.actor(
input_data,
attention_mask=attn_masks,
position_ids=pos_ids
)
logits = outputs['logits'][:, :-1, :]
targets = seqs[:, 1:]
new_log_probs = F.log_softmax(logits, dim=-1)
new_token_log_probs = torch.gather(new_log_probs, -1, targets.unsqueeze(-1)).squeeze(-1)
mask = torch.zeros_like(new_token_log_probs)
for i, pl in enumerate(p_lens):
pl_val = int(pl.item())
start_idx = max(0, pl_val - 1)
if start_idx < mask.size(1):
mask[i, start_idx:] = 1.0
is_padding = (targets == self.tokenizer.pad_token_id)
is_valid_old_lp = (old_lp != 0.0)
mask = mask * (~is_padding).float() * is_valid_old_lp.float()
q_probs = torch.exp(old_lp).clamp(min=1e-10, max=1.0)
term_low = 1.0 - (4.0 * self.eps_low) / q_probs
lower_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_low, min=0.0))
term_high = 1.0 + (4.0 * self.eps_high) / q_probs
upper_clip = 0.5 + 0.5 * torch.sqrt(torch.clamp(term_high, min=0.0))
ratio = torch.exp(new_token_log_probs - old_lp)
ratio = torch.clamp(ratio, 0, self.r_max)
advs_expanded = advs.unsqueeze(1).expand_as(ratio)
surr1 = ratio * advs_expanded
clipped_ratio = torch.min(torch.max(ratio, lower_clip), upper_clip)
surr2 = clipped_ratio * advs_expanded
element_wise_loss = torch.min(surr1, surr2)
masked_loss = element_wise_loss * mask
response_lens = torch.clamp(mask.sum(dim=1), min=1.0)
per_response_loss = masked_loss.sum(dim=1) / response_lens
loss = -per_response_loss.mean()
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
total_loss += loss.item()
update_steps += 1
return total_loss / max(update_steps, 1)