Spaces:
Running
feat: RFC 005 interactive rollout wrapper + multi-turn GRPO training
Browse filesrollout_wrapper.py:
- run_episode() runs a full interactive episode via vLLM
- model generates ONE tool call at a time, sees tool result, then decides next
- captures (context, completion, logprobs) per turn as a Trajectory
- true reactive multi-turn β not blind planning
train_rfc005.py:
- collects N_EPISODES in parallel via ThreadPoolExecutor
- re-scores each turn with HF model for differentiable logprobs
- GRPO loss = -advantage * sum(logprobs across all turns in episode)
- Unsloth syncs HF weights β vLLM after each optimizer.step() automatically
Upgrade from train.py:
before: model generates all tool calls at once, never sees results
now: model reacts to each tool result before deciding the next call
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- training/rollout_wrapper.py +197 -0
- training/train_rfc005.py +160 -0
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RFC 005 interactive rollout wrapper.
|
| 3 |
+
|
| 4 |
+
Runs a full multi-turn episode where the model sees tool results at each step.
|
| 5 |
+
Unlike the single-completion approach in train.py, the model:
|
| 6 |
+
- generates ONE tool call at a time
|
| 7 |
+
- sees the actual result before deciding the next move
|
| 8 |
+
- is reactive, not planning blind
|
| 9 |
+
|
| 10 |
+
Returns a Trajectory: list of (context, completion, logprobs) per turn + final reward.
|
| 11 |
+
The training loop re-scores each turn with the HF model to get differentiable logprobs
|
| 12 |
+
and computes GRPO loss across the full trajectory.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import requests
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
|
| 20 |
+
ENV_URL = os.environ.get("ENV_URL", "https://http--moa-rl-env--7b2fgcxb6gxp.code.run")
|
| 21 |
+
VLLM_URL = os.environ.get("VLLM_URL", "http://localhost:8001")
|
| 22 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/gpt-oss-20b-instruct")
|
| 23 |
+
MAX_TURNS = 8
|
| 24 |
+
TIMEOUT = 120
|
| 25 |
+
|
| 26 |
+
SYSTEM_PROMPT = """\
|
| 27 |
+
You are a TypeScript coding agent. Fix broken source files using tools.
|
| 28 |
+
|
| 29 |
+
Emit exactly ONE tool call per response as a JSON object on its own line:
|
| 30 |
+
{"tool": "read", "params": {"path": "src/foo.ts"}}
|
| 31 |
+
{"tool": "edit", "params": {"path": "src/foo.ts", "old_string": "...", "new_string": "..."}}
|
| 32 |
+
{"tool": "bash", "params": {"cmd": "npx tsc --noEmit 2>&1 | head -10"}}
|
| 33 |
+
{"tool": "submit", "params": {}}
|
| 34 |
+
|
| 35 |
+
One JSON object. No prose. No markdown fences.\
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class Turn:
|
| 41 |
+
"""One model generation step within an episode."""
|
| 42 |
+
messages: list[dict] # full conversation context fed into this generation
|
| 43 |
+
completion: str # what the model generated
|
| 44 |
+
logprobs: list[float] # per-token logprobs returned by vLLM (for reference)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class Trajectory:
|
| 49 |
+
"""A complete episode: sequence of turns + final reward."""
|
| 50 |
+
turns: list[Turn] = field(default_factory=list)
|
| 51 |
+
reward: float = 0.0
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ββ env helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
|
| 56 |
+
def _env_reset() -> dict:
|
| 57 |
+
r = requests.post(f"{ENV_URL}/reset", json={}, timeout=TIMEOUT)
|
| 58 |
+
r.raise_for_status()
|
| 59 |
+
raw = r.json()
|
| 60 |
+
return raw.get("observation", raw)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _env_step(tool: str, params: dict) -> dict:
|
| 64 |
+
r = requests.post(
|
| 65 |
+
f"{ENV_URL}/step",
|
| 66 |
+
json={"action": {"tool": tool, "params": params}},
|
| 67 |
+
timeout=TIMEOUT,
|
| 68 |
+
)
|
| 69 |
+
r.raise_for_status()
|
| 70 |
+
raw = r.json()
|
| 71 |
+
obs = raw.get("observation", raw)
|
| 72 |
+
obs["reward"] = raw.get("reward", 0.0)
|
| 73 |
+
return obs
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ββ vLLM generation ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
def _vllm_generate(messages: list[dict]) -> tuple[str, list[float]]:
|
| 79 |
+
"""
|
| 80 |
+
Call vLLM with logprobs=True.
|
| 81 |
+
Returns (completion_text, per_token_logprobs).
|
| 82 |
+
"""
|
| 83 |
+
r = requests.post(
|
| 84 |
+
f"{VLLM_URL}/v1/chat/completions",
|
| 85 |
+
json={
|
| 86 |
+
"model": MODEL_NAME,
|
| 87 |
+
"messages": messages,
|
| 88 |
+
"max_tokens": 256,
|
| 89 |
+
"temperature": 0.7,
|
| 90 |
+
"logprobs": True,
|
| 91 |
+
"top_logprobs": 1,
|
| 92 |
+
},
|
| 93 |
+
timeout=TIMEOUT,
|
| 94 |
+
)
|
| 95 |
+
r.raise_for_status()
|
| 96 |
+
result = r.json()
|
| 97 |
+
choice = result["choices"][0]
|
| 98 |
+
text = choice["message"]["content"]
|
| 99 |
+
lp_data = choice.get("logprobs", {}).get("content", [])
|
| 100 |
+
logprobs = [entry["logprob"] for entry in lp_data] if lp_data else []
|
| 101 |
+
return text, logprobs
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ββ prompt helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
|
| 106 |
+
def _initial_messages(obs: dict) -> list[dict]:
|
| 107 |
+
user_msgs = obs.get("user_messages", [])
|
| 108 |
+
ctx = ""
|
| 109 |
+
if user_msgs:
|
| 110 |
+
ctx = "User messages that triggered this task:\n"
|
| 111 |
+
ctx += "\n".join(f" > {m}" for m in user_msgs) + "\n\n"
|
| 112 |
+
|
| 113 |
+
content = (
|
| 114 |
+
f"{ctx}"
|
| 115 |
+
f"Task: {obs['task']}\n\n"
|
| 116 |
+
f"File to fix: {obs['broken_file_path']}\n\n"
|
| 117 |
+
"Tests that must pass:\n"
|
| 118 |
+
f"```ts\n{obs.get('test_file_content', '')[:1500]}\n```\n\n"
|
| 119 |
+
"Start by reading the file."
|
| 120 |
+
)
|
| 121 |
+
return [
|
| 122 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 123 |
+
{"role": "user", "content": content},
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _parse_tool_call(text: str) -> tuple[str, dict] | None:
|
| 128 |
+
for line in text.splitlines():
|
| 129 |
+
line = line.strip()
|
| 130 |
+
if not line.startswith("{"):
|
| 131 |
+
continue
|
| 132 |
+
try:
|
| 133 |
+
obj = json.loads(line)
|
| 134 |
+
if "tool" in obj and "params" in obj:
|
| 135 |
+
return obj["tool"], obj["params"]
|
| 136 |
+
except json.JSONDecodeError:
|
| 137 |
+
pass
|
| 138 |
+
return None
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ββ episode runner ββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββββββββββ
|
| 142 |
+
|
| 143 |
+
def run_episode() -> Trajectory:
|
| 144 |
+
"""
|
| 145 |
+
Run one full interactive episode.
|
| 146 |
+
|
| 147 |
+
At each turn the model sees all previous tool results β true reactive multi-turn.
|
| 148 |
+
Captures logprobs at every generation step so GRPO loss can be computed
|
| 149 |
+
across the full trajectory.
|
| 150 |
+
|
| 151 |
+
Difference from single-completion train.py:
|
| 152 |
+
Before: model generates ALL tool calls blindly upfront
|
| 153 |
+
Now: model generates ONE tool call, sees the result, then decides next move
|
| 154 |
+
"""
|
| 155 |
+
traj = Trajectory()
|
| 156 |
+
obs = _env_reset()
|
| 157 |
+
messages = _initial_messages(obs)
|
| 158 |
+
|
| 159 |
+
for _ in range(MAX_TURNS):
|
| 160 |
+
completion, logprobs = _vllm_generate(messages)
|
| 161 |
+
|
| 162 |
+
traj.turns.append(Turn(
|
| 163 |
+
messages = list(messages), # snapshot of context at this step
|
| 164 |
+
completion = completion,
|
| 165 |
+
logprobs = logprobs,
|
| 166 |
+
))
|
| 167 |
+
|
| 168 |
+
parsed = _parse_tool_call(completion)
|
| 169 |
+
if parsed is None:
|
| 170 |
+
# Model produced no valid tool call β end with zero reward
|
| 171 |
+
traj.reward = 0.0
|
| 172 |
+
return traj
|
| 173 |
+
|
| 174 |
+
tool, params = parsed
|
| 175 |
+
|
| 176 |
+
# Append model turn to conversation
|
| 177 |
+
messages.append({"role": "assistant", "content": completion})
|
| 178 |
+
|
| 179 |
+
# Execute against env
|
| 180 |
+
step_obs = _env_step(tool, params)
|
| 181 |
+
done = step_obs.get("done", False)
|
| 182 |
+
|
| 183 |
+
if done:
|
| 184 |
+
traj.reward = step_obs.get("reward", 0.0)
|
| 185 |
+
return traj
|
| 186 |
+
|
| 187 |
+
# Feed tool result back so model can react to it
|
| 188 |
+
tool_result = step_obs.get("tool_result", "")
|
| 189 |
+
messages.append({
|
| 190 |
+
"role": "user",
|
| 191 |
+
"content": f"[{tool} result]\n{tool_result}",
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
# Max turns hit β force submit
|
| 195 |
+
obs_final = _env_step("submit", {})
|
| 196 |
+
traj.reward = obs_final.get("reward", 0.0)
|
| 197 |
+
return traj
|
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
RFC 005 training loop β true interactive multi-turn GRPO.
|
| 3 |
+
|
| 4 |
+
The model generates one tool call at a time and sees tool results before
|
| 5 |
+
deciding the next move. This is what train.py can't do with standard GRPOTrainer.
|
| 6 |
+
|
| 7 |
+
How it works:
|
| 8 |
+
1. rollout_wrapper.run_episode() runs N parallel episodes via vLLM
|
| 9 |
+
- at each turn: generate β execute tool β inject result β continue
|
| 10 |
+
- captures (context, completion, vllm_logprobs) per turn
|
| 11 |
+
2. HF model re-scores each turn: forward pass on (context, completion)
|
| 12 |
+
β differentiable token logprobs
|
| 13 |
+
3. GRPO loss:
|
| 14 |
+
advantage_i = (reward_i - mean_reward) / (std_reward + 1e-8)
|
| 15 |
+
loss = -mean( advantage_i * sum(logprob of tokens in turn t, for all t in episode i) )
|
| 16 |
+
4. optimizer.step()
|
| 17 |
+
5. Unsloth syncs updated HF weights β vLLM automatically
|
| 18 |
+
|
| 19 |
+
The key upgrade over train.py:
|
| 20 |
+
train.py β model plans blind (generates all tool calls at once, never sees results)
|
| 21 |
+
this file β model reacts (one call at a time, sees actual output each step)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 28 |
+
from unsloth import FastLanguageModel
|
| 29 |
+
|
| 30 |
+
from rollout_wrapper import run_episode, Trajectory
|
| 31 |
+
|
| 32 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/gpt-oss-20b-instruct")
|
| 33 |
+
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/output/moa-rl-grpo-rfc005")
|
| 34 |
+
N_EPISODES = int(os.environ.get("N_EPISODES", "4")) # episodes per training step (GRPO needs variance)
|
| 35 |
+
MAX_STEPS = int(os.environ.get("MAX_STEPS", "300"))
|
| 36 |
+
LR = float(os.environ.get("LR", "5e-6"))
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ββ model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 40 |
+
|
| 41 |
+
print(f"Loading {MODEL_NAME}...")
|
| 42 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 43 |
+
model_name = MODEL_NAME,
|
| 44 |
+
max_seq_length = 4096,
|
| 45 |
+
load_in_4bit = False,
|
| 46 |
+
dtype = torch.bfloat16,
|
| 47 |
+
)
|
| 48 |
+
model = FastLanguageModel.get_peft_model(
|
| 49 |
+
model,
|
| 50 |
+
r = 16,
|
| 51 |
+
lora_alpha = 16,
|
| 52 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
| 53 |
+
"gate_proj", "up_proj", "down_proj"],
|
| 54 |
+
use_gradient_checkpointing = "unsloth",
|
| 55 |
+
random_state = 42,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Start vLLM inside Unsloth (syncs weights automatically after each optimizer step)
|
| 59 |
+
from unsloth import PatchFastRL
|
| 60 |
+
PatchFastRL("GRPO", FastLanguageModel)
|
| 61 |
+
|
| 62 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ββ GRPO loss over a trajectory ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
|
| 67 |
+
def score_turn(messages: list[dict], completion: str) -> torch.Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Re-score one turn with the HF model to get differentiable token logprobs.
|
| 70 |
+
|
| 71 |
+
vLLM logprobs are used for episode collection (fast generation).
|
| 72 |
+
HF logprobs are used here for the actual gradient update.
|
| 73 |
+
"""
|
| 74 |
+
# Build input: format messages as a single string the model was trained on
|
| 75 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 76 |
+
messages,
|
| 77 |
+
tokenize = False,
|
| 78 |
+
add_generation_prompt = True,
|
| 79 |
+
)
|
| 80 |
+
full_text = prompt_text + completion
|
| 81 |
+
|
| 82 |
+
inputs = tokenizer(full_text, return_tensors="pt").to(model.device)
|
| 83 |
+
prompt_ids = tokenizer(prompt_text, return_tensors="pt")["input_ids"]
|
| 84 |
+
prompt_len = prompt_ids.shape[1]
|
| 85 |
+
|
| 86 |
+
with torch.no_grad() if not model.training else torch.enable_grad():
|
| 87 |
+
logits = model(**inputs).logits # (1, seq_len, vocab)
|
| 88 |
+
|
| 89 |
+
# Only score the completion tokens (not the prompt)
|
| 90 |
+
comp_logits = logits[0, prompt_len - 1 : -1, :] # (comp_len, vocab)
|
| 91 |
+
comp_ids = inputs["input_ids"][0, prompt_len:] # (comp_len,)
|
| 92 |
+
|
| 93 |
+
log_probs = F.log_softmax(comp_logits, dim=-1)
|
| 94 |
+
token_lps = log_probs[range(len(comp_ids)), comp_ids]
|
| 95 |
+
return token_lps.sum() # scalar: total logprob of this completion
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def grpo_loss(trajectories: list[Trajectory]) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Compute GRPO loss across N trajectories.
|
| 101 |
+
|
| 102 |
+
advantage_i = (reward_i - mean) / (std + 1e-8)
|
| 103 |
+
loss = -mean_i( advantage_i * sum_t( logprob(turn t in episode i) ) )
|
| 104 |
+
"""
|
| 105 |
+
rewards = torch.tensor([t.reward for t in trajectories], dtype=torch.float32)
|
| 106 |
+
mean_r = rewards.mean()
|
| 107 |
+
std_r = rewards.std() + 1e-8
|
| 108 |
+
advantages = (rewards - mean_r) / std_r
|
| 109 |
+
|
| 110 |
+
losses = []
|
| 111 |
+
for traj, adv in zip(trajectories, advantages):
|
| 112 |
+
# Sum logprobs across all turns in this episode
|
| 113 |
+
total_lp = sum(
|
| 114 |
+
score_turn(turn.messages, turn.completion)
|
| 115 |
+
for turn in traj.turns
|
| 116 |
+
)
|
| 117 |
+
losses.append(-adv * total_lp)
|
| 118 |
+
|
| 119 |
+
return torch.stack(losses).mean()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ββ training loop ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
+
|
| 124 |
+
print(f"RFC 005 training: {N_EPISODES} episodes/step Γ {MAX_STEPS} steps")
|
| 125 |
+
print(f"Model: {MODEL_NAME} β {OUTPUT_DIR}")
|
| 126 |
+
|
| 127 |
+
for step in range(MAX_STEPS):
|
| 128 |
+
model.train()
|
| 129 |
+
|
| 130 |
+
# Collect N episodes in parallel via vLLM
|
| 131 |
+
with ThreadPoolExecutor(max_workers=N_EPISODES) as pool:
|
| 132 |
+
trajectories = list(pool.map(lambda _: run_episode(), range(N_EPISODES)))
|
| 133 |
+
|
| 134 |
+
rewards = [t.reward for t in trajectories]
|
| 135 |
+
mean_r = sum(rewards) / len(rewards)
|
| 136 |
+
|
| 137 |
+
# GRPO loss + optimizer step
|
| 138 |
+
loss = grpo_loss(trajectories)
|
| 139 |
+
optimizer.zero_grad()
|
| 140 |
+
loss.backward()
|
| 141 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 142 |
+
optimizer.step()
|
| 143 |
+
|
| 144 |
+
# Unsloth automatically syncs updated weights β vLLM after optimizer.step()
|
| 145 |
+
|
| 146 |
+
print(
|
| 147 |
+
f"step {step+1:4d}/{MAX_STEPS} | "
|
| 148 |
+
f"loss {loss.item():.4f} | "
|
| 149 |
+
f"rewards {[f'{r:.2f}' for r in rewards]} | "
|
| 150 |
+
f"mean {mean_r:.3f}"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if (step + 1) % 50 == 0:
|
| 154 |
+
model.save_pretrained(f"{OUTPUT_DIR}/step-{step+1}")
|
| 155 |
+
tokenizer.save_pretrained(f"{OUTPUT_DIR}/step-{step+1}")
|
| 156 |
+
print(f" β checkpoint saved")
|
| 157 |
+
|
| 158 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 159 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 160 |
+
print(f"Done. Saved to {OUTPUT_DIR}")
|