Wolfvin's picture
Upload diffusion_llm/training/dapo.py with huggingface_hub
9b970dd verified
"""AAM Diffusion LLM — DAPO Training
Decoupled Clip & Dynamic Sampling Policy Optimization (Yu et al., 2025).
Four improvements over GRPO:
1. Decoupled Clip (asymmetric epsilon)
2. Dynamic Sampling (filter zero-variance groups)
3. Token-Level Policy Gradient Loss
4. Overlong Filtering
"""
from __future__ import annotations
import copy
import logging
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
@dataclass
class DAPOConfig:
clip_ratio_low: float = 0.2
clip_ratio_high: float = 0.28
dynamic_sampling: bool = True
token_level_loss: bool = True
overlong_filter: bool = True
max_response_length: int = 2048
num_responses_per_prompt: int = 8
kl_coefficient: float = 0.1
discount_factor: float = 1.0
use_reward_normalization: bool = True
use_advantage_normalization: bool = True
learning_rate: float = 1e-6
reference_model_freeze: bool = True
entropy_coefficient: float = 0.01
max_grad_norm: float = 1.0
temperature: float = 0.7
reward_shaping: str = "centered"
def __post_init__(self) -> None:
if self.clip_ratio_low <= 0:
raise ValueError(f"clip_ratio_low must be positive, got {self.clip_ratio_low}")
if self.clip_ratio_high <= 0:
raise ValueError(f"clip_ratio_high must be positive, got {self.clip_ratio_high}")
if self.num_responses_per_prompt < 2:
raise ValueError(f"num_responses_per_prompt must be >= 2, got {self.num_responses_per_prompt}")
class DAPOTrainer:
"""DAPO Trainer for AAM Diffusion LLM."""
def __init__(
self,
config: DAPOConfig,
policy_model: nn.Module,
reference_model: Optional[nn.Module] = None,
reward_fn: Optional[Callable] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
) -> None:
self.config = config
self.policy_model = policy_model
self.reward_fn = reward_fn
if reference_model is not None:
self.reference_model = reference_model
elif config.kl_coefficient > 0:
self.reference_model = copy.deepcopy(policy_model)
else:
self.reference_model = None
if self.reference_model is not None and config.reference_model_freeze:
for param in self.reference_model.parameters():
param.requires_grad = False
trainable_params = [p for p in policy_model.parameters() if p.requires_grad]
self.optimizer = optimizer or torch.optim.AdamW(
trainable_params, lr=config.learning_rate, betas=(0.9, 0.95), weight_decay=0.01,
)
self.device = next(policy_model.parameters()).device
def compute_dapo_loss(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
ref_log_probs: torch.Tensor,
rewards: torch.Tensor,
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, float]]:
cfg = self.config
log_ratio = log_probs - old_log_probs
ratio = torch.exp(log_ratio)
advantages = self._compute_advantages(rewards)
advantages_expanded = advantages.unsqueeze(-1).expand_as(log_probs) if advantages.dim() == 1 else advantages
clipped_ratio = torch.clamp(ratio, 1.0 - cfg.clip_ratio_low, 1.0 + cfg.clip_ratio_high)
surr1 = ratio * advantages_expanded
surr2 = clipped_ratio * advantages_expanded
if cfg.token_level_loss:
per_token_loss = -torch.min(surr1, surr2) * attention_mask
num_valid_tokens = attention_mask.sum(dim=-1, keepdim=True).clamp(min=1)
policy_loss = (per_token_loss.sum(dim=-1) / num_valid_tokens.squeeze(-1)).mean()
else:
per_token_loss = -torch.min(surr1, surr2) * attention_mask
seq_loss = per_token_loss.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)
policy_loss = seq_loss.mean()
kl_penalty = torch.tensor(0.0, device=log_probs.device)
if ref_log_probs is not None and cfg.kl_coefficient > 0:
kl_per_token = torch.exp(log_probs) * (log_probs - ref_log_probs) * attention_mask
kl_penalty = cfg.kl_coefficient * (kl_per_token.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)).mean()
entropy = torch.tensor(0.0, device=log_probs.device)
if cfg.entropy_coefficient > 0:
per_token_entropy = -torch.exp(log_probs) * log_probs * attention_mask
entropy = (per_token_entropy.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)).mean()
loss = policy_loss + kl_penalty - cfg.entropy_coefficient * entropy
with torch.no_grad():
metrics = {
"dapo/policy_loss": policy_loss.item(),
"dapo/kl_penalty": kl_penalty.item() if isinstance(kl_penalty, torch.Tensor) else kl_penalty,
"dapo/entropy": entropy.item() if isinstance(entropy, torch.Tensor) else entropy,
"dapo/loss": loss.item(),
"dapo/mean_reward": rewards.mean().item(),
}
return loss, metrics
def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
cfg = self.config
if cfg.use_reward_normalization and rewards.numel() > 1:
rewards = self._shape_rewards(rewards, cfg.reward_shaping)
advantages = rewards.clone()
if cfg.use_advantage_normalization and advantages.numel() > 1:
adv_std = advantages.std()
if adv_std > 1e-8:
advantages = (advantages - advantages.mean()) / (adv_std + 1e-8)
return advantages
def _shape_rewards(self, rewards: torch.Tensor, strategy: str) -> torch.Tensor:
if strategy == "raw":
return rewards
if strategy == "centered":
return rewards - rewards.mean()
if strategy == "rank_based":
sorted_indices = rewards.argsort()
ranks = torch.zeros_like(rewards, dtype=torch.float32)
ranks[sorted_indices] = torch.arange(len(rewards), dtype=torch.float32, device=rewards.device) / max(len(rewards) - 1, 1)
return 2.0 * ranks - 1.0
return rewards
def filter_prompts(
self,
prompts: List[str],
responses: List[List[str]],
rewards: torch.Tensor,
) -> Tuple[List[str], List[List[str]], torch.Tensor, Dict[str, int]]:
if not self.config.dynamic_sampling:
return prompts, responses, rewards, {"filtered": 0, "total": len(prompts)}
if rewards.dim() == 1:
has_variance = rewards > 1e-6
else:
reward_std_per_prompt = rewards.std(dim=-1)
has_variance = reward_std_per_prompt > 1e-6
valid_indices = has_variance.nonzero(as_tuple=True)[0]
if len(valid_indices) == 0:
return prompts, responses, rewards, {"filtered": len(prompts), "total": len(prompts)}
filtered_prompts = [prompts[i] for i in valid_indices]
filtered_responses = [responses[i] for i in valid_indices]
filtered_rewards = rewards[valid_indices]
num_filtered = len(prompts) - len(valid_indices)
return filtered_prompts, filtered_responses, filtered_rewards, {"filtered": num_filtered, "total": len(prompts)}