natnael kahssay Claude Sonnet 4.6 commited on
Commit
ded7690
Β·
1 Parent(s): bb5a5ec

feat: RFC 005 interactive rollout wrapper + multi-turn GRPO training

Browse files

rollout_wrapper.py:
- run_episode() runs a full interactive episode via vLLM
- model generates ONE tool call at a time, sees tool result, then decides next
- captures (context, completion, logprobs) per turn as a Trajectory
- true reactive multi-turn β€” not blind planning

train_rfc005.py:
- collects N_EPISODES in parallel via ThreadPoolExecutor
- re-scores each turn with HF model for differentiable logprobs
- GRPO loss = -advantage * sum(logprobs across all turns in episode)
- Unsloth syncs HF weights β†’ vLLM after each optimizer.step() automatically

Upgrade from train.py:
before: model generates all tool calls at once, never sees results
now: model reacts to each tool result before deciding the next call

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

training/rollout_wrapper.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RFC 005 interactive rollout wrapper.
3
+
4
+ Runs a full multi-turn episode where the model sees tool results at each step.
5
+ Unlike the single-completion approach in train.py, the model:
6
+ - generates ONE tool call at a time
7
+ - sees the actual result before deciding the next move
8
+ - is reactive, not planning blind
9
+
10
+ Returns a Trajectory: list of (context, completion, logprobs) per turn + final reward.
11
+ The training loop re-scores each turn with the HF model to get differentiable logprobs
12
+ and computes GRPO loss across the full trajectory.
13
+ """
14
+
15
+ import json
16
+ import os
17
+ import requests
18
+ from dataclasses import dataclass, field
19
+
20
+ ENV_URL = os.environ.get("ENV_URL", "https://http--moa-rl-env--7b2fgcxb6gxp.code.run")
21
+ VLLM_URL = os.environ.get("VLLM_URL", "http://localhost:8001")
22
+ MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/gpt-oss-20b-instruct")
23
+ MAX_TURNS = 8
24
+ TIMEOUT = 120
25
+
26
+ SYSTEM_PROMPT = """\
27
+ You are a TypeScript coding agent. Fix broken source files using tools.
28
+
29
+ Emit exactly ONE tool call per response as a JSON object on its own line:
30
+ {"tool": "read", "params": {"path": "src/foo.ts"}}
31
+ {"tool": "edit", "params": {"path": "src/foo.ts", "old_string": "...", "new_string": "..."}}
32
+ {"tool": "bash", "params": {"cmd": "npx tsc --noEmit 2>&1 | head -10"}}
33
+ {"tool": "submit", "params": {}}
34
+
35
+ One JSON object. No prose. No markdown fences.\
36
+ """
37
+
38
+
39
+ @dataclass
40
+ class Turn:
41
+ """One model generation step within an episode."""
42
+ messages: list[dict] # full conversation context fed into this generation
43
+ completion: str # what the model generated
44
+ logprobs: list[float] # per-token logprobs returned by vLLM (for reference)
45
+
46
+
47
+ @dataclass
48
+ class Trajectory:
49
+ """A complete episode: sequence of turns + final reward."""
50
+ turns: list[Turn] = field(default_factory=list)
51
+ reward: float = 0.0
52
+
53
+
54
+ # ── env helpers ────────────────────────────────────────────────────────────────
55
+
56
+ def _env_reset() -> dict:
57
+ r = requests.post(f"{ENV_URL}/reset", json={}, timeout=TIMEOUT)
58
+ r.raise_for_status()
59
+ raw = r.json()
60
+ return raw.get("observation", raw)
61
+
62
+
63
+ def _env_step(tool: str, params: dict) -> dict:
64
+ r = requests.post(
65
+ f"{ENV_URL}/step",
66
+ json={"action": {"tool": tool, "params": params}},
67
+ timeout=TIMEOUT,
68
+ )
69
+ r.raise_for_status()
70
+ raw = r.json()
71
+ obs = raw.get("observation", raw)
72
+ obs["reward"] = raw.get("reward", 0.0)
73
+ return obs
74
+
75
+
76
+ # ── vLLM generation ────────────────────────────────────────────────────────────
77
+
78
+ def _vllm_generate(messages: list[dict]) -> tuple[str, list[float]]:
79
+ """
80
+ Call vLLM with logprobs=True.
81
+ Returns (completion_text, per_token_logprobs).
82
+ """
83
+ r = requests.post(
84
+ f"{VLLM_URL}/v1/chat/completions",
85
+ json={
86
+ "model": MODEL_NAME,
87
+ "messages": messages,
88
+ "max_tokens": 256,
89
+ "temperature": 0.7,
90
+ "logprobs": True,
91
+ "top_logprobs": 1,
92
+ },
93
+ timeout=TIMEOUT,
94
+ )
95
+ r.raise_for_status()
96
+ result = r.json()
97
+ choice = result["choices"][0]
98
+ text = choice["message"]["content"]
99
+ lp_data = choice.get("logprobs", {}).get("content", [])
100
+ logprobs = [entry["logprob"] for entry in lp_data] if lp_data else []
101
+ return text, logprobs
102
+
103
+
104
+ # ── prompt helpers ─────────────────────────────────────────────────────────────
105
+
106
+ def _initial_messages(obs: dict) -> list[dict]:
107
+ user_msgs = obs.get("user_messages", [])
108
+ ctx = ""
109
+ if user_msgs:
110
+ ctx = "User messages that triggered this task:\n"
111
+ ctx += "\n".join(f" > {m}" for m in user_msgs) + "\n\n"
112
+
113
+ content = (
114
+ f"{ctx}"
115
+ f"Task: {obs['task']}\n\n"
116
+ f"File to fix: {obs['broken_file_path']}\n\n"
117
+ "Tests that must pass:\n"
118
+ f"```ts\n{obs.get('test_file_content', '')[:1500]}\n```\n\n"
119
+ "Start by reading the file."
120
+ )
121
+ return [
122
+ {"role": "system", "content": SYSTEM_PROMPT},
123
+ {"role": "user", "content": content},
124
+ ]
125
+
126
+
127
+ def _parse_tool_call(text: str) -> tuple[str, dict] | None:
128
+ for line in text.splitlines():
129
+ line = line.strip()
130
+ if not line.startswith("{"):
131
+ continue
132
+ try:
133
+ obj = json.loads(line)
134
+ if "tool" in obj and "params" in obj:
135
+ return obj["tool"], obj["params"]
136
+ except json.JSONDecodeError:
137
+ pass
138
+ return None
139
+
140
+
141
+ # ── episode runner ──────────────────────────────────────────��──────────────────
142
+
143
+ def run_episode() -> Trajectory:
144
+ """
145
+ Run one full interactive episode.
146
+
147
+ At each turn the model sees all previous tool results β€” true reactive multi-turn.
148
+ Captures logprobs at every generation step so GRPO loss can be computed
149
+ across the full trajectory.
150
+
151
+ Difference from single-completion train.py:
152
+ Before: model generates ALL tool calls blindly upfront
153
+ Now: model generates ONE tool call, sees the result, then decides next move
154
+ """
155
+ traj = Trajectory()
156
+ obs = _env_reset()
157
+ messages = _initial_messages(obs)
158
+
159
+ for _ in range(MAX_TURNS):
160
+ completion, logprobs = _vllm_generate(messages)
161
+
162
+ traj.turns.append(Turn(
163
+ messages = list(messages), # snapshot of context at this step
164
+ completion = completion,
165
+ logprobs = logprobs,
166
+ ))
167
+
168
+ parsed = _parse_tool_call(completion)
169
+ if parsed is None:
170
+ # Model produced no valid tool call β€” end with zero reward
171
+ traj.reward = 0.0
172
+ return traj
173
+
174
+ tool, params = parsed
175
+
176
+ # Append model turn to conversation
177
+ messages.append({"role": "assistant", "content": completion})
178
+
179
+ # Execute against env
180
+ step_obs = _env_step(tool, params)
181
+ done = step_obs.get("done", False)
182
+
183
+ if done:
184
+ traj.reward = step_obs.get("reward", 0.0)
185
+ return traj
186
+
187
+ # Feed tool result back so model can react to it
188
+ tool_result = step_obs.get("tool_result", "")
189
+ messages.append({
190
+ "role": "user",
191
+ "content": f"[{tool} result]\n{tool_result}",
192
+ })
193
+
194
+ # Max turns hit β€” force submit
195
+ obs_final = _env_step("submit", {})
196
+ traj.reward = obs_final.get("reward", 0.0)
197
+ return traj
training/train_rfc005.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RFC 005 training loop β€” true interactive multi-turn GRPO.
3
+
4
+ The model generates one tool call at a time and sees tool results before
5
+ deciding the next move. This is what train.py can't do with standard GRPOTrainer.
6
+
7
+ How it works:
8
+ 1. rollout_wrapper.run_episode() runs N parallel episodes via vLLM
9
+ - at each turn: generate β†’ execute tool β†’ inject result β†’ continue
10
+ - captures (context, completion, vllm_logprobs) per turn
11
+ 2. HF model re-scores each turn: forward pass on (context, completion)
12
+ β†’ differentiable token logprobs
13
+ 3. GRPO loss:
14
+ advantage_i = (reward_i - mean_reward) / (std_reward + 1e-8)
15
+ loss = -mean( advantage_i * sum(logprob of tokens in turn t, for all t in episode i) )
16
+ 4. optimizer.step()
17
+ 5. Unsloth syncs updated HF weights β†’ vLLM automatically
18
+
19
+ The key upgrade over train.py:
20
+ train.py β†’ model plans blind (generates all tool calls at once, never sees results)
21
+ this file β†’ model reacts (one call at a time, sees actual output each step)
22
+ """
23
+
24
+ import os
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from concurrent.futures import ThreadPoolExecutor
28
+ from unsloth import FastLanguageModel
29
+
30
+ from rollout_wrapper import run_episode, Trajectory
31
+
32
+ MODEL_NAME = os.environ.get("MODEL_NAME", "unsloth/gpt-oss-20b-instruct")
33
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/output/moa-rl-grpo-rfc005")
34
+ N_EPISODES = int(os.environ.get("N_EPISODES", "4")) # episodes per training step (GRPO needs variance)
35
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", "300"))
36
+ LR = float(os.environ.get("LR", "5e-6"))
37
+
38
+
39
+ # ── model ──────────────────────────────────────────────────────────────────────
40
+
41
+ print(f"Loading {MODEL_NAME}...")
42
+ model, tokenizer = FastLanguageModel.from_pretrained(
43
+ model_name = MODEL_NAME,
44
+ max_seq_length = 4096,
45
+ load_in_4bit = False,
46
+ dtype = torch.bfloat16,
47
+ )
48
+ model = FastLanguageModel.get_peft_model(
49
+ model,
50
+ r = 16,
51
+ lora_alpha = 16,
52
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
53
+ "gate_proj", "up_proj", "down_proj"],
54
+ use_gradient_checkpointing = "unsloth",
55
+ random_state = 42,
56
+ )
57
+
58
+ # Start vLLM inside Unsloth (syncs weights automatically after each optimizer step)
59
+ from unsloth import PatchFastRL
60
+ PatchFastRL("GRPO", FastLanguageModel)
61
+
62
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
63
+
64
+
65
+ # ── GRPO loss over a trajectory ────────────────────────────────────────────────
66
+
67
+ def score_turn(messages: list[dict], completion: str) -> torch.Tensor:
68
+ """
69
+ Re-score one turn with the HF model to get differentiable token logprobs.
70
+
71
+ vLLM logprobs are used for episode collection (fast generation).
72
+ HF logprobs are used here for the actual gradient update.
73
+ """
74
+ # Build input: format messages as a single string the model was trained on
75
+ prompt_text = tokenizer.apply_chat_template(
76
+ messages,
77
+ tokenize = False,
78
+ add_generation_prompt = True,
79
+ )
80
+ full_text = prompt_text + completion
81
+
82
+ inputs = tokenizer(full_text, return_tensors="pt").to(model.device)
83
+ prompt_ids = tokenizer(prompt_text, return_tensors="pt")["input_ids"]
84
+ prompt_len = prompt_ids.shape[1]
85
+
86
+ with torch.no_grad() if not model.training else torch.enable_grad():
87
+ logits = model(**inputs).logits # (1, seq_len, vocab)
88
+
89
+ # Only score the completion tokens (not the prompt)
90
+ comp_logits = logits[0, prompt_len - 1 : -1, :] # (comp_len, vocab)
91
+ comp_ids = inputs["input_ids"][0, prompt_len:] # (comp_len,)
92
+
93
+ log_probs = F.log_softmax(comp_logits, dim=-1)
94
+ token_lps = log_probs[range(len(comp_ids)), comp_ids]
95
+ return token_lps.sum() # scalar: total logprob of this completion
96
+
97
+
98
+ def grpo_loss(trajectories: list[Trajectory]) -> torch.Tensor:
99
+ """
100
+ Compute GRPO loss across N trajectories.
101
+
102
+ advantage_i = (reward_i - mean) / (std + 1e-8)
103
+ loss = -mean_i( advantage_i * sum_t( logprob(turn t in episode i) ) )
104
+ """
105
+ rewards = torch.tensor([t.reward for t in trajectories], dtype=torch.float32)
106
+ mean_r = rewards.mean()
107
+ std_r = rewards.std() + 1e-8
108
+ advantages = (rewards - mean_r) / std_r
109
+
110
+ losses = []
111
+ for traj, adv in zip(trajectories, advantages):
112
+ # Sum logprobs across all turns in this episode
113
+ total_lp = sum(
114
+ score_turn(turn.messages, turn.completion)
115
+ for turn in traj.turns
116
+ )
117
+ losses.append(-adv * total_lp)
118
+
119
+ return torch.stack(losses).mean()
120
+
121
+
122
+ # ── training loop ──────────────────────────────────────────────────────────────
123
+
124
+ print(f"RFC 005 training: {N_EPISODES} episodes/step Γ— {MAX_STEPS} steps")
125
+ print(f"Model: {MODEL_NAME} β†’ {OUTPUT_DIR}")
126
+
127
+ for step in range(MAX_STEPS):
128
+ model.train()
129
+
130
+ # Collect N episodes in parallel via vLLM
131
+ with ThreadPoolExecutor(max_workers=N_EPISODES) as pool:
132
+ trajectories = list(pool.map(lambda _: run_episode(), range(N_EPISODES)))
133
+
134
+ rewards = [t.reward for t in trajectories]
135
+ mean_r = sum(rewards) / len(rewards)
136
+
137
+ # GRPO loss + optimizer step
138
+ loss = grpo_loss(trajectories)
139
+ optimizer.zero_grad()
140
+ loss.backward()
141
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
142
+ optimizer.step()
143
+
144
+ # Unsloth automatically syncs updated weights β†’ vLLM after optimizer.step()
145
+
146
+ print(
147
+ f"step {step+1:4d}/{MAX_STEPS} | "
148
+ f"loss {loss.item():.4f} | "
149
+ f"rewards {[f'{r:.2f}' for r in rewards]} | "
150
+ f"mean {mean_r:.3f}"
151
+ )
152
+
153
+ if (step + 1) % 50 == 0:
154
+ model.save_pretrained(f"{OUTPUT_DIR}/step-{step+1}")
155
+ tokenizer.save_pretrained(f"{OUTPUT_DIR}/step-{step+1}")
156
+ print(f" β†’ checkpoint saved")
157
+
158
+ model.save_pretrained(OUTPUT_DIR)
159
+ tokenizer.save_pretrained(OUTPUT_DIR)
160
+ print(f"Done. Saved to {OUTPUT_DIR}")