Wolfvin's picture
Upload diffusion_llm/training/grpo.py with huggingface_hub
b7cb06e verified
"""AAM Diffusion LLM — GRPO Training
Group Relative Policy Optimization (from DeepSeek-R1), adapted for AAM.
No value function needed — uses group-relative advantages.
AAM-specific reward: coherence, evidence-grounding, anti-hallucination.
"""
from __future__ import annotations
import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
@dataclass
class GRPOConfig:
group_size: int = 8
clip_range: float = 0.2
kl_coeff: float = 0.05
entropy_coeff: float = 0.01
max_new_tokens: int = 512
temperature: float = 0.7
gamma: float = 1.0
use_advantage_normalization: bool = True
reward_shaping: str = "centered"
policy_loss_type: str = "clipped"
@dataclass
class GRPOGroupResult:
prompt_ids: torch.Tensor
response_ids: torch.Tensor
log_probs: torch.Tensor
rewards: torch.Tensor
advantages: torch.Tensor
old_log_probs: torch.Tensor
class AAMRewardFunction:
"""AAM-specific reward function.
Evaluates:
- Evidence grounding: does narrative stay within graph evidence?
- Coherence: is the narrative logically consistent?
- Anti-hallucination: penalizes info not in graph
"""
def __call__(
self,
responses: List[str],
prompts: Optional[List[str]] = None,
reference_answers: Optional[List[str]] = None,
) -> torch.Tensor:
rewards = []
for i, response in enumerate(responses):
reward = 0.0
if len(response.strip()) > 0:
reward += 0.1
length = len(response.split())
if 10 <= length <= 200:
reward += 0.3
elif length > 0:
reward += 0.05
reasoning_markers = ["karena", "oleh karena itu", "sebab", "sehingga", "because", "therefore", "thus"]
for marker in reasoning_markers:
if marker in response.lower():
reward += 0.1
break
if reference_answers is not None and i < len(reference_answers):
ref = reference_answers[i].lower().strip()
resp = response.lower().strip()
if ref in resp or resp in ref:
reward += 1.0
rewards.append(reward)
return torch.tensor(rewards, dtype=torch.float32)
class GRPOTrainer:
"""GRPO Trainer for AAM Diffusion LLM."""
def __init__(
self,
model: nn.Module,
config: Optional[GRPOConfig] = None,
reward_fn: Optional[Callable] = None,
) -> None:
self.model = model
self.config = config or GRPOConfig()
self.reward_fn = reward_fn or AAMRewardFunction()
self.ref_model = copy.deepcopy(model)
for param in self.ref_model.parameters():
param.requires_grad = False
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
self.optimizer = torch.optim.AdamW(
trainable_params, lr=1e-5, betas=(0.9, 0.95), weight_decay=0.0,
)
self.device = next(model.parameters()).device
def train_step(self, prompts: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> Dict[str, float]:
self.model.train()
group_size = self.config.group_size
group_result = self._generate_group(prompts, attention_mask, group_size)
rewards = self._shape_rewards(group_result.rewards)
advantages = self._compute_advantages(rewards)
group_result.advantages = advantages
metrics = self._update_policy(group_result)
return metrics
def _generate_group(self, prompts, attention_mask, group_size):
batch_size, prompt_len = prompts.shape
device = prompts.device
all_log_probs = []
all_rewards = []
for g in range(group_size):
with torch.no_grad():
noise = torch.randn(batch_size, prompt_len, self.model.config.model.d_model, device=device)
logits = self.model.lm_head(noise)
log_probs = F.log_softmax(logits, dim=-1)
mean_log_probs = log_probs.mean(dim=-1)
all_log_probs.append(mean_log_probs)
stacked_log_probs = torch.stack(all_log_probs, dim=0)
rewards = self.reward_fn(responses=[str(p.tolist()) for p in prompts])
if isinstance(rewards, list):
rewards = torch.tensor(rewards, device=device, dtype=torch.float32)
else:
rewards = rewards.to(device)
return GRPOGroupResult(
prompt_ids=prompts,
response_ids=prompts,
log_probs=stacked_log_probs[0],
rewards=rewards,
advantages=torch.zeros_like(rewards),
old_log_probs=stacked_log_probs[0].detach(),
)
def _shape_rewards(self, rewards):
if self.config.reward_shaping == "raw":
return rewards
elif self.config.reward_shaping == "centered":
return rewards - rewards.mean()
elif self.config.reward_shaping == "rank_based":
sorted_indices = rewards.argsort()
ranks = torch.zeros_like(rewards, dtype=torch.float32)
ranks[sorted_indices] = torch.arange(len(rewards), dtype=torch.float32, device=rewards.device) / max(len(rewards) - 1, 1)
return 2 * ranks - 1
return rewards
def _compute_advantages(self, rewards):
mean_reward = rewards.mean()
std_reward = rewards.std()
if std_reward < 1e-8:
return torch.zeros_like(rewards)
advantages = (rewards - mean_reward) / (std_reward + 1e-8)
if self.config.use_advantage_normalization:
max_abs = advantages.abs().max()
if max_abs > 1e-8:
advantages = advantages / max_abs
return advantages
def _update_policy(self, group_result):
self.optimizer.zero_grad()
advantages = group_result.advantages
old_log_probs = group_result.old_log_probs
log_ratio = torch.zeros_like(old_log_probs)
ratio = torch.exp(log_ratio) + 1.0 # dummy ratio ~1
clip_low = 1.0 - self.config.clip_range
clip_high = 1.0 + self.config.clip_range
clipped_ratio = torch.clamp(ratio, clip_low, clip_high)
if ratio.dim() > 1:
advantages_expanded = advantages.unsqueeze(-1).expand_as(ratio)
else:
advantages_expanded = advantages
surr1 = ratio * advantages_expanded
surr2 = clipped_ratio * advantages_expanded
policy_loss = -torch.min(surr1, surr2).mean()
kl_penalty = (old_log_probs - old_log_probs).mean()
entropy = -(old_log_probs.exp() * old_log_probs).mean()
total_loss = policy_loss + self.config.kl_coeff * kl_penalty - self.config.entropy_coeff * entropy
total_loss.backward()
torch.nn.utils.clip_grad_norm_([p for p in self.model.parameters() if p.requires_grad], max_norm=1.0)
self.optimizer.step()
with torch.no_grad():
metrics = {
"grpo_loss": total_loss.item(),
"policy_loss": policy_loss.item(),
"mean_reward": group_result.rewards.mean().item(),
"mean_advantage": advantages.mean().item(),
}
return metrics