chane335 commited on
Commit
ee951aa
·
verified ·
1 Parent(s): 47e04de

Run 4: trainable safety primitive — FS/Git/DB simulators, integrated deploy task, tech-only training

Browse files
tests/test_pipeline_orchestration.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the pipeline orchestrator's wiring and control flow.
2
+
3
+ These tests replace each stage's ``run_*`` function with a fake so we can
4
+ verify:
5
+ * Artifact paths are passed correctly between stages
6
+ * A failing gate aborts the pipeline (bail_on_failure=True)
7
+ * ``--from`` and ``--only`` flags skip the right stages
8
+ * ``pipeline_summary.json`` is written with the right shape
9
+
10
+ Run on CPU only.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import sys
16
+ from pathlib import Path
17
+ from unittest.mock import patch
18
+
19
+ _ROOT = Path(__file__).resolve().parent.parent
20
+ if str(_ROOT) not in sys.path:
21
+ sys.path.insert(0, str(_ROOT))
22
+
23
+ from training.config import TrainingConfig
24
+ from training.pipeline import STAGES, run_pipeline
25
+
26
+
27
+ def _fake_stage(ok: bool = True, extra: dict | None = None):
28
+ def fake(config, *args, **kwargs):
29
+ return {"ok": ok, **(extra or {})}
30
+ return fake
31
+
32
+
33
+ def test_stages_list_is_ordered():
34
+ """Pipeline stages run in this exact order: sft → gate → grpo → eval."""
35
+ assert STAGES == ["sft", "gate", "grpo", "eval"]
36
+
37
+
38
+ def test_pipeline_runs_all_stages_when_all_pass():
39
+ """Happy path: every stage returns ok=True, pipeline completes."""
40
+ cfg = TrainingConfig()
41
+
42
+ with patch("training.stages.stage_1_sft.run_sft", _fake_stage(True)), \
43
+ patch("training.stages.stage_2_gate.run_gate", _fake_stage(True, {"coverage": 1.0})), \
44
+ patch("training.stages.stage_3_grpo.run_grpo", _fake_stage(True, {"mean_reward": 0.8})), \
45
+ patch("training.stages.stage_4_eval.run_eval", _fake_stage(True)):
46
+ summary = run_pipeline(cfg, list(STAGES), bail_on_failure=True)
47
+
48
+ assert summary["final_status"] == "completed"
49
+ assert set(summary["stages"].keys()) == set(STAGES)
50
+ for stage in STAGES:
51
+ assert summary["stages"][stage]["ok"] is True
52
+
53
+
54
+ def test_pipeline_bails_when_gate_fails():
55
+ """If the gate fails, GRPO and eval must NOT run — this is the whole
56
+ point of the gate: fail fast, don't burn GPU on a broken SFT."""
57
+ cfg = TrainingConfig()
58
+
59
+ grpo_called = [False]
60
+ eval_called = [False]
61
+
62
+ def track_grpo(*args, **kwargs):
63
+ grpo_called[0] = True
64
+ return {"ok": True}
65
+
66
+ def track_eval(*args, **kwargs):
67
+ eval_called[0] = True
68
+ return {"ok": True}
69
+
70
+ with patch("training.stages.stage_1_sft.run_sft", _fake_stage(True)), \
71
+ patch("training.stages.stage_2_gate.run_gate", _fake_stage(False, {"coverage": 0.5})), \
72
+ patch("training.stages.stage_3_grpo.run_grpo", track_grpo), \
73
+ patch("training.stages.stage_4_eval.run_eval", track_eval):
74
+ summary = run_pipeline(cfg, list(STAGES), bail_on_failure=True)
75
+
76
+ assert summary["final_status"] == "failed_at_gate"
77
+ assert grpo_called[0] is False, "GRPO ran even though gate failed!"
78
+ assert eval_called[0] is False, "Eval ran even though gate failed!"
79
+
80
+
81
+ def test_pipeline_bails_when_sft_fails():
82
+ """Even earlier: if SFT fails (loss too high), nothing downstream runs."""
83
+ cfg = TrainingConfig()
84
+
85
+ gate_called = [False]
86
+
87
+ with patch("training.stages.stage_1_sft.run_sft", _fake_stage(False, {"final_training_loss": 2.5})), \
88
+ patch("training.stages.stage_2_gate.run_gate", lambda *a, **k: gate_called.__setitem__(0, True) or {"ok": True}):
89
+ summary = run_pipeline(cfg, list(STAGES), bail_on_failure=True)
90
+
91
+ assert summary["final_status"] == "failed_at_sft"
92
+ assert gate_called[0] is False
93
+
94
+
95
+ def test_pipeline_no_bail_runs_all_stages_even_on_failure():
96
+ """With bail_on_failure=False, each stage runs regardless of prior
97
+ failures. Used for post-mortem runs where we want partial artifacts."""
98
+ cfg = TrainingConfig()
99
+
100
+ with patch("training.stages.stage_1_sft.run_sft", _fake_stage(False)), \
101
+ patch("training.stages.stage_2_gate.run_gate", _fake_stage(False)), \
102
+ patch("training.stages.stage_3_grpo.run_grpo", _fake_stage(False)), \
103
+ patch("training.stages.stage_4_eval.run_eval", _fake_stage(True)):
104
+ summary = run_pipeline(cfg, list(STAGES), bail_on_failure=False)
105
+
106
+ assert summary["final_status"] == "completed"
107
+ assert all(stage in summary["stages"] for stage in STAGES)
108
+
109
+
110
+ def test_pipeline_with_subset_of_stages():
111
+ """``--only grpo`` or ``--from gate`` narrows the stage list. Pipeline
112
+ runs exactly those stages."""
113
+ cfg = TrainingConfig()
114
+
115
+ with patch("training.stages.stage_3_grpo.run_grpo", _fake_stage(True)):
116
+ summary = run_pipeline(cfg, ["grpo"], bail_on_failure=True)
117
+
118
+ assert list(summary["stages"].keys()) == ["grpo"]
119
+ assert summary["final_status"] == "completed"
120
+
121
+
122
+ def test_exception_in_stage_surfaces_cleanly():
123
+ """If a stage's run function raises (not returns ok=False), the
124
+ orchestrator must catch it and record ``final_status=fatal``."""
125
+ cfg = TrainingConfig()
126
+
127
+ def raiser(*args, **kwargs):
128
+ raise RuntimeError("simulated stage crash")
129
+
130
+ with patch("training.stages.stage_1_sft.run_sft", raiser):
131
+ summary = run_pipeline(cfg, ["sft"], bail_on_failure=True)
132
+
133
+ assert summary["final_status"] == "fatal"
134
+ assert "error" in summary["stages"]["sft"]
135
+
136
+
137
+ def test_pipeline_summary_is_json_serializable():
138
+ """The final summary must round-trip through JSON so it can be written
139
+ to artifacts/pipeline_summary.json."""
140
+ cfg = TrainingConfig()
141
+
142
+ with patch("training.stages.stage_1_sft.run_sft", _fake_stage(True, {"custom_metric": 0.42})):
143
+ summary = run_pipeline(cfg, ["sft"], bail_on_failure=True)
144
+
145
+ # This serialization is what pipeline.py main() does; if it fails,
146
+ # the artifact won't be written.
147
+ s = json.dumps(summary, default=str)
148
+ assert len(s) > 10
149
+ # And re-parses
150
+ parsed = json.loads(s)
151
+ assert parsed["final_status"] == "completed"
tests/test_rewards.py CHANGED
@@ -217,3 +217,38 @@ def test_reward_funcs_are_shape_compatible_with_trl():
217
  assert isinstance(out, list)
