"""AAM Diffusion LLM — DAPO Training Decoupled Clip & Dynamic Sampling Policy Optimization (Yu et al., 2025). Four improvements over GRPO: 1. Decoupled Clip (asymmetric epsilon) 2. Dynamic Sampling (filter zero-variance groups) 3. Token-Level Policy Gradient Loss 4. Overlong Filtering """ from __future__ import annotations import copy import logging from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger(__name__) @dataclass class DAPOConfig: clip_ratio_low: float = 0.2 clip_ratio_high: float = 0.28 dynamic_sampling: bool = True token_level_loss: bool = True overlong_filter: bool = True max_response_length: int = 2048 num_responses_per_prompt: int = 8 kl_coefficient: float = 0.1 discount_factor: float = 1.0 use_reward_normalization: bool = True use_advantage_normalization: bool = True learning_rate: float = 1e-6 reference_model_freeze: bool = True entropy_coefficient: float = 0.01 max_grad_norm: float = 1.0 temperature: float = 0.7 reward_shaping: str = "centered" def __post_init__(self) -> None: if self.clip_ratio_low <= 0: raise ValueError(f"clip_ratio_low must be positive, got {self.clip_ratio_low}") if self.clip_ratio_high <= 0: raise ValueError(f"clip_ratio_high must be positive, got {self.clip_ratio_high}") if self.num_responses_per_prompt < 2: raise ValueError(f"num_responses_per_prompt must be >= 2, got {self.num_responses_per_prompt}") class DAPOTrainer: """DAPO Trainer for AAM Diffusion LLM.""" def __init__( self, config: DAPOConfig, policy_model: nn.Module, reference_model: Optional[nn.Module] = None, reward_fn: Optional[Callable] = None, optimizer: Optional[torch.optim.Optimizer] = None, ) -> None: self.config = config self.policy_model = policy_model self.reward_fn = reward_fn if reference_model is not None: self.reference_model = reference_model elif config.kl_coefficient > 0: self.reference_model = copy.deepcopy(policy_model) else: self.reference_model = None if self.reference_model is not None and config.reference_model_freeze: for param in self.reference_model.parameters(): param.requires_grad = False trainable_params = [p for p in policy_model.parameters() if p.requires_grad] self.optimizer = optimizer or torch.optim.AdamW( trainable_params, lr=config.learning_rate, betas=(0.9, 0.95), weight_decay=0.01, ) self.device = next(policy_model.parameters()).device def compute_dapo_loss( self, log_probs: torch.Tensor, old_log_probs: torch.Tensor, ref_log_probs: torch.Tensor, rewards: torch.Tensor, attention_mask: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, float]]: cfg = self.config log_ratio = log_probs - old_log_probs ratio = torch.exp(log_ratio) advantages = self._compute_advantages(rewards) advantages_expanded = advantages.unsqueeze(-1).expand_as(log_probs) if advantages.dim() == 1 else advantages clipped_ratio = torch.clamp(ratio, 1.0 - cfg.clip_ratio_low, 1.0 + cfg.clip_ratio_high) surr1 = ratio * advantages_expanded surr2 = clipped_ratio * advantages_expanded if cfg.token_level_loss: per_token_loss = -torch.min(surr1, surr2) * attention_mask num_valid_tokens = attention_mask.sum(dim=-1, keepdim=True).clamp(min=1) policy_loss = (per_token_loss.sum(dim=-1) / num_valid_tokens.squeeze(-1)).mean() else: per_token_loss = -torch.min(surr1, surr2) * attention_mask seq_loss = per_token_loss.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1) policy_loss = seq_loss.mean() kl_penalty = torch.tensor(0.0, device=log_probs.device) if ref_log_probs is not None and cfg.kl_coefficient > 0: kl_per_token = torch.exp(log_probs) * (log_probs - ref_log_probs) * attention_mask kl_penalty = cfg.kl_coefficient * (kl_per_token.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)).mean() entropy = torch.tensor(0.0, device=log_probs.device) if cfg.entropy_coefficient > 0: per_token_entropy = -torch.exp(log_probs) * log_probs * attention_mask entropy = (per_token_entropy.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)).mean() loss = policy_loss + kl_penalty - cfg.entropy_coefficient * entropy with torch.no_grad(): metrics = { "dapo/policy_loss": policy_loss.item(), "dapo/kl_penalty": kl_penalty.item() if isinstance(kl_penalty, torch.Tensor) else kl_penalty, "dapo/entropy": entropy.item() if isinstance(entropy, torch.Tensor) else entropy, "dapo/loss": loss.item(), "dapo/mean_reward": rewards.mean().item(), } return loss, metrics def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor: cfg = self.config if cfg.use_reward_normalization and rewards.numel() > 1: rewards = self._shape_rewards(rewards, cfg.reward_shaping) advantages = rewards.clone() if cfg.use_advantage_normalization and advantages.numel() > 1: adv_std = advantages.std() if adv_std > 1e-8: advantages = (advantages - advantages.mean()) / (adv_std + 1e-8) return advantages def _shape_rewards(self, rewards: torch.Tensor, strategy: str) -> torch.Tensor: if strategy == "raw": return rewards if strategy == "centered": return rewards - rewards.mean() if strategy == "rank_based": sorted_indices = rewards.argsort() ranks = torch.zeros_like(rewards, dtype=torch.float32) ranks[sorted_indices] = torch.arange(len(rewards), dtype=torch.float32, device=rewards.device) / max(len(rewards) - 1, 1) return 2.0 * ranks - 1.0 return rewards def filter_prompts( self, prompts: List[str], responses: List[List[str]], rewards: torch.Tensor, ) -> Tuple[List[str], List[List[str]], torch.Tensor, Dict[str, int]]: if not self.config.dynamic_sampling: return prompts, responses, rewards, {"filtered": 0, "total": len(prompts)} if rewards.dim() == 1: has_variance = rewards > 1e-6 else: reward_std_per_prompt = rewards.std(dim=-1) has_variance = reward_std_per_prompt > 1e-6 valid_indices = has_variance.nonzero(as_tuple=True)[0] if len(valid_indices) == 0: return prompts, responses, rewards, {"filtered": len(prompts), "total": len(prompts)} filtered_prompts = [prompts[i] for i in valid_indices] filtered_responses = [responses[i] for i in valid_indices] filtered_rewards = rewards[valid_indices] num_filtered = len(prompts) - len(valid_indices) return filtered_prompts, filtered_responses, filtered_rewards, {"filtered": num_filtered, "total": len(prompts)}