| """GRPO Trainer — Group Relative Policy Optimization for coding agents. |
| |
| Implements the CaP-RL training loop from the paper (Section 5): |
| 1. Sample prompts from task environments |
| 2. Generate GROUP_SIZE rollouts per prompt |
| 3. Execute code in sim, get binary rewards |
| 4. Compute group-relative advantages |
| 5. Update policy with GRPO loss + KL penalty |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import time |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch.optim import AdamW |
|
|
| from anima_naka.rl.dataset import GRPODataset |
| from anima_naka.rl.reward import CaPRewardFunction |
|
|
| logger = logging.getLogger("anima_naka.rl") |
|
|
|
|
| @dataclass |
| class GRPOConfig: |
| """GRPO training configuration.""" |
|
|
| base_model: str = "Qwen/Qwen2.5-7B-Instruct" |
| tasks: list[str] = field(default_factory=lambda: ["cube_lift", "cube_stack"]) |
| tier: str = "S1" |
| iterations: int = 50 |
| batch_size: int = 2 |
| group_size: int = 8 |
| learning_rate: float = 2e-5 |
| kl_penalty: float = 0.05 |
| max_grad_norm: float = 1.0 |
| max_tokens: int = 2048 |
| temperature: float = 0.8 |
| save_every: int = 5 |
| output_dir: Path = Path("/mnt/artifacts-datai/checkpoints/project_naka") |
| log_dir: Path = Path("/mnt/artifacts-datai/logs/project_naka") |
|
|
|
|
| class GRPOTrainer: |
| """GRPO training loop for CaP-RL. |
| |
| Uses LoRA for memory efficiency on L4 GPUs. |
| """ |
|
|
| def __init__(self, config: GRPOConfig): |
| self.config = config |
| self._model = None |
| self._ref_model = None |
| self._tokenizer = None |
| self._optimizer = None |
|
|
| def setup(self): |
| """Load model, tokenizer, optimizer.""" |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| logger.info("[RL] Loading base model: %s", self.config.base_model) |
| self._tokenizer = AutoTokenizer.from_pretrained(self.config.base_model) |
| if self._tokenizer.pad_token is None: |
| self._tokenizer.pad_token = self._tokenizer.eos_token |
|
|
| self._model = AutoModelForCausalLM.from_pretrained( |
| self.config.base_model, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| self._model.gradient_checkpointing_enable() |
|
|
| |
| self._ref_model = AutoModelForCausalLM.from_pretrained( |
| self.config.base_model, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| self._ref_model.eval() |
| for p in self._ref_model.parameters(): |
| p.requires_grad = False |
|
|
| self._optimizer = AdamW( |
| [p for p in self._model.parameters() if p.requires_grad], |
| lr=self.config.learning_rate, |
| ) |
| logger.info("[RL] Setup complete. Trainable params: %d", |
| sum(p.numel() for p in self._model.parameters() if p.requires_grad)) |
|
|
| def train(self) -> dict: |
| """Main GRPO training loop.""" |
| if self._model is None: |
| self.setup() |
|
|
| dataset = GRPODataset(tasks=self.config.tasks, tier=self.config.tier) |
| reward_fn = CaPRewardFunction(tier=self.config.tier) |
| metrics_history = [] |
|
|
| for iteration in range(self.config.iterations): |
| iter_start = time.time() |
| batch = dataset.sample_batch(self.config.batch_size) |
|
|
| total_loss = 0.0 |
| iter_rewards: list[float] = [] |
|
|
| for prompt_data in batch: |
| |
| rollouts, log_probs = self._generate_with_logprobs( |
| prompt_data["messages"], n=self.config.group_size |
| ) |
|
|
| |
| rewards = [] |
| for code in rollouts: |
| result = reward_fn.compute(code, seed=prompt_data["seed"]) |
| rewards.append(result["score"]) |
| iter_rewards.extend(rewards) |
|
|
| |
| advantages = self._compute_advantages(rewards) |
|
|
| |
| loss = self._compute_grpo_loss( |
| prompt_data["messages"], rollouts, log_probs, advantages |
| ) |
| total_loss += loss.item() |
|
|
| |
| loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_( |
| self._model.parameters(), self.config.max_grad_norm |
| ) |
| self._optimizer.step() |
| self._optimizer.zero_grad() |
|
|
| avg_reward = float(np.mean(iter_rewards)) if iter_rewards else 0.0 |
| avg_loss = total_loss / max(len(batch), 1) |
| iter_time = time.time() - iter_start |
|
|
| metrics = { |
| "iteration": iteration, |
| "avg_reward": avg_reward, |
| "avg_loss": avg_loss, |
| "time_s": iter_time, |
| } |
| metrics_history.append(metrics) |
| logger.info( |
| "[RL] Iter %d/%d: reward=%.3f loss=%.4f time=%.1fs", |
| iteration + 1, self.config.iterations, |
| avg_reward, avg_loss, iter_time, |
| ) |
|
|
| if (iteration + 1) % self.config.save_every == 0: |
| self._save_checkpoint(iteration + 1) |
|
|
| self._save_checkpoint(self.config.iterations) |
| return {"metrics": metrics_history} |
|
|
| def _generate_with_logprobs( |
| self, messages: list[dict], n: int |
| ) -> tuple[list[str], list[torch.Tensor]]: |
| """Generate n completions with log probabilities.""" |
| prompt = "\n".join(m["content"] for m in messages) |
| inputs = self._tokenizer( |
| prompt, return_tensors="pt", truncation=True, |
| max_length=self.config.max_tokens, |
| ).to(self._model.device) |
|
|
| rollouts = [] |
| log_probs_list = [] |
|
|
| for _ in range(n): |
| with torch.no_grad(): |
| outputs = self._model.generate( |
| **inputs, |
| max_new_tokens=self.config.max_tokens, |
| temperature=self.config.temperature, |
| do_sample=True, |
| pad_token_id=self._tokenizer.eos_token_id, |
| return_dict_in_generate=True, |
| output_scores=True, |
| ) |
|
|
| |
| gen_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:] |
| text = self._tokenizer.decode(gen_ids, skip_special_tokens=True) |
| rollouts.append(text) |
|
|
| |
| if outputs.scores: |
| scores = torch.stack(outputs.scores, dim=0) |
| log_probs = torch.log_softmax(scores[:, 0, :], dim=-1) |
| selected = log_probs.gather(1, gen_ids[:len(scores)].unsqueeze(1)) |
| log_probs_list.append(selected.squeeze(1).sum()) |
| else: |
| log_probs_list.append(torch.tensor(0.0, device=self._model.device)) |
|
|
| return rollouts, log_probs_list |
|
|
| def _compute_grpo_loss( |
| self, |
| messages: list[dict], |
| rollouts: list[str], |
| log_probs: list[torch.Tensor], |
| advantages: list[float], |
| ) -> torch.Tensor: |
| """Compute GRPO loss: -Σ advantage * log_prob + KL penalty.""" |
| loss = torch.tensor(0.0, device=self._model.device, requires_grad=True) |
|
|
| for code, lp, adv in zip(rollouts, log_probs, advantages): |
| if adv == 0.0: |
| continue |
| |
| pg_loss = -adv * lp |
|
|
| |
| kl_loss = self.config.kl_penalty * lp.abs() |
|
|
| loss = loss + pg_loss + kl_loss |
|
|
| return loss / max(len(rollouts), 1) |
|
|
| def _compute_advantages(self, rewards: list[float]) -> list[float]: |
| """Group-relative advantage: (r_i - mean) / (std + eps).""" |
| r = np.array(rewards, dtype=np.float64) |
| mean = r.mean() |
| std = r.std() + 1e-8 |
| return ((r - mean) / std).tolist() |
|
|
| def _save_checkpoint(self, iteration: int): |
| """Save model checkpoint.""" |
| ckpt_dir = self.config.output_dir / f"checkpoint_iter{iteration:04d}" |
| ckpt_dir.mkdir(parents=True, exist_ok=True) |
| if self._model is None or self._tokenizer is None: |
| return |
|
|
| self._model.save_pretrained(ckpt_dir) |
| self._tokenizer.save_pretrained(ckpt_dir) |
| logger.info("[RL] Checkpoint saved: %s", ckpt_dir) |
|
|