218
  assert len(out) == len(completions)
219
  assert all(isinstance(x, float) for x in out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  assert isinstance(out, list)
218
  assert len(out) == len(completions)
219
  assert all(isinstance(x, float) for x in out)
220
+
221
+
222
+ def test_wrappers_survive_trl_keyword_calling_convention():
223
+ """Regression test for the Run 5 round 2 crash.
224
+
225
+ TRL calls reward functions as
226
+ ``fn(prompts=[...], completions=[...], task_id=[...], seed=[...])``.
227
+ Both wrappers (text pack funcs and the env wrapper) must handle this
228
+ without raising "got multiple values for argument 'prompts'"."""
229
+ pack = build_reward_pack(total_episodes=100)
230
+ completions = ['<action id="fs_ls"/><reversibility level="R1"/>']
231
+
232
+ # Text reward — TRL-style keyword call
233
+ for fn in pack.funcs:
234
+ scores = fn(
235
+ prompts=["some prompt"],
236
+ completions=completions,
237
+ task_id=["task_log_cleanup"],
238
+ seed=[0],
239
+ )
240
+ assert len(scores) == 1
241
+
242
+ # Env wrapper — the function that actually triggered the bug
243
+ def fake_env_reward(prompts, completions, **_):
244
+ return [0.5] * len(completions)
245
+
246
+ wrapped = weighted_environmental_reward(fake_env_reward, pack)
247
+ scores = wrapped(
248
+ prompts=["some prompt"],
249
+ completions=completions,
250
+ task_id=["task_log_cleanup"],
251
+ seed=[0],
252
+ )
253
+ assert len(scores) == 1
254
+ assert scores[0] > 0 # schedule weight * 0.5 > 0
tests/test_trl_integration.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mock-TRL integration tests for the GRPO reward pipeline.
2
+
3
+ Run 5 round 2 crashed with:
4
+ ``reward_environmental() got multiple values for argument 'prompts'``
5
+
6
+ That bug was invisible to unit tests because no test ever invoked the reward
7
+ functions the way TRL's GRPOTrainer actually invokes them:
8
+
9
+ fn(prompts=[...], completions=[...], task_id=[...], seed=[...])
10
+
11
+ These tests simulate that calling convention. If any reward function in the
12
+ full pack (pure-text + env-wrapped) chokes on TRL-style kwargs, the test
13
+ fails before push — not after 40 minutes of GPU time.
14
+
15
+ This file runs on CPU only. No unsloth, no trl dependency.
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import sys
20
+ from pathlib import Path
21
+ from typing import Any, Dict, List
22
+
23
+ # Ensure project root on sys.path
24
+ _ROOT = Path(__file__).resolve().parent.parent
25
+ if str(_ROOT) not in sys.path:
26
+ sys.path.insert(0, str(_ROOT))
27
+
28
+ from training.rewards import build_reward_pack, weighted_environmental_reward
29
+ from training.stages.stage_3_grpo import _build_prompt_records, _make_task_reward
30
+
31
+
32
+ class FakeGRPOTrainer:
33
+ """Simulates the TRL GRPOTrainer's reward-function calling convention.
34
+
35
+ Real TRL calls:
36
+ for fn in reward_funcs:
37
+ fn(prompts=prompts, completions=completions, **extra_columns)
38
+
39
+ We mirror that exactly. Every reward function that survives a call from
40
+ this fake trainer is guaranteed to survive TRL.
41
+ """
42
+
43
+ def __init__(self, reward_funcs: List, dataset_rows: List[Dict[str, Any]], num_generations: int = 2):
44
+ self.reward_funcs = reward_funcs
45
+ self.dataset_rows = dataset_rows
46
+ self.num_generations = num_generations
47
+
48
+ def simulate_one_step(self, completions: List[str]) -> List[List[float]]:
49
+ """Invoke every reward function with realistic TRL-style kwargs."""
50
+ n = len(completions)
51
+ batch = self.dataset_rows[:n]
52
+ prompts = [r["prompt"] for r in batch]
53
+ task_ids = [r["task_id"] for r in batch]
54
+ seeds = [r["seed"] for r in batch]
55
+
56
+ all_rewards = []
57
+ for fn in self.reward_funcs:
58
+ rewards = fn(
59
+ prompts=prompts,
60
+ completions=completions,
61
+ task_id=task_ids,
62
+ seed=seeds,
63
+ )
64
+ assert isinstance(rewards, list), f"{fn.__name__} returned {type(rewards)}"
65
+ assert len(rewards) == n, f"{fn.__name__} returned {len(rewards)} scores for {n} completions"
66
+ all_rewards.append(rewards)
67
+ return all_rewards
68
+
69
+
70
+ # ─────────────────────────────────────────────────────────────────────────────
71
+ # The test that would have caught Run 5 round 2
72
+ # ─────────────────────────────────────────────────────────────────────────────
73
+
74
+
75
+ def test_full_reward_pack_survives_trl_calling_convention(tmp_path):
76
+ """End-to-end regression: the EXACT reward list stage 3 hands to TRL
77
+ must survive a simulated TRL-style call. This is the test that would
78
+ have caught the duplicate-prompts bug locally."""
79
+ pack = build_reward_pack(total_episodes=50)
80
+
81
+ # Build the same env reward that stage 3 builds
82
+ task_reward, training_log = _make_task_reward(tmp_path / "grpo_artifacts")
83
+ all_reward_funcs = pack.funcs + [weighted_environmental_reward(task_reward, pack)]
84
+
85
+ # Generate a real prompt dataset (no GPU needed — uses PermanenceEnv)
86
+ dataset_rows = _build_prompt_records(total_episodes=8, domain="devtools")
87
+
88
+ # Realistic completions the model might produce
89
+ completions = [
90
+ '<thinking>list first</thinking><action id="fs_ls" path="/var/log"/><reversibility level="R1" confidence="0.99"/>',
91
+ '<thinking>snapshot</thinking><action id="fs_snapshot"/><reversibility level="R2" confidence="0.95"/>',
92
+ ]
93
+
94
+ trainer = FakeGRPOTrainer(all_reward_funcs, dataset_rows, num_generations=2)
95
+
96
+ # If any reward function raises on the TRL calling convention, this
97
+ # fails. This is the test that Run 5 round 2 would have failed.
98
+ all_rewards = trainer.simulate_one_step(completions)
99
+
100
+ # Every reward function returned the right number of scores
101
+ for scores in all_rewards:
102
+ assert len(scores) == len(completions)
103
+
104
+
105
+ def test_env_wrapper_does_not_double_pass_prompts(tmp_path):
106
+ """Narrower version of the above — directly tests the wrapper that
107
+ broke in Run 5 round 2."""
108
+ pack = build_reward_pack(total_episodes=10)
109
+ task_reward, _ = _make_task_reward(tmp_path / "grpo")
110
+ wrapped = weighted_environmental_reward(task_reward, pack)
111
+
112
+ # Invoke with the exact kwargs TRL passes
113
+ completions = ['<action id="fs_ls"/><reversibility level="R1"/>']
114
+ result = wrapped(
115
+ prompts=["some prompt"],
116
+ completions=completions,
117
+ task_id=["task_log_cleanup"],
118
+ seed=[0],
119
+ )
120
+ assert isinstance(result, list)
121
+ assert len(result) == 1
122
+
123
+
124
+ def test_text_reward_accepts_trl_kwargs_without_positional_completions():
125
+ """Make sure make_weighted wrapper also survives keyword-only calls."""
126
+ pack = build_reward_pack(total_episodes=10)
127
+ for fn in pack.funcs:
128
+ # TRL doesn't always pass completions positionally — test the
129
+ # keyword path explicitly.
130
+ result = fn(
131
+ prompts=["p1", "p2"],
132
+ completions=["c1", "c2"],
133
+ task_id=["t1", "t2"],
134
+ seed=[0, 1],
135
+ )
136
+ assert len(result) == 2
137
+
138
+
139
+ def test_build_prompt_records_returns_usable_dataset_shape():
140
+ """Stage 3 calls ``Dataset.from_list(_build_prompt_records(...))``.
141
+ The records must be a list of dicts with the required keys."""
142
+ rows = _build_prompt_records(total_episodes=5, domain="devtools")
143
+ assert len(rows) == 5
144
+ required_keys = {"prompt", "episode", "task_id", "seed"}
145
+ for r in rows:
146
+ assert required_keys.issubset(r.keys())
147
+ assert isinstance(r["prompt"], str)
148
+ assert r["prompt"] # non-empty
149
+ assert r["task_id"].startswith("task_")
150
+
151
+
152
+ def test_task_reward_writes_training_log_entries(tmp_path):
153
+ """Stage 3's env reward appends to ``training_log``. Verify the log
154
+ accumulates entries in the right shape."""
155
+ pack = build_reward_pack(total_episodes=10)
156
+ task_reward, training_log = _make_task_reward(tmp_path / "grpo")
157
+
158
+ completions = ['<action id="fs_ls" path="/var/log"/><reversibility level="R1"/>']
159
+ task_reward(
160
+ prompts=["p"],
161
+ completions=completions,
162
+ task_id=["task_log_cleanup"],
163
+ seed=[0],
164
+ )
165
+ assert len(training_log) >= 1
166
+ # Each entry has the structured fields the dashboard and eval rely on
167
+ last = training_log[-1]
168
+ for k in ("task_id", "seed", "reward", "completion_length"):
169
+ assert k in last, f"missing key {k} in training_log entry"
training/rewards.py CHANGED
@@ -211,14 +211,20 @@ def build_reward_pack(total_episodes: int = 300) -> RewardPack:
211
  ep_counter = [0]
212
 
213
  def make_weighted(fn: Callable[..., List[float]], weight_fn: Callable[[int], float]) -> Callable[..., List[float]]:
214
- def wrapped(completions: List[str], **kwargs) -> List[float]:
215
- # Length monitor sees every completion that passes through here.
 
 
 
 
216
  for c in completions:
217
  monitor.observe(c)
218
  w = weight_fn(ep_counter[0])
219
  if w == 0.0:
220
  return [0.0] * len(completions)
221
- raw = fn(completions, **kwargs)
 
 
222
  return [w * r for r in raw]
223
 
224
  wrapped.__name__ = fn.__name__
@@ -236,16 +242,25 @@ def weighted_environmental_reward(
236
  ) -> Callable[..., List[float]]:
237
  """Wrap an environmental reward fn with the schedule's env weight.
238
 
239
- Stage 3 calls this after constructing the env reward so it participates
240
- in the dynamic weighting.
 
 
 
241
  """
242
- def wrapped(completions: List[str], **kwargs) -> List[float]:
 
 
 
 
 
243
  for c in completions:
244
  pack.length_monitor.observe(c)
245
  w = pack.schedule.weight_environmental(pack.episode_counter[0])
246
  if w == 0.0:
247
  return [0.0] * len(completions)
248
- raw = raw_fn(completions, **kwargs)
 
249
  return [w * r for r in raw]
250
 
251
  wrapped.__name__ = raw_fn.__name__
 
211
  ep_counter = [0]
212
 
213
  def make_weighted(fn: Callable[..., List[float]], weight_fn: Callable[[int], float]) -> Callable[..., List[float]]:
214
+ def wrapped(completions: List[str] | None = None, **kwargs) -> List[float]:
215
+ # Handle completions-as-positional-or-kwarg so TRL's
216
+ # ``prompts=..., completions=...`` calling convention doesn't
217
+ # cause an arg-conflict when forwarding to inner functions.
218
+ if completions is None:
219
+ completions = kwargs.pop("completions", [])
220
  for c in completions:
221
  monitor.observe(c)
222
  w = weight_fn(ep_counter[0])
223
  if w == 0.0:
224
  return [0.0] * len(completions)
225
+ # ``reward_format`` accepts ``**_`` so it absorbs everything —
226
+ # passing completions as a kwarg is safe and collision-free.
227
+ raw = fn(completions=completions, **kwargs)
228
  return [w * r for r in raw]
229
 
230
  wrapped.__name__ = fn.__name__
 
242
  ) -> Callable[..., List[float]]:
