Spaces:
Running
Running
| """End-to-end GRPO training on AWS RL Env, driven from Google Colab. | |
| This script is pedagogical β it shows how the moving pieces connect: | |
| [Central Curriculum] --picks task_id--> [G parallel rollouts via GrpoPool] | |
| | | |
| v | |
| [Per-rollout trajectory of (prompt, action, reward)] | |
| | | |
| v | |
| [Group-normalized advantages: A_i = (R_i - mean) / std] | |
| | | |
| v | |
| [PPO-style policy-gradient loss on logprobs] | |
| | | |
| v | |
| [Optimizer step] | |
| Why this is "GRPO" and not vanilla REINFORCE: | |
| GRPO (Group Relative Policy Optimization, DeepSeek) replaces the value | |
| baseline with the **group mean** of rewards. For each task we sample G | |
| trajectories; the advantage for rollout i is A_i = (R_i - mean(R)) / std(R). | |
| This is variance-reduced and critic-free β perfect for our env where G=8. | |
| Requirements (install in the Colab cell before running): | |
| !pip install unsloth trl torch transformers accelerate bitsandbytes httpx websockets | |
| Prerequisites: | |
| The RL env server must be running somewhere Colab can reach, with | |
| AWS_RL_ENV_POOL_SIZE=8 set. Easiest: | |
| docker run -p 8000:8000 -e AWS_RL_ENV_POOL_SIZE=8 aws-rl-env:latest | |
| And expose port 8000 via `cloudflared tunnel` or `ngrok http 8000`. | |
| Set BASE_URL below to the public URL of that tunnel. | |
| Run: | |
| python scripts/grpo_train.py | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import math | |
| import os | |
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from client import AwsRlEnv | |
| from models import AwsRlAction, Task | |
| from scripts.grpo_pool import GrpoPool | |
| from server.services.curriculum import Curriculum | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Config β tune for your setup | |
| # --------------------------------------------------------------------------- | |
| BASE_URL = os.getenv("AWS_RL_ENV_BASE_URL", "http://localhost:8000") | |
| GROUP_SIZE = int(os.getenv("GRPO_GROUP_SIZE", "8")) # G in GRPO | |
| NUM_GRPO_STEPS = int(os.getenv("GRPO_NUM_STEPS", "100")) # outer training steps | |
| MAX_EPISODE_STEPS = int( | |
| os.getenv("GRPO_MAX_STEPS", "15") | |
| ) # per-rollout step cap (matches MAX_STEPS in env) | |
| LEARNING_RATE = float(os.getenv("GRPO_LR", "5e-6")) | |
| KL_COEFF = float(os.getenv("GRPO_KL", "0.04")) # KL penalty vs reference model | |
| CLIP_EPS = float(os.getenv("GRPO_CLIP", "0.2")) # PPO clip for stability | |
| TEMPERATURE = float(os.getenv("GRPO_TEMP", "0.9")) | |
| MAX_NEW_TOKENS = int(os.getenv("GRPO_MAX_NEW", "96")) # per model generation | |
| MODEL_NAME = os.getenv("GRPO_MODEL", "Qwen/Qwen2.5-1.5B-Instruct") | |
| # --------------------------------------------------------------------------- | |
| # Model loading (Unsloth β 4-bit LoRA, Colab-friendly) | |
| # --------------------------------------------------------------------------- | |
| def load_model_and_tokenizer(): | |
| """Load a 4-bit LoRA-wrapped model via Unsloth. | |
| Unsloth is a drop-in replacement for transformers that ~2x speeds up | |
| fine-tuning on a single GPU and fits a 1.5B model on a free Colab T4. | |
| """ | |
| from unsloth import FastLanguageModel | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=MODEL_NAME, | |
| max_seq_length=4096, | |
| load_in_4bit=True, | |
| ) | |
| # LoRA wrapping β only these params receive gradients | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=16, | |
| lora_alpha=16, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| lora_dropout=0.0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| ) | |
| FastLanguageModel.for_training(model) | |
| return model, tokenizer | |
| # --------------------------------------------------------------------------- | |
| # Prompt construction & action extraction | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """You are an expert AWS SRE agent. You operate a simulated AWS cloud by \ | |
| emitting one AWS CLI command at a time. You will see the task description and the most \ | |
| recent command output, then reply with EXACTLY ONE AWS CLI command on a single line \ | |
| starting with 'aws '. No explanation, no markdown, no quotes β just the command.""" | |
| def build_prompt(tokenizer, task: Task, history: List[Tuple[str, str]]) -> str: | |
| """Build a chat prompt from the task + command/output history.""" | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"TASK: {task.description}"}, | |
| ] | |
| for cmd, out in history[-4:]: # keep last 4 turns to fit context | |
| messages.append({"role": "assistant", "content": cmd}) | |
| messages.append({"role": "user", "content": f"OUTPUT:\n{out[:400]}"}) | |
| return tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| def extract_command(raw: str) -> str: | |
| """Pull the first `aws β¦` line out of the model's raw decoded output.""" | |
| for line in raw.splitlines(): | |
| line = line.strip().strip("`").strip() | |
| if line.startswith("aws "): | |
| return line | |
| return "aws help" # safe fallback so env always accepts the command | |
| # --------------------------------------------------------------------------- | |
| # Rollout β one trajectory, one env, one task | |
| # --------------------------------------------------------------------------- | |
| class Step: | |
| """One step of a trajectory. `prompt_ids` + `action_ids` are what we backprop on.""" | |
| prompt_ids: torch.Tensor # shape [prompt_len] | |
| action_ids: torch.Tensor # shape [action_len] | |
| logprob_sum: ( | |
| torch.Tensor | |
| ) # scalar β sum of model logprobs over action_ids at sample time | |
| reward: float | |
| class Trajectory: | |
| steps: List[Step] | |
| total_reward: float | |
| def generate_action( | |
| model, | |
| tokenizer, | |
| prompt: str, | |
| device: torch.device, | |
| ) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Sample one command from the model; return (text, prompt_ids, action_ids, logprob_sum).""" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| prompt_ids = inputs["input_ids"][0] | |
| out = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=True, | |
| temperature=TEMPERATURE, | |
| top_p=0.95, | |
| pad_token_id=tokenizer.eos_token_id, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| full_ids = out.sequences[0] | |
| action_ids = full_ids[prompt_ids.size(0) :] | |
| # Gather per-token logprobs of the sampled tokens (at generation time) | |
| if out.scores: | |
| logits = torch.stack(out.scores, dim=0) # [T, 1, V] | |
| logprobs = torch.log_softmax(logits, dim=-1)[:, 0, :] | |
| token_lp = logprobs.gather(1, action_ids.unsqueeze(-1)).squeeze(-1) | |
| logprob_sum = token_lp.sum() | |
| else: | |
| logprob_sum = torch.tensor(0.0, device=device) | |
| text = tokenizer.decode(action_ids, skip_special_tokens=True) | |
| return text, prompt_ids, action_ids, logprob_sum | |
| async def run_single_rollout( | |
| env: AwsRlEnv, | |
| task: Task, | |
| model, | |
| tokenizer, | |
| device: torch.device, | |
| ) -> Trajectory: | |
| """Drive one env through up to MAX_EPISODE_STEPS, recording every step.""" | |
| result = await env.reset(task=task) | |
| history: List[Tuple[str, str]] = [] | |
| steps: List[Step] = [] | |
| total_reward = 0.0 | |
| for _ in range(MAX_EPISODE_STEPS): | |
| prompt = build_prompt(tokenizer, task, history) | |
| raw, prompt_ids, action_ids, logprob_sum = generate_action( | |
| model, tokenizer, prompt, device | |
| ) | |
| command = extract_command(raw) | |
| result = await env.step(AwsRlAction(command=command)) | |
| reward = float(result.reward) | |
| total_reward += reward | |
| steps.append( | |
| Step( | |
| prompt_ids=prompt_ids.cpu(), | |
| action_ids=action_ids.cpu(), | |
| logprob_sum=logprob_sum.detach().cpu(), | |
| reward=reward, | |
| ) | |
| ) | |
| history.append((command, result.observation.command_output or "")) | |
| if result.done: | |
| break | |
| return Trajectory(steps=steps, total_reward=total_reward) | |
| # --------------------------------------------------------------------------- | |
| # GRPO loss β group-normalized advantages, PPO-style clipped ratio | |
| # --------------------------------------------------------------------------- | |
| def compute_group_advantages(rewards: List[float]) -> List[float]: | |
| """Core GRPO step: subtract the group mean and divide by group std. | |
| A_i = (R_i - mean(R_1..G)) / (std(R_1..G) + eps) | |
| This makes the "baseline" the group's own performance β no value network | |
| needed. If all G rollouts tied, advantages are zero (no signal, correct). | |
| """ | |
| mean = sum(rewards) / len(rewards) | |
| var = sum((r - mean) ** 2 for r in rewards) / len(rewards) | |
| std = math.sqrt(var) + 1e-8 | |
| return [(r - mean) / std for r in rewards] | |
| def logprob_under_current_model( | |
| model, tokenizer, step: Step, device: torch.device | |
| ) -> torch.Tensor: | |
| """Re-score the sampled action under the CURRENT policy (for gradient). | |
| At rollout time we recorded the old policy's logprob_sum. To get a | |
| differentiable ratio we have to recompute it now with the current weights. | |
| """ | |
| full = torch.cat([step.prompt_ids, step.action_ids]).unsqueeze(0).to(device) | |
| attn = torch.ones_like(full) | |
| outputs = model(input_ids=full, attention_mask=attn) | |
| logits = outputs.logits[0, :-1, :] # predict next token | |
| targets = full[0, 1:] | |
| prompt_len = step.prompt_ids.size(0) | |
| # Only the action tokens contribute to the loss | |
| action_logits = logits[prompt_len - 1 : prompt_len - 1 + step.action_ids.size(0)] | |
| action_targets = targets[prompt_len - 1 : prompt_len - 1 + step.action_ids.size(0)] | |
| logp = F.log_softmax(action_logits, dim=-1) | |
| token_logp = logp.gather(1, action_targets.unsqueeze(-1)).squeeze(-1) | |
| return token_logp.sum() | |
| def grpo_loss( | |
| model, | |
| tokenizer, | |
| trajectories: List[Trajectory], | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| """GRPO objective: maximize clipped advantage-weighted logprob ratio. | |
| loss = -mean_i [ min(ratio_i * A_i, clip(ratio_i, 1-eps, 1+eps) * A_i) ] | |
| """ | |
| rewards = [t.total_reward for t in trajectories] | |
| advantages = compute_group_advantages(rewards) | |
| losses: List[torch.Tensor] = [] | |
| for traj, adv in zip(trajectories, advantages): | |
| if not traj.steps: | |
| continue | |
| adv_t = torch.tensor(adv, device=device, dtype=torch.float32) | |
| for step in traj.steps: | |
| new_logp = logprob_under_current_model(model, tokenizer, step, device) | |
| old_logp = step.logprob_sum.to(device) | |
| ratio = torch.exp(new_logp - old_logp) | |
| unclipped = ratio * adv_t | |
| clipped = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * adv_t | |
| losses.append(-torch.min(unclipped, clipped)) | |
| if not losses: | |
| return torch.tensor(0.0, device=device, requires_grad=True) | |
| return torch.stack(losses).mean() | |
| # --------------------------------------------------------------------------- | |
| # Main training loop | |
| # --------------------------------------------------------------------------- | |
| async def train() -> None: | |
| model, tokenizer = load_model_and_tokenizer() | |
| device = next(model.parameters()).device | |
| optimizer = torch.optim.AdamW( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=LEARNING_RATE, | |
| ) | |
| curriculum = Curriculum() | |
| async with GrpoPool( | |
| base_url=BASE_URL, size=GROUP_SIZE, curriculum=curriculum | |
| ) as pool: | |
| logger.info("Connected pool of %d envs against %s", GROUP_SIZE, BASE_URL) | |
| for step_idx in range(NUM_GRPO_STEPS): | |
| # 1) central curriculum picks ONE task for the whole group | |
| task = curriculum.next_task() | |
| logger.info( | |
| "[step %d/%d] task_id=%d tier=%s", | |
| step_idx + 1, | |
| NUM_GRPO_STEPS, | |
| task.task_id, | |
| task.difficulty.value, | |
| ) | |
| # 2) launch G parallel rollouts, all on the same task_id | |
| rollout_coros = [ | |
| run_single_rollout(e, task, model, tokenizer, device) for e in pool.envs | |
| ] | |
| trajectories = await asyncio.gather(*rollout_coros) | |
| rewards = [t.total_reward for t in trajectories] | |
| logger.info( | |
| " rewards: min=%.3f mean=%.3f max=%.3f", | |
| min(rewards), | |
| sum(rewards) / len(rewards), | |
| max(rewards), | |
| ) | |
| # 3) GRPO loss + update | |
| model.train() | |
| loss = grpo_loss(model, tokenizer, trajectories, device) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| logger.info(" loss=%.4f", loss.item()) | |
| # 4) feed result back to curriculum (one record per group, not per rollout) | |
| pool.record_group_result(task, rewards) | |
| # Save LoRA adapter | |
| output_dir = os.getenv("GRPO_OUTPUT_DIR", "./grpo_lora_out") | |
| model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| logger.info("Saved LoRA adapter to %s", output_dir) | |
| if __name__ == "__main__": | |
| asyncio.run(train()) | |