File size: 14,035 Bytes
e56d042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
"""End-to-end GRPO training on AWS RL Env, driven from Google Colab.

This script is pedagogical β€” it shows how the moving pieces connect:

    [Central Curriculum] --picks task_id--> [G parallel rollouts via GrpoPool]
                                                     |
                                                     v
                          [Per-rollout trajectory of (prompt, action, reward)]
                                                     |
                                                     v
                          [Group-normalized advantages: A_i = (R_i - mean) / std]
                                                     |
                                                     v
                          [PPO-style policy-gradient loss on logprobs]
                                                     |
                                                     v
                                             [Optimizer step]

Why this is "GRPO" and not vanilla REINFORCE:
    GRPO (Group Relative Policy Optimization, DeepSeek) replaces the value
    baseline with the **group mean** of rewards. For each task we sample G
    trajectories; the advantage for rollout i is A_i = (R_i - mean(R)) / std(R).
    This is variance-reduced and critic-free β€” perfect for our env where G=8.

Requirements (install in the Colab cell before running):
    !pip install unsloth trl torch transformers accelerate bitsandbytes httpx websockets

Prerequisites:
    The RL env server must be running somewhere Colab can reach, with
    AWS_RL_ENV_POOL_SIZE=8 set. Easiest:
        docker run -p 8000:8000 -e AWS_RL_ENV_POOL_SIZE=8 aws-rl-env:latest
    And expose port 8000 via `cloudflared tunnel` or `ngrok http 8000`.

    Set BASE_URL below to the public URL of that tunnel.

Run:
    python scripts/grpo_train.py
"""

from __future__ import annotations

import asyncio
import logging
import math
import os
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn.functional as F

from client import AwsRlEnv
from models import AwsRlAction, Task
from scripts.grpo_pool import GrpoPool
from server.services.curriculum import Curriculum

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Config β€” tune for your setup
# ---------------------------------------------------------------------------

BASE_URL = os.getenv("AWS_RL_ENV_BASE_URL", "http://localhost:8000")
GROUP_SIZE = int(os.getenv("GRPO_GROUP_SIZE", "8"))  # G in GRPO
NUM_GRPO_STEPS = int(os.getenv("GRPO_NUM_STEPS", "100"))  # outer training steps
MAX_EPISODE_STEPS = int(
    os.getenv("GRPO_MAX_STEPS", "15")
)  # per-rollout step cap (matches MAX_STEPS in env)
LEARNING_RATE = float(os.getenv("GRPO_LR", "5e-6"))
KL_COEFF = float(os.getenv("GRPO_KL", "0.04"))  # KL penalty vs reference model
CLIP_EPS = float(os.getenv("GRPO_CLIP", "0.2"))  # PPO clip for stability
TEMPERATURE = float(os.getenv("GRPO_TEMP", "0.9"))
MAX_NEW_TOKENS = int(os.getenv("GRPO_MAX_NEW", "96"))  # per model generation
MODEL_NAME = os.getenv("GRPO_MODEL", "Qwen/Qwen2.5-1.5B-Instruct")


# ---------------------------------------------------------------------------
# Model loading (Unsloth β€” 4-bit LoRA, Colab-friendly)
# ---------------------------------------------------------------------------


def load_model_and_tokenizer():
    """Load a 4-bit LoRA-wrapped model via Unsloth.

    Unsloth is a drop-in replacement for transformers that ~2x speeds up
    fine-tuning on a single GPU and fits a 1.5B model on a free Colab T4.
    """
    from unsloth import FastLanguageModel

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=4096,
        load_in_4bit=True,
    )
    # LoRA wrapping β€” only these params receive gradients
    model = FastLanguageModel.get_peft_model(
        model,
        r=16,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.0,
        bias="none",
        use_gradient_checkpointing="unsloth",
    )
    FastLanguageModel.for_training(model)
    return model, tokenizer


# ---------------------------------------------------------------------------
# Prompt construction & action extraction
# ---------------------------------------------------------------------------

SYSTEM_PROMPT = """You are an expert AWS SRE agent. You operate a simulated AWS cloud by \
emitting one AWS CLI command at a time. You will see the task description and the most \
recent command output, then reply with EXACTLY ONE AWS CLI command on a single line \
starting with 'aws '. No explanation, no markdown, no quotes β€” just the command."""


def build_prompt(tokenizer, task: Task, history: List[Tuple[str, str]]) -> str:
    """Build a chat prompt from the task + command/output history."""
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"TASK: {task.description}"},
    ]
    for cmd, out in history[-4:]:  # keep last 4 turns to fit context
        messages.append({"role": "assistant", "content": cmd})
        messages.append({"role": "user", "content": f"OUTPUT:\n{out[:400]}"})
    return tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )


def extract_command(raw: str) -> str:
    """Pull the first `aws …` line out of the model's raw decoded output."""
    for line in raw.splitlines():
        line = line.strip().strip("`").strip()
        if line.startswith("aws "):
            return line
    return "aws help"  # safe fallback so env always accepts the command


# ---------------------------------------------------------------------------
# Rollout β€” one trajectory, one env, one task
# ---------------------------------------------------------------------------


@dataclass
class Step:
    """One step of a trajectory. `prompt_ids` + `action_ids` are what we backprop on."""

    prompt_ids: torch.Tensor  # shape [prompt_len]
    action_ids: torch.Tensor  # shape [action_len]
    logprob_sum: (
        torch.Tensor
    )  # scalar β€” sum of model logprobs over action_ids at sample time
    reward: float


@dataclass
class Trajectory:
    steps: List[Step]
    total_reward: float


@torch.no_grad()
def generate_action(
    model,
    tokenizer,
    prompt: str,
    device: torch.device,
) -> Tuple[str, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Sample one command from the model; return (text, prompt_ids, action_ids, logprob_sum)."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    prompt_ids = inputs["input_ids"][0]
    out = model.generate(
        **inputs,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=TEMPERATURE,
        top_p=0.95,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
    )
    full_ids = out.sequences[0]
    action_ids = full_ids[prompt_ids.size(0) :]
    # Gather per-token logprobs of the sampled tokens (at generation time)
    if out.scores:
        logits = torch.stack(out.scores, dim=0)  # [T, 1, V]
        logprobs = torch.log_softmax(logits, dim=-1)[:, 0, :]
        token_lp = logprobs.gather(1, action_ids.unsqueeze(-1)).squeeze(-1)
        logprob_sum = token_lp.sum()
    else:
        logprob_sum = torch.tensor(0.0, device=device)
    text = tokenizer.decode(action_ids, skip_special_tokens=True)
    return text, prompt_ids, action_ids, logprob_sum


async def run_single_rollout(
    env: AwsRlEnv,
    task: Task,
    model,
    tokenizer,
    device: torch.device,
) -> Trajectory:
    """Drive one env through up to MAX_EPISODE_STEPS, recording every step."""
    result = await env.reset(task=task)
    history: List[Tuple[str, str]] = []
    steps: List[Step] = []
    total_reward = 0.0

    for _ in range(MAX_EPISODE_STEPS):
        prompt = build_prompt(tokenizer, task, history)
        raw, prompt_ids, action_ids, logprob_sum = generate_action(
            model, tokenizer, prompt, device
        )
        command = extract_command(raw)
        result = await env.step(AwsRlAction(command=command))
        reward = float(result.reward)
        total_reward += reward
        steps.append(
            Step(
                prompt_ids=prompt_ids.cpu(),
                action_ids=action_ids.cpu(),
                logprob_sum=logprob_sum.detach().cpu(),
                reward=reward,
            )
        )
        history.append((command, result.observation.command_output or ""))
        if result.done:
            break

    return Trajectory(steps=steps, total_reward=total_reward)


# ---------------------------------------------------------------------------
# GRPO loss β€” group-normalized advantages, PPO-style clipped ratio
# ---------------------------------------------------------------------------


def compute_group_advantages(rewards: List[float]) -> List[float]:
    """Core GRPO step: subtract the group mean and divide by group std.

    A_i = (R_i - mean(R_1..G)) / (std(R_1..G) + eps)

    This makes the "baseline" the group's own performance β€” no value network
    needed. If all G rollouts tied, advantages are zero (no signal, correct).
    """
    mean = sum(rewards) / len(rewards)
    var = sum((r - mean) ** 2 for r in rewards) / len(rewards)
    std = math.sqrt(var) + 1e-8
    return [(r - mean) / std for r in rewards]


def logprob_under_current_model(
    model, tokenizer, step: Step, device: torch.device
) -> torch.Tensor:
    """Re-score the sampled action under the CURRENT policy (for gradient).

    At rollout time we recorded the old policy's logprob_sum. To get a
    differentiable ratio we have to recompute it now with the current weights.
    """
    full = torch.cat([step.prompt_ids, step.action_ids]).unsqueeze(0).to(device)
    attn = torch.ones_like(full)
    outputs = model(input_ids=full, attention_mask=attn)
    logits = outputs.logits[0, :-1, :]  # predict next token
    targets = full[0, 1:]
    prompt_len = step.prompt_ids.size(0)
    # Only the action tokens contribute to the loss
    action_logits = logits[prompt_len - 1 : prompt_len - 1 + step.action_ids.size(0)]
    action_targets = targets[prompt_len - 1 : prompt_len - 1 + step.action_ids.size(0)]
    logp = F.log_softmax(action_logits, dim=-1)
    token_logp = logp.gather(1, action_targets.unsqueeze(-1)).squeeze(-1)
    return token_logp.sum()


def grpo_loss(
    model,
    tokenizer,
    trajectories: List[Trajectory],
    device: torch.device,
) -> torch.Tensor:
    """GRPO objective: maximize clipped advantage-weighted logprob ratio.

    loss = -mean_i [ min(ratio_i * A_i, clip(ratio_i, 1-eps, 1+eps) * A_i) ]
    """
    rewards = [t.total_reward for t in trajectories]
    advantages = compute_group_advantages(rewards)

    losses: List[torch.Tensor] = []
    for traj, adv in zip(trajectories, advantages):
        if not traj.steps:
            continue
        adv_t = torch.tensor(adv, device=device, dtype=torch.float32)
        for step in traj.steps:
            new_logp = logprob_under_current_model(model, tokenizer, step, device)
            old_logp = step.logprob_sum.to(device)
            ratio = torch.exp(new_logp - old_logp)
            unclipped = ratio * adv_t
            clipped = torch.clamp(ratio, 1 - CLIP_EPS, 1 + CLIP_EPS) * adv_t
            losses.append(-torch.min(unclipped, clipped))

    if not losses:
        return torch.tensor(0.0, device=device, requires_grad=True)
    return torch.stack(losses).mean()


# ---------------------------------------------------------------------------
# Main training loop
# ---------------------------------------------------------------------------


async def train() -> None:
    model, tokenizer = load_model_and_tokenizer()
    device = next(model.parameters()).device
    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=LEARNING_RATE,
    )

    curriculum = Curriculum()
    async with GrpoPool(
        base_url=BASE_URL, size=GROUP_SIZE, curriculum=curriculum
    ) as pool:
        logger.info("Connected pool of %d envs against %s", GROUP_SIZE, BASE_URL)

        for step_idx in range(NUM_GRPO_STEPS):
            # 1) central curriculum picks ONE task for the whole group
            task = curriculum.next_task()
            logger.info(
                "[step %d/%d] task_id=%d tier=%s",
                step_idx + 1,
                NUM_GRPO_STEPS,
                task.task_id,
                task.difficulty.value,
            )

            # 2) launch G parallel rollouts, all on the same task_id
            rollout_coros = [
                run_single_rollout(e, task, model, tokenizer, device) for e in pool.envs
            ]
            trajectories = await asyncio.gather(*rollout_coros)
            rewards = [t.total_reward for t in trajectories]
            logger.info(
                "  rewards: min=%.3f mean=%.3f max=%.3f",
                min(rewards),
                sum(rewards) / len(rewards),
                max(rewards),
            )

            # 3) GRPO loss + update
            model.train()
            loss = grpo_loss(model, tokenizer, trajectories, device)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            logger.info("  loss=%.4f", loss.item())

            # 4) feed result back to curriculum (one record per group, not per rollout)
            pool.record_group_result(task, rewards)

    # Save LoRA adapter
    output_dir = os.getenv("GRPO_OUTPUT_DIR", "./grpo_lora_out")
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    logger.info("Saved LoRA adapter to %s", output_dir)


if __name__ == "__main__":
    asyncio.run(train())