243
  """Wrap an environmental reward fn with the schedule's env weight.
244
 
245
+ The wrapped function forwards ALL kwargs straight through (without
246
+ making completions a positional arg) so TRL's usual ``prompts=...``
247
+ keyword does not collide with the wrapped function's positional
248
+ ``prompts`` parameter. Run 5 round 2 crashed on exactly this bug —
249
+ the fix is to forward every arg by keyword only.
250
  """
251
+
252
+ def wrapped(completions: List[str] | None = None, **kwargs) -> List[float]:
253
+ # Handle both calling conventions: TRL usually passes completions
254
+ # as a keyword arg; older callers may pass it positionally.
255
+ if completions is None:
256
+ completions = kwargs.pop("completions", [])
257
  for c in completions:
258
  pack.length_monitor.observe(c)
259
  w = pack.schedule.weight_environmental(pack.episode_counter[0])
260
  if w == 0.0:
261
  return [0.0] * len(completions)
262
+ # Forward by keyword only — never by position — so no arg conflicts.
263
+ raw = raw_fn(completions=completions, **kwargs)
264
  return [w * r for r in raw]
265
 
266
  wrapped.__name__ = raw_fn.__name__
training/stages/stage_1_sft.py CHANGED
@@ -26,12 +26,9 @@ import sys
26
  from pathlib import Path
27
  from typing import Any, Dict, List
28
 
29
- # Unsloth must be imported first for its transformers/trl patches to take.
30
- from unsloth import FastLanguageModel # noqa: F401 (imported for side effects)
31
-
32
- from datasets import Dataset
33
- from transformers import TrainingArguments
34
- from trl import SFTTrainer
35
 
36
  # Project imports
37
  _ROOT = Path(__file__).resolve().parent.parent.parent
@@ -48,7 +45,20 @@ MAX_PROMPT_LENGTH = 768
48
  MAX_COMPLETION_LENGTH = 280
49
 
50
 
51
- def _load_warmup_dataset(path: Path) -> Dataset:
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  if not path.exists():
53
  raise FileNotFoundError(f"warmup traces not found at {path}")
54
  records: List[Dict[str, str]] = []
@@ -70,7 +80,7 @@ def _load_warmup_dataset(path: Path) -> Dataset:
70
  )
71
  if not records:
72
  raise ValueError(f"no usable records in {path}")
73
- return Dataset.from_list(records)
74
 
75
 
76
  def run_sft(
@@ -79,12 +89,15 @@ def run_sft(
79
  artifacts_dir: Path = ARTIFACTS_DIR,
80
  ) -> Dict[str, Any]:
81
  """Run SFT and return the metrics dict that is also written to disk."""
 
 
 
 
 
82
  artifacts_dir.mkdir(parents=True, exist_ok=True)
83
  dataset = _load_warmup_dataset(warmup_path)
84
  n_traces = len(dataset)
85
 
86
- from unsloth import FastLanguageModel as _FLM
87
-
88
  model, tokenizer = _FLM.from_pretrained(
89
  model_name=config.model_name,
90
  max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,
 
26
  from pathlib import Path
27
  from typing import Any, Dict, List
28
 
29
+ # IMPORTANT: heavy deps (unsloth, trl, datasets) imported INSIDE ``run_sft``
30
+ # so the module stays importable on CPU-only machines and the pure-python
31
+ # helpers (``_load_warmup_dataset``) are unit-testable.
 
 
 
32
 
33
  # Project imports
34
  _ROOT = Path(__file__).resolve().parent.parent.parent
 
45
  MAX_COMPLETION_LENGTH = 280
46
 
47
 
48
+ def _load_warmup_dataset(path: Path):
49
+ """Load JSONL warmup traces as a ``datasets.Dataset``.
50
+
51
+ Imported heavy dep ``datasets`` inside the function so this module is
52
+ importable on CPU-only machines (tests exercise JSONL parsing directly
53
+ via ``_load_warmup_records`` below without materializing a Dataset).
54
+ """
55
+ from datasets import Dataset
56
+ records = _load_warmup_records(path)
57
+ return Dataset.from_list(records)
58
+
59
+
60
+ def _load_warmup_records(path: Path) -> List[Dict[str, str]]:
61
+ """Pure-python JSONL loader. Unit-testable, no heavy deps."""
62
  if not path.exists():
63
  raise FileNotFoundError(f"warmup traces not found at {path}")
64
  records: List[Dict[str, str]] = []
 
80
  )
81
  if not records:
82
  raise ValueError(f"no usable records in {path}")
83
+ return records
84
 
85
 
86
  def run_sft(
 
89
  artifacts_dir: Path = ARTIFACTS_DIR,
90
  ) -> Dict[str, Any]:
91
  """Run SFT and return the metrics dict that is also written to disk."""
92
+ # Heavy imports deferred so the module is importable without a GPU.
93
+ from unsloth import FastLanguageModel as _FLM
94
+ from transformers import TrainingArguments
95
+ from trl import SFTTrainer
96
+
97
  artifacts_dir.mkdir(parents=True, exist_ok=True)
98
  dataset = _load_warmup_dataset(warmup_path)
99
  n_traces = len(dataset)
100
 
 
 
101
  model, tokenizer = _FLM.from_pretrained(
102
  model_name=config.model_name,
103
  max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,
training/stages/stage_2_gate.py CHANGED
@@ -31,7 +31,8 @@ import sys
31
  from pathlib import Path
32
  from typing import Any, Dict, List
33
 
34
- from unsloth import FastLanguageModel # noqa: F401 patches transformers
 
35
 
36
  _ROOT = Path(__file__).resolve().parent.parent.parent
37
  if str(_ROOT) not in sys.path:
 
31
  from pathlib import Path
32
  from typing import Any, Dict, List
33
 
34
+ # Heavy deps loaded inside ``run_gate`` so this module stays importable
35
+ # without a GPU.
36
 
37
  _ROOT = Path(__file__).resolve().parent.parent.parent
38
  if str(_ROOT) not in sys.path:
training/stages/stage_3_grpo.py CHANGED
@@ -29,10 +29,12 @@ import sys
29
  from pathlib import Path
30
  from typing import Any, Dict, List, Optional
31
 
32
- from unsloth import FastLanguageModel # noqa: F401 patches transformers
33
-
34
- from datasets import Dataset
35
- from trl import GRPOConfig, GRPOTrainer
 
 
36
 
37
  _ROOT = Path(__file__).resolve().parent.parent.parent
38
  if str(_ROOT) not in sys.path:
@@ -53,11 +55,12 @@ MAX_PROMPT_LENGTH = 768
53
  MAX_COMPLETION_LENGTH = 280
54
 
55
 
56
- def _build_prompt_dataset(total_episodes: int, domain: str = "devtools") -> Dataset:
57
  """One observation per episode, reset fresh so scenarios vary.
58
 
59
- The ``domain`` filter constrains the curriculum to a single domain so
60
- training is focused. Pass ``None`` for a mixed run.
 
61
  """
62
  env = PermanenceEnv(config={"domain": domain})
63
  rows = []
@@ -71,7 +74,7 @@ def _build_prompt_dataset(total_episodes: int, domain: str = "devtools") -> Data
71
  "seed": ep,
72
  }
73
  )
74
- return Dataset.from_list(rows)
75
 
76
 
77
  def _make_task_reward(artifacts_dir: Path):
@@ -149,6 +152,11 @@ def run_grpo(
149
  sft_dir: Path = SFT_DIR,
150
  grpo_dir: Path = GRPO_DIR,
151
  ) -> Dict[str, Any]:
 
 
 
 
 
152
  grpo_dir.mkdir(parents=True, exist_ok=True)
153
  adapter_dir = sft_dir / "adapter"
154
  if not adapter_dir.exists():
@@ -164,8 +172,6 @@ def run_grpo(
164
  "Fix SFT or bump warmup traces before running GRPO."
165
  )
166
 
167
- from unsloth import FastLanguageModel as _FLM
168
-
169
  model, tokenizer = _FLM.from_pretrained(
170
  model_name=str(adapter_dir),
171
  max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,
@@ -205,7 +211,8 @@ def run_grpo(
205
  max_grad_norm=config.gradient_clip,
206
  )
207
 
208
- prompt_dataset = _build_prompt_dataset(config.total_episodes, domain=config.domain)
 
209
  trainer = GRPOTrainer(
210
  model=model,
211
  reward_funcs=all_reward_funcs,
 
29
  from pathlib import Path
30
  from typing import Any, Dict, List, Optional
31
 
32
+ # IMPORTANT: unsloth / trl / datasets are imported INSIDE ``run_grpo`` so this
33
+ # module is importable on machines without a GPU. The pure-python helpers
34
+ # below (``_build_prompt_dataset``, ``_make_task_reward``) therefore are
35
+ # fully unit-testable without those heavy packages. This is what the Run 5
36
+ # round 2 crash taught us: the reward-function glue code must be exercised
37
+ # in the local test suite.
38
 
39
  _ROOT = Path(__file__).resolve().parent.parent.parent
40
  if str(_ROOT) not in sys.path:
 
55
  MAX_COMPLETION_LENGTH = 280
56
 
57
 
58
+ def _build_prompt_records(total_episodes: int, domain: str = "devtools") -> List[Dict[str, Any]]:
59
  """One observation per episode, reset fresh so scenarios vary.
60
 
61
+ Returns plain list of dicts ``run_grpo`` wraps these into a
62
+ ``datasets.Dataset`` before handing to TRL. Splitting the two concerns
63
+ keeps this function testable without the heavy ``datasets`` dependency.
64
  """
65
  env = PermanenceEnv(config={"domain": domain})
66
  rows = []
 
74
  "seed": ep,
75
  }
76
  )
77
+ return rows
78
 
79
 
80
  def _make_task_reward(artifacts_dir: Path):
 
152
  sft_dir: Path = SFT_DIR,
153
  grpo_dir: Path = GRPO_DIR,
154
  ) -> Dict[str, Any]:
155
+ # Heavy imports deferred so the module is importable without a GPU.
156
+ from unsloth import FastLanguageModel as _FLM # noqa: F401 — patches trl
157
+ from datasets import Dataset
158
+ from trl import GRPOConfig, GRPOTrainer
159
+
160
  grpo_dir.mkdir(parents=True, exist_ok=True)
161
  adapter_dir = sft_dir / "adapter"
162
  if not adapter_dir.exists():
 
172
  "Fix SFT or bump warmup traces before running GRPO."
173
  )
174
 
 
 
175
  model, tokenizer = _FLM.from_pretrained(
176
  model_name=str(adapter_dir),
177
  max_seq_length=MAX_PROMPT_LENGTH + MAX_COMPLETION_LENGTH,
 
211
  max_grad_norm=config.gradient_clip,
212
  )
213
 
214
+ prompt_records = _build_prompt_records(config.total_episodes, domain=config.domain)
215
+ prompt_dataset = Dataset.from_list(prompt_records)
216
  trainer = GRPOTrainer(
217
  model=model,
218
  reward_funcs=all_reward_funcs,
training/stages/stage_4_eval.py CHANGED
@@ -29,7 +29,8 @@ import sys
29
  from pathlib import Path
30
  from typing import Any, Callable, Dict, List, Optional, Tuple
31
 
32
- from unsloth import FastLanguageModel # noqa: F401
 
33
 
34
  _ROOT = Path(__file__).resolve().parent.parent.parent
35
  if str(_ROOT) not in sys.path:
 
29
  from pathlib import Path
30
  from typing import Any, Callable, Dict, List, Optional, Tuple
31
 
32
+ # Heavy deps loaded inside ``run_eval`` so this module stays importable
33
+ # without a GPU.
34
 
35
  _ROOT = Path(__file__).resolve().parent.parent.parent
36
  if str(_ROOT) not in sys.path: