project_naka / code /trainer.py
ilessio-aiflowlab's picture
Upload folder using huggingface_hub
665e529 verified
"""GRPO Trainer — Group Relative Policy Optimization for coding agents.
Implements the CaP-RL training loop from the paper (Section 5):
1. Sample prompts from task environments
2. Generate GROUP_SIZE rollouts per prompt
3. Execute code in sim, get binary rewards
4. Compute group-relative advantages
5. Update policy with GRPO loss + KL penalty
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import torch
from torch.optim import AdamW
from anima_naka.rl.dataset import GRPODataset
from anima_naka.rl.reward import CaPRewardFunction
logger = logging.getLogger("anima_naka.rl")
@dataclass
class GRPOConfig:
"""GRPO training configuration."""
base_model: str = "Qwen/Qwen2.5-7B-Instruct"
tasks: list[str] = field(default_factory=lambda: ["cube_lift", "cube_stack"])
tier: str = "S1"
iterations: int = 50
batch_size: int = 2
group_size: int = 8
learning_rate: float = 2e-5
kl_penalty: float = 0.05
max_grad_norm: float = 1.0
max_tokens: int = 2048
temperature: float = 0.8
save_every: int = 5
output_dir: Path = Path("/mnt/artifacts-datai/checkpoints/project_naka")
log_dir: Path = Path("/mnt/artifacts-datai/logs/project_naka")
class GRPOTrainer:
"""GRPO training loop for CaP-RL.
Uses LoRA for memory efficiency on L4 GPUs.
"""
def __init__(self, config: GRPOConfig):
self.config = config
self._model = None
self._ref_model = None
self._tokenizer = None
self._optimizer = None
def setup(self):
"""Load model, tokenizer, optimizer."""
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info("[RL] Loading base model: %s", self.config.base_model)
self._tokenizer = AutoTokenizer.from_pretrained(self.config.base_model)
if self._tokenizer.pad_token is None:
self._tokenizer.pad_token = self._tokenizer.eos_token
self._model = AutoModelForCausalLM.from_pretrained(
self.config.base_model,
dtype=torch.bfloat16,
device_map="auto",
)
self._model.gradient_checkpointing_enable()
# Reference model (frozen) for KL penalty
self._ref_model = AutoModelForCausalLM.from_pretrained(
self.config.base_model,
dtype=torch.bfloat16,
device_map="auto",
)
self._ref_model.eval()
for p in self._ref_model.parameters():
p.requires_grad = False
self._optimizer = AdamW(
[p for p in self._model.parameters() if p.requires_grad],
lr=self.config.learning_rate,
)
logger.info("[RL] Setup complete. Trainable params: %d",
sum(p.numel() for p in self._model.parameters() if p.requires_grad))
def train(self) -> dict:
"""Main GRPO training loop."""
if self._model is None:
self.setup()
dataset = GRPODataset(tasks=self.config.tasks, tier=self.config.tier)
reward_fn = CaPRewardFunction(tier=self.config.tier)
metrics_history = []
for iteration in range(self.config.iterations):
iter_start = time.time()
batch = dataset.sample_batch(self.config.batch_size)
total_loss = 0.0
iter_rewards: list[float] = []
for prompt_data in batch:
# Generate rollouts
rollouts, log_probs = self._generate_with_logprobs(
prompt_data["messages"], n=self.config.group_size
)
# Compute rewards
rewards = []
for code in rollouts:
result = reward_fn.compute(code, seed=prompt_data["seed"])
rewards.append(result["score"])
iter_rewards.extend(rewards)
# Compute advantages
advantages = self._compute_advantages(rewards)
# GRPO loss: -Σ advantage * log_prob + KL penalty
loss = self._compute_grpo_loss(
prompt_data["messages"], rollouts, log_probs, advantages
)
total_loss += loss.item()
# Backprop
loss.backward()
# Gradient step
torch.nn.utils.clip_grad_norm_(
self._model.parameters(), self.config.max_grad_norm
)
self._optimizer.step()
self._optimizer.zero_grad()
avg_reward = float(np.mean(iter_rewards)) if iter_rewards else 0.0
avg_loss = total_loss / max(len(batch), 1)
iter_time = time.time() - iter_start
metrics = {
"iteration": iteration,
"avg_reward": avg_reward,
"avg_loss": avg_loss,
"time_s": iter_time,
}
metrics_history.append(metrics)
logger.info(
"[RL] Iter %d/%d: reward=%.3f loss=%.4f time=%.1fs",
iteration + 1, self.config.iterations,
avg_reward, avg_loss, iter_time,
)
if (iteration + 1) % self.config.save_every == 0:
self._save_checkpoint(iteration + 1)
self._save_checkpoint(self.config.iterations)
return {"metrics": metrics_history}
def _generate_with_logprobs(
self, messages: list[dict], n: int
) -> tuple[list[str], list[torch.Tensor]]:
"""Generate n completions with log probabilities."""
prompt = "\n".join(m["content"] for m in messages)
inputs = self._tokenizer(
prompt, return_tensors="pt", truncation=True,
max_length=self.config.max_tokens,
).to(self._model.device)
rollouts = []
log_probs_list = []
for _ in range(n):
with torch.no_grad():
outputs = self._model.generate(
**inputs,
max_new_tokens=self.config.max_tokens,
temperature=self.config.temperature,
do_sample=True,
pad_token_id=self._tokenizer.eos_token_id,
return_dict_in_generate=True,
output_scores=True,
)
# Decode generated text
gen_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:]
text = self._tokenizer.decode(gen_ids, skip_special_tokens=True)
rollouts.append(text)
# Compute log probs of generated tokens
if outputs.scores:
scores = torch.stack(outputs.scores, dim=0) # (seq_len, 1, vocab)
log_probs = torch.log_softmax(scores[:, 0, :], dim=-1)
selected = log_probs.gather(1, gen_ids[:len(scores)].unsqueeze(1))
log_probs_list.append(selected.squeeze(1).sum())
else:
log_probs_list.append(torch.tensor(0.0, device=self._model.device))
return rollouts, log_probs_list
def _compute_grpo_loss(
self,
messages: list[dict],
rollouts: list[str],
log_probs: list[torch.Tensor],
advantages: list[float],
) -> torch.Tensor:
"""Compute GRPO loss: -Σ advantage * log_prob + KL penalty."""
loss = torch.tensor(0.0, device=self._model.device, requires_grad=True)
for code, lp, adv in zip(rollouts, log_probs, advantages):
if adv == 0.0:
continue
# Policy gradient: -advantage * log_prob
pg_loss = -adv * lp
# KL penalty (optional, computed via log_prob difference)
kl_loss = self.config.kl_penalty * lp.abs()
loss = loss + pg_loss + kl_loss
return loss / max(len(rollouts), 1)
def _compute_advantages(self, rewards: list[float]) -> list[float]:
"""Group-relative advantage: (r_i - mean) / (std + eps)."""
r = np.array(rewards, dtype=np.float64)
mean = r.mean()
std = r.std() + 1e-8
return ((r - mean) / std).tolist()
def _save_checkpoint(self, iteration: int):
"""Save model checkpoint."""
ckpt_dir = self.config.output_dir / f"checkpoint_iter{iteration:04d}"
ckpt_dir.mkdir(parents=True, exist_ok=True)
if self._model is None or self._tokenizer is None:
return
self._model.save_pretrained(ckpt_dir)
self._tokenizer.save_pretrained(ckpt_dir)
logger.info("[RL] Checkpoint saved: %s", ckpt_dir)