"""GRPO Trainer — Group Relative Policy Optimization for coding agents. Implements the CaP-RL training loop from the paper (Section 5): 1. Sample prompts from task environments 2. Generate GROUP_SIZE rollouts per prompt 3. Execute code in sim, get binary rewards 4. Compute group-relative advantages 5. Update policy with GRPO loss + KL penalty """ from __future__ import annotations import logging import time from dataclasses import dataclass, field from pathlib import Path import numpy as np import torch from torch.optim import AdamW from anima_naka.rl.dataset import GRPODataset from anima_naka.rl.reward import CaPRewardFunction logger = logging.getLogger("anima_naka.rl") @dataclass class GRPOConfig: """GRPO training configuration.""" base_model: str = "Qwen/Qwen2.5-7B-Instruct" tasks: list[str] = field(default_factory=lambda: ["cube_lift", "cube_stack"]) tier: str = "S1" iterations: int = 50 batch_size: int = 2 group_size: int = 8 learning_rate: float = 2e-5 kl_penalty: float = 0.05 max_grad_norm: float = 1.0 max_tokens: int = 2048 temperature: float = 0.8 save_every: int = 5 output_dir: Path = Path("/mnt/artifacts-datai/checkpoints/project_naka") log_dir: Path = Path("/mnt/artifacts-datai/logs/project_naka") class GRPOTrainer: """GRPO training loop for CaP-RL. Uses LoRA for memory efficiency on L4 GPUs. """ def __init__(self, config: GRPOConfig): self.config = config self._model = None self._ref_model = None self._tokenizer = None self._optimizer = None def setup(self): """Load model, tokenizer, optimizer.""" from transformers import AutoModelForCausalLM, AutoTokenizer logger.info("[RL] Loading base model: %s", self.config.base_model) self._tokenizer = AutoTokenizer.from_pretrained(self.config.base_model) if self._tokenizer.pad_token is None: self._tokenizer.pad_token = self._tokenizer.eos_token self._model = AutoModelForCausalLM.from_pretrained( self.config.base_model, dtype=torch.bfloat16, device_map="auto", ) self._model.gradient_checkpointing_enable() # Reference model (frozen) for KL penalty self._ref_model = AutoModelForCausalLM.from_pretrained( self.config.base_model, dtype=torch.bfloat16, device_map="auto", ) self._ref_model.eval() for p in self._ref_model.parameters(): p.requires_grad = False self._optimizer = AdamW( [p for p in self._model.parameters() if p.requires_grad], lr=self.config.learning_rate, ) logger.info("[RL] Setup complete. Trainable params: %d", sum(p.numel() for p in self._model.parameters() if p.requires_grad)) def train(self) -> dict: """Main GRPO training loop.""" if self._model is None: self.setup() dataset = GRPODataset(tasks=self.config.tasks, tier=self.config.tier) reward_fn = CaPRewardFunction(tier=self.config.tier) metrics_history = [] for iteration in range(self.config.iterations): iter_start = time.time() batch = dataset.sample_batch(self.config.batch_size) total_loss = 0.0 iter_rewards: list[float] = [] for prompt_data in batch: # Generate rollouts rollouts, log_probs = self._generate_with_logprobs( prompt_data["messages"], n=self.config.group_size ) # Compute rewards rewards = [] for code in rollouts: result = reward_fn.compute(code, seed=prompt_data["seed"]) rewards.append(result["score"]) iter_rewards.extend(rewards) # Compute advantages advantages = self._compute_advantages(rewards) # GRPO loss: -Σ advantage * log_prob + KL penalty loss = self._compute_grpo_loss( prompt_data["messages"], rollouts, log_probs, advantages ) total_loss += loss.item() # Backprop loss.backward() # Gradient step torch.nn.utils.clip_grad_norm_( self._model.parameters(), self.config.max_grad_norm ) self._optimizer.step() self._optimizer.zero_grad() avg_reward = float(np.mean(iter_rewards)) if iter_rewards else 0.0 avg_loss = total_loss / max(len(batch), 1) iter_time = time.time() - iter_start metrics = { "iteration": iteration, "avg_reward": avg_reward, "avg_loss": avg_loss, "time_s": iter_time, } metrics_history.append(metrics) logger.info( "[RL] Iter %d/%d: reward=%.3f loss=%.4f time=%.1fs", iteration + 1, self.config.iterations, avg_reward, avg_loss, iter_time, ) if (iteration + 1) % self.config.save_every == 0: self._save_checkpoint(iteration + 1) self._save_checkpoint(self.config.iterations) return {"metrics": metrics_history} def _generate_with_logprobs( self, messages: list[dict], n: int ) -> tuple[list[str], list[torch.Tensor]]: """Generate n completions with log probabilities.""" prompt = "\n".join(m["content"] for m in messages) inputs = self._tokenizer( prompt, return_tensors="pt", truncation=True, max_length=self.config.max_tokens, ).to(self._model.device) rollouts = [] log_probs_list = [] for _ in range(n): with torch.no_grad(): outputs = self._model.generate( **inputs, max_new_tokens=self.config.max_tokens, temperature=self.config.temperature, do_sample=True, pad_token_id=self._tokenizer.eos_token_id, return_dict_in_generate=True, output_scores=True, ) # Decode generated text gen_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:] text = self._tokenizer.decode(gen_ids, skip_special_tokens=True) rollouts.append(text) # Compute log probs of generated tokens if outputs.scores: scores = torch.stack(outputs.scores, dim=0) # (seq_len, 1, vocab) log_probs = torch.log_softmax(scores[:, 0, :], dim=-1) selected = log_probs.gather(1, gen_ids[:len(scores)].unsqueeze(1)) log_probs_list.append(selected.squeeze(1).sum()) else: log_probs_list.append(torch.tensor(0.0, device=self._model.device)) return rollouts, log_probs_list def _compute_grpo_loss( self, messages: list[dict], rollouts: list[str], log_probs: list[torch.Tensor], advantages: list[float], ) -> torch.Tensor: """Compute GRPO loss: -Σ advantage * log_prob + KL penalty.""" loss = torch.tensor(0.0, device=self._model.device, requires_grad=True) for code, lp, adv in zip(rollouts, log_probs, advantages): if adv == 0.0: continue # Policy gradient: -advantage * log_prob pg_loss = -adv * lp # KL penalty (optional, computed via log_prob difference) kl_loss = self.config.kl_penalty * lp.abs() loss = loss + pg_loss + kl_loss return loss / max(len(rollouts), 1) def _compute_advantages(self, rewards: list[float]) -> list[float]: """Group-relative advantage: (r_i - mean) / (std + eps).""" r = np.array(rewards, dtype=np.float64) mean = r.mean() std = r.std() + 1e-8 return ((r - mean) / std).tolist() def _save_checkpoint(self, iteration: int): """Save model checkpoint.""" ckpt_dir = self.config.output_dir / f"checkpoint_iter{iteration:04d}" ckpt_dir.mkdir(parents=True, exist_ok=True) if self._model is None or self._tokenizer is None: return self._model.save_pretrained(ckpt_dir) self._tokenizer.save_pretrained(ckpt_dir) logger.info("[RL] Checkpoint saved: %s", ckpt_dir)