Wolfvin commited on
Commit
9b970dd
·
verified ·
1 Parent(s): fa50230

Upload diffusion_llm/training/dapo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_llm/training/dapo.py +187 -0
diffusion_llm/training/dapo.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AAM Diffusion LLM — DAPO Training
2
+
3
+ Decoupled Clip & Dynamic Sampling Policy Optimization (Yu et al., 2025).
4
+ Four improvements over GRPO:
5
+ 1. Decoupled Clip (asymmetric epsilon)
6
+ 2. Dynamic Sampling (filter zero-variance groups)
7
+ 3. Token-Level Policy Gradient Loss
8
+ 4. Overlong Filtering
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import copy
14
+ import logging
15
+ from dataclasses import dataclass
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class DAPOConfig:
27
+ clip_ratio_low: float = 0.2
28
+ clip_ratio_high: float = 0.28
29
+ dynamic_sampling: bool = True
30
+ token_level_loss: bool = True
31
+ overlong_filter: bool = True
32
+ max_response_length: int = 2048
33
+ num_responses_per_prompt: int = 8
34
+ kl_coefficient: float = 0.1
35
+ discount_factor: float = 1.0
36
+ use_reward_normalization: bool = True
37
+ use_advantage_normalization: bool = True
38
+ learning_rate: float = 1e-6
39
+ reference_model_freeze: bool = True
40
+ entropy_coefficient: float = 0.01
41
+ max_grad_norm: float = 1.0
42
+ temperature: float = 0.7
43
+ reward_shaping: str = "centered"
44
+
45
+ def __post_init__(self) -> None:
46
+ if self.clip_ratio_low <= 0:
47
+ raise ValueError(f"clip_ratio_low must be positive, got {self.clip_ratio_low}")
48
+ if self.clip_ratio_high <= 0:
49
+ raise ValueError(f"clip_ratio_high must be positive, got {self.clip_ratio_high}")
50
+ if self.num_responses_per_prompt < 2:
51
+ raise ValueError(f"num_responses_per_prompt must be >= 2, got {self.num_responses_per_prompt}")
52
+
53
+
54
+ class DAPOTrainer:
55
+ """DAPO Trainer for AAM Diffusion LLM."""
56
+
57
+ def __init__(
58
+ self,
59
+ config: DAPOConfig,
60
+ policy_model: nn.Module,
61
+ reference_model: Optional[nn.Module] = None,
62
+ reward_fn: Optional[Callable] = None,
63
+ optimizer: Optional[torch.optim.Optimizer] = None,
64
+ ) -> None:
65
+ self.config = config
66
+ self.policy_model = policy_model
67
+ self.reward_fn = reward_fn
68
+
69
+ if reference_model is not None:
70
+ self.reference_model = reference_model
71
+ elif config.kl_coefficient > 0:
72
+ self.reference_model = copy.deepcopy(policy_model)
73
+ else:
74
+ self.reference_model = None
75
+
76
+ if self.reference_model is not None and config.reference_model_freeze:
77
+ for param in self.reference_model.parameters():
78
+ param.requires_grad = False
79
+
80
+ trainable_params = [p for p in policy_model.parameters() if p.requires_grad]
81
+ self.optimizer = optimizer or torch.optim.AdamW(
82
+ trainable_params, lr=config.learning_rate, betas=(0.9, 0.95), weight_decay=0.01,
83
+ )
84
+
85
+ self.device = next(policy_model.parameters()).device
86
+
87
+ def compute_dapo_loss(
88
+ self,
89
+ log_probs: torch.Tensor,
90
+ old_log_probs: torch.Tensor,
91
+ ref_log_probs: torch.Tensor,
92
+ rewards: torch.Tensor,
93
+ attention_mask: torch.Tensor,
94
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
95
+ cfg = self.config
96
+
97
+ log_ratio = log_probs - old_log_probs
98
+ ratio = torch.exp(log_ratio)
99
+
100
+ advantages = self._compute_advantages(rewards)
101
+ advantages_expanded = advantages.unsqueeze(-1).expand_as(log_probs) if advantages.dim() == 1 else advantages
102
+
103
+ clipped_ratio = torch.clamp(ratio, 1.0 - cfg.clip_ratio_low, 1.0 + cfg.clip_ratio_high)
104
+
105
+ surr1 = ratio * advantages_expanded
106
+ surr2 = clipped_ratio * advantages_expanded
107
+
108
+ if cfg.token_level_loss:
109
+ per_token_loss = -torch.min(surr1, surr2) * attention_mask
110
+ num_valid_tokens = attention_mask.sum(dim=-1, keepdim=True).clamp(min=1)
111
+ policy_loss = (per_token_loss.sum(dim=-1) / num_valid_tokens.squeeze(-1)).mean()
112
+ else:
113
+ per_token_loss = -torch.min(surr1, surr2) * attention_mask
114
+ seq_loss = per_token_loss.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)
115
+ policy_loss = seq_loss.mean()
116
+
117
+ kl_penalty = torch.tensor(0.0, device=log_probs.device)
118
+ if ref_log_probs is not None and cfg.kl_coefficient > 0:
119
+ kl_per_token = torch.exp(log_probs) * (log_probs - ref_log_probs) * attention_mask
120
+ kl_penalty = cfg.kl_coefficient * (kl_per_token.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)).mean()
121
+
122
+ entropy = torch.tensor(0.0, device=log_probs.device)
123
+ if cfg.entropy_coefficient > 0:
124
+ per_token_entropy = -torch.exp(log_probs) * log_probs * attention_mask
125
+ entropy = (per_token_entropy.sum(dim=-1) / attention_mask.sum(dim=-1).clamp(min=1)).mean()
126
+
127
+ loss = policy_loss + kl_penalty - cfg.entropy_coefficient * entropy
128
+
129
+ with torch.no_grad():
130
+ metrics = {
131
+ "dapo/policy_loss": policy_loss.item(),
132
+ "dapo/kl_penalty": kl_penalty.item() if isinstance(kl_penalty, torch.Tensor) else kl_penalty,
133
+ "dapo/entropy": entropy.item() if isinstance(entropy, torch.Tensor) else entropy,
134
+ "dapo/loss": loss.item(),
135
+ "dapo/mean_reward": rewards.mean().item(),
136
+ }
137
+
138
+ return loss, metrics
139
+
140
+ def _compute_advantages(self, rewards: torch.Tensor) -> torch.Tensor:
141
+ cfg = self.config
142
+ if cfg.use_reward_normalization and rewards.numel() > 1:
143
+ rewards = self._shape_rewards(rewards, cfg.reward_shaping)
144
+ advantages = rewards.clone()
145
+ if cfg.use_advantage_normalization and advantages.numel() > 1:
146
+ adv_std = advantages.std()
147
+ if adv_std > 1e-8:
148
+ advantages = (advantages - advantages.mean()) / (adv_std + 1e-8)
149
+ return advantages
150
+
151
+ def _shape_rewards(self, rewards: torch.Tensor, strategy: str) -> torch.Tensor:
152
+ if strategy == "raw":
153
+ return rewards
154
+ if strategy == "centered":
155
+ return rewards - rewards.mean()
156
+ if strategy == "rank_based":
157
+ sorted_indices = rewards.argsort()
158
+ ranks = torch.zeros_like(rewards, dtype=torch.float32)
159
+ ranks[sorted_indices] = torch.arange(len(rewards), dtype=torch.float32, device=rewards.device) / max(len(rewards) - 1, 1)
160
+ return 2.0 * ranks - 1.0
161
+ return rewards
162
+
163
+ def filter_prompts(
164
+ self,
165
+ prompts: List[str],
166
+ responses: List[List[str]],
167
+ rewards: torch.Tensor,
168
+ ) -> Tuple[List[str], List[List[str]], torch.Tensor, Dict[str, int]]:
169
+ if not self.config.dynamic_sampling:
170
+ return prompts, responses, rewards, {"filtered": 0, "total": len(prompts)}
171
+
172
+ if rewards.dim() == 1:
173
+ has_variance = rewards > 1e-6
174
+ else:
175
+ reward_std_per_prompt = rewards.std(dim=-1)
176
+ has_variance = reward_std_per_prompt > 1e-6
177
+
178
+ valid_indices = has_variance.nonzero(as_tuple=True)[0]
179
+ if len(valid_indices) == 0:
180
+ return prompts, responses, rewards, {"filtered": len(prompts), "total": len(prompts)}
181
+
182
+ filtered_prompts = [prompts[i] for i in valid_indices]
183
+ filtered_responses = [responses[i] for i in valid_indices]
184
+ filtered_rewards = rewards[valid_indices]
185
+ num_filtered = len(prompts) - len(valid_indices)
186
+
187
+ return filtered_prompts, filtered_responses, filtered_rewards, {"filtered": num_filtered, "total": len(prompts)}