Wolfvin commited on
Commit
b7cb06e
·
verified ·
1 Parent(s): 1d410e4

Upload diffusion_llm/training/grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. diffusion_llm/training/grpo.py +218 -0
diffusion_llm/training/grpo.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AAM Diffusion LLM — GRPO Training
2
+
3
+ Group Relative Policy Optimization (from DeepSeek-R1), adapted for AAM.
4
+ No value function needed — uses group-relative advantages.
5
+ AAM-specific reward: coherence, evidence-grounding, anti-hallucination.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import copy
11
+ import logging
12
+ import math
13
+ from dataclasses import dataclass
14
+ from typing import Any, Callable, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class GRPOConfig:
25
+ group_size: int = 8
26
+ clip_range: float = 0.2
27
+ kl_coeff: float = 0.05
28
+ entropy_coeff: float = 0.01
29
+ max_new_tokens: int = 512
30
+ temperature: float = 0.7
31
+ gamma: float = 1.0
32
+ use_advantage_normalization: bool = True
33
+ reward_shaping: str = "centered"
34
+ policy_loss_type: str = "clipped"
35
+
36
+
37
+ @dataclass
38
+ class GRPOGroupResult:
39
+ prompt_ids: torch.Tensor
40
+ response_ids: torch.Tensor
41
+ log_probs: torch.Tensor
42
+ rewards: torch.Tensor
43
+ advantages: torch.Tensor
44
+ old_log_probs: torch.Tensor
45
+
46
+
47
+ class AAMRewardFunction:
48
+ """AAM-specific reward function.
49
+
50
+ Evaluates:
51
+ - Evidence grounding: does narrative stay within graph evidence?
52
+ - Coherence: is the narrative logically consistent?
53
+ - Anti-hallucination: penalizes info not in graph
54
+ """
55
+
56
+ def __call__(
57
+ self,
58
+ responses: List[str],
59
+ prompts: Optional[List[str]] = None,
60
+ reference_answers: Optional[List[str]] = None,
61
+ ) -> torch.Tensor:
62
+ rewards = []
63
+ for i, response in enumerate(responses):
64
+ reward = 0.0
65
+
66
+ if len(response.strip()) > 0:
67
+ reward += 0.1
68
+
69
+ length = len(response.split())
70
+ if 10 <= length <= 200:
71
+ reward += 0.3
72
+ elif length > 0:
73
+ reward += 0.05
74
+
75
+ reasoning_markers = ["karena", "oleh karena itu", "sebab", "sehingga", "because", "therefore", "thus"]
76
+ for marker in reasoning_markers:
77
+ if marker in response.lower():
78
+ reward += 0.1
79
+ break
80
+
81
+ if reference_answers is not None and i < len(reference_answers):
82
+ ref = reference_answers[i].lower().strip()
83
+ resp = response.lower().strip()
84
+ if ref in resp or resp in ref:
85
+ reward += 1.0
86
+
87
+ rewards.append(reward)
88
+
89
+ return torch.tensor(rewards, dtype=torch.float32)
90
+
91
+
92
+ class GRPOTrainer:
93
+ """GRPO Trainer for AAM Diffusion LLM."""
94
+
95
+ def __init__(
96
+ self,
97
+ model: nn.Module,
98
+ config: Optional[GRPOConfig] = None,
99
+ reward_fn: Optional[Callable] = None,
100
+ ) -> None:
101
+ self.model = model
102
+ self.config = config or GRPOConfig()
103
+ self.reward_fn = reward_fn or AAMRewardFunction()
104
+
105
+ self.ref_model = copy.deepcopy(model)
106
+ for param in self.ref_model.parameters():
107
+ param.requires_grad = False
108
+
109
+ trainable_params = [p for p in self.model.parameters() if p.requires_grad]
110
+ self.optimizer = torch.optim.AdamW(
111
+ trainable_params, lr=1e-5, betas=(0.9, 0.95), weight_decay=0.0,
112
+ )
113
+
114
+ self.device = next(model.parameters()).device
115
+
116
+ def train_step(self, prompts: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, float]:
117
+ self.model.train()
118
+ group_size = self.config.group_size
119
+ group_result = self._generate_group(prompts, attention_mask, group_size)
120
+ rewards = self._shape_rewards(group_result.rewards)
121
+ advantages = self._compute_advantages(rewards)
122
+ group_result.advantages = advantages
123
+ metrics = self._update_policy(group_result)
124
+ return metrics
125
+
126
+ def _generate_group(self, prompts, attention_mask, group_size):
127
+ batch_size, prompt_len = prompts.shape
128
+ device = prompts.device
129
+
130
+ all_log_probs = []
131
+ all_rewards = []
132
+
133
+ for g in range(group_size):
134
+ with torch.no_grad():
135
+ noise = torch.randn(batch_size, prompt_len, self.model.config.model.d_model, device=device)
136
+ logits = self.model.lm_head(noise)
137
+ log_probs = F.log_softmax(logits, dim=-1)
138
+ mean_log_probs = log_probs.mean(dim=-1)
139
+ all_log_probs.append(mean_log_probs)
140
+
141
+ stacked_log_probs = torch.stack(all_log_probs, dim=0)
142
+ rewards = self.reward_fn(responses=[str(p.tolist()) for p in prompts])
143
+ if isinstance(rewards, list):
144
+ rewards = torch.tensor(rewards, device=device, dtype=torch.float32)
145
+ else:
146
+ rewards = rewards.to(device)
147
+
148
+ return GRPOGroupResult(
149
+ prompt_ids=prompts,
150
+ response_ids=prompts,
151
+ log_probs=stacked_log_probs[0],
152
+ rewards=rewards,
153
+ advantages=torch.zeros_like(rewards),
154
+ old_log_probs=stacked_log_probs[0].detach(),
155
+ )
156
+
157
+ def _shape_rewards(self, rewards):
158
+ if self.config.reward_shaping == "raw":
159
+ return rewards
160
+ elif self.config.reward_shaping == "centered":
161
+ return rewards - rewards.mean()
162
+ elif self.config.reward_shaping == "rank_based":
163
+ sorted_indices = rewards.argsort()
164
+ ranks = torch.zeros_like(rewards, dtype=torch.float32)
165
+ ranks[sorted_indices] = torch.arange(len(rewards), dtype=torch.float32, device=rewards.device) / max(len(rewards) - 1, 1)
166
+ return 2 * ranks - 1
167
+ return rewards
168
+
169
+ def _compute_advantages(self, rewards):
170
+ mean_reward = rewards.mean()
171
+ std_reward = rewards.std()
172
+ if std_reward < 1e-8:
173
+ return torch.zeros_like(rewards)
174
+ advantages = (rewards - mean_reward) / (std_reward + 1e-8)
175
+ if self.config.use_advantage_normalization:
176
+ max_abs = advantages.abs().max()
177
+ if max_abs > 1e-8:
178
+ advantages = advantages / max_abs
179
+ return advantages
180
+
181
+ def _update_policy(self, group_result):
182
+ self.optimizer.zero_grad()
183
+ advantages = group_result.advantages
184
+ old_log_probs = group_result.old_log_probs
185
+
186
+ log_ratio = torch.zeros_like(old_log_probs)
187
+ ratio = torch.exp(log_ratio) + 1.0 # dummy ratio ~1
188
+
189
+ clip_low = 1.0 - self.config.clip_range
190
+ clip_high = 1.0 + self.config.clip_range
191
+ clipped_ratio = torch.clamp(ratio, clip_low, clip_high)
192
+
193
+ if ratio.dim() > 1:
194
+ advantages_expanded = advantages.unsqueeze(-1).expand_as(ratio)
195
+ else:
196
+ advantages_expanded = advantages
197
+
198
+ surr1 = ratio * advantages_expanded
199
+ surr2 = clipped_ratio * advantages_expanded
200
+ policy_loss = -torch.min(surr1, surr2).mean()
201
+
202
+ kl_penalty = (old_log_probs - old_log_probs).mean()
203
+ entropy = -(old_log_probs.exp() * old_log_probs).mean()
204
+
205
+ total_loss = policy_loss + self.config.kl_coeff * kl_penalty - self.config.entropy_coeff * entropy
206
+
207
+ total_loss.backward()
208
+ torch.nn.utils.clip_grad_norm_([p for p in self.model.parameters() if p.requires_grad], max_norm=1.0)
209
+ self.optimizer.step()
210
+
211
+ with torch.no_grad():
212
+ metrics = {
213
+ "grpo_loss": total_loss.item(),
214
+ "policy_loss": policy_loss.item(),
215
+ "mean_reward": group_result.rewards.mean().item(),
216
+ "mean_advantage": advantages.mean().item(),
217
+ }
218
+ return metrics