Nikitasoni22 commited on
Commit
048dc4f
·
1 Parent(s): 3e7bcbe

training issue resolved

Browse files
Files changed (3) hide show
  1. eval_lora.py +26 -1
  2. train.py +409 -50
  3. train_colab.ipynb +35 -26
eval_lora.py CHANGED
@@ -5,6 +5,9 @@ Usage (Colab):
5
  !python eval_lora.py --adapter-path ./cicd_rl_agent_final
6
 
7
  Optional: --base-model must match what you fine-tuned.
 
 
 
8
  """
9
 
10
  import argparse
@@ -96,7 +99,14 @@ def main():
96
  default=True,
97
  help="Compare predicted YAML vs correct_yaml using canonicalized YAML tree",
98
  )
 
 
 
 
 
 
99
  args = p.parse_args()
 
100
 
101
  if not os.path.isdir(args.adapter_path):
102
  print(f"Adapter path not found: {args.adapter_path}")
@@ -146,9 +156,24 @@ def main():
146
  comp = strip_code_fences(raw)
147
  correct = task.get("correct_yaml", "")
148
  label = partial_match_score(comp, correct, task["pipeline_yaml"], args.canonical_compare)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  by_diff[d].append(
150
  {
151
- "id": task["id"],
152
  "label": label,
153
  }
154
  )
 
5
  !python eval_lora.py --adapter-path ./cicd_rl_agent_final
6
 
7
  Optional: --base-model must match what you fine-tuned.
8
+
9
+ Debug a few tasks (raw vs canonical reference):
10
+ !python eval_lora.py --adapter-path ./cicd_rl_agent_final --inspect easy_003,medium_001
11
  """
12
 
13
  import argparse
 
99
  default=True,
100
  help="Compare predicted YAML vs correct_yaml using canonicalized YAML tree",
101
  )
102
+ p.add_argument(
103
+ "--inspect",
104
+ type=str,
105
+ default="",
106
+ help="Comma-separated task ids (e.g. easy_001,medium_002) to print raw model output vs reference",
107
+ )
108
  args = p.parse_args()
109
+ inspect_ids = {s.strip() for s in args.inspect.split(",") if s.strip()}
110
 
111
  if not os.path.isdir(args.adapter_path):
112
  print(f"Adapter path not found: {args.adapter_path}")
 
156
  comp = strip_code_fences(raw)
157
  correct = task.get("correct_yaml", "")
158
  label = partial_match_score(comp, correct, task["pipeline_yaml"], args.canonical_compare)
159
+ tid = task["id"]
160
+ if inspect_ids and tid in inspect_ids:
161
+ pred_c = canonical_yaml(comp) if args.canonical_compare else comp.strip()
162
+ gold_c = canonical_yaml(correct) if args.canonical_compare else correct.strip()
163
+ print(f"\n=== INSPECT {tid} (label={label}) ===\n")
164
+ print("--- raw model output ---")
165
+ print(raw)
166
+ print("--- after strip_code_fences ---")
167
+ print(comp)
168
+ print("--- canonical pred ---")
169
+ print(pred_c)
170
+ print("--- canonical reference (correct_yaml) ---")
171
+ print(gold_c)
172
+ print("--- match ---")
173
+ print("exact canonical match:", pred_c == gold_c)
174
  by_diff[d].append(
175
  {
176
+ "id": tid,
177
  "label": label,
178
  }
179
  )
train.py CHANGED
@@ -1,10 +1,24 @@
1
  """
2
- train.py — CICD RL Agent full training script
3
- Run: python train.py
 
 
 
 
 
 
 
 
 
 
 
 
4
  Requires: pip install unsloth trl datasets transformers
5
  """
6
 
 
7
  import os, re, sys
 
8
  sys.path.insert(0, os.path.dirname(__file__))
9
  try:
10
  import yaml
@@ -23,6 +37,23 @@ NUM_SAMPLES = 512
23
  # GRPO: use `max_completion_length` (TRL); older examples used `max_new_tokens`.
24
  MAX_COMPLETION_TOKENS = 128
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  from cicd_debug_env.tasks import ALL_TASKS
27
  from datasets import Dataset
28
  import random
@@ -58,6 +89,29 @@ def build_dataset():
58
  })
59
  return Dataset.from_list(records)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def _completion_to_text(completion) -> str:
62
  """
63
  Normalize TRL/Unsloth completion payloads to plain text.
@@ -118,7 +172,8 @@ def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **
118
  pred_canon = _canonical_yaml(pred)
119
  correct_canon = _canonical_yaml(correct)
120
  # Strict reward: exact/canonical exact gets high reward; everything else is negative.
121
- rewards.append(3.0 if pred_canon and pred_canon == correct_canon else -1.0)
 
122
  return rewards
123
 
124
  def reward_yaml_structure(completions, prompts, **kwargs):
@@ -137,7 +192,7 @@ def reward_yaml_structure(completions, prompts, **kwargs):
137
  score = 0.4 * int(starts_yaml) + 0.4 * int(has_yaml_keys) + 0.2 * int(line_count_ok)
138
  if has_prose_or_md:
139
  score -= 1.0
140
- rewards.append(score)
141
  return rewards
142
 
143
  def reward_no_hallucination(completions, prompts, **kwargs):
@@ -149,27 +204,266 @@ def reward_no_hallucination(completions, prompts, **kwargs):
149
  for c in completions:
150
  lower = _completion_to_text(c).lower()
151
  bad_hits = sum(1 for p in bad if p in lower)
152
- values.append(-2.0 if bad_hits > 0 else 0.5)
153
  return values
154
 
155
  REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # Colab often sets WANDB_DISABLED in the runtime env.
159
- # If report_to is wandb, this env var causes a hard runtime error in Trainer callbacks.
160
  if os.environ.get("WANDB_DISABLED", "").strip().lower() in {"1", "true", "yes", "on"}:
161
- print("Detected WANDB_DISABLED; unsetting it because report_to='wandb'.")
162
  os.environ.pop("WANDB_DISABLED", None)
163
 
164
  if USE_UNSLOTH:
165
  from unsloth import FastLanguageModel
166
  model, tokenizer = FastLanguageModel.from_pretrained(
167
- model_name=MODEL_NAME, max_seq_length=1024, dtype=None, load_in_4bit=True)
 
168
  model = FastLanguageModel.get_peft_model(
169
- model, r=16,
170
- target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
171
- lora_alpha=16, lora_dropout=0.0, bias="none",
172
- use_gradient_checkpointing="unsloth", random_state=42)
 
 
 
 
 
173
  else:
174
  from transformers import AutoModelForCausalLM, AutoTokenizer
175
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -177,53 +471,118 @@ def main():
177
  if tokenizer.pad_token is None:
178
  tokenizer.pad_token = tokenizer.eos_token
179
 
180
- dataset = build_dataset()
181
- print(f"Dataset: {len(dataset)} samples")
 
182
 
183
- # Prefer wandb logging when available; gracefully fall back if not installed.
184
- use_wandb = True
185
- try:
186
- import wandb # noqa: F401
187
- except Exception:
188
- use_wandb = False
189
- print("wandb is not installed; falling back to report_to='none'.")
190
 
191
- from trl import GRPOTrainer, GRPOConfig
192
- args = GRPOConfig(
193
- output_dir="./cicd_rl_output",
194
- per_device_train_batch_size=BATCH_SIZE,
195
- gradient_accumulation_steps=GRAD_ACCUM,
196
- learning_rate=5e-6, max_steps=MAX_STEPS,
197
- num_generations=4, max_completion_length=MAX_COMPLETION_TOKENS,
198
- logging_steps=5, save_steps=50,
199
- report_to="wandb" if use_wandb else "none", remove_unused_columns=False,
200
- warmup_steps=10, lr_scheduler_type="cosine", optim="adamw_8bit",
201
- )
202
- trainer = GRPOTrainer(
203
- model=model, args=args, reward_funcs=REWARD_FUNCTIONS,
204
- train_dataset=dataset, processing_class=tokenizer)
205
 
206
- print("Starting GRPO training...")
207
- if use_wandb:
208
- import wandb
209
- wandb.init(project="cicd-rl-agent", name="grpo-run-1")
210
- trainer.train()
211
- print("Training complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  save_path = "./cicd_rl_agent_final"
214
- if USE_UNSLOTH:
215
  model.save_pretrained(save_path)
216
  tokenizer.save_pretrained(save_path)
217
- print(f"LoRA adapters saved to {save_path}")
218
- print("Testing post-training inference...")
219
- FastLanguageModel.for_inference(model)
220
- test_input = tokenizer("Fix this YAML: steps:\n - run: npm tset", return_tensors="pt").to("cuda")
221
- out = model.generate(**test_input, max_new_tokens=64)
222
- print(tokenizer.decode(out[0], skip_special_tokens=True))
223
- else:
224
  model.save_pretrained(save_path)
225
  tokenizer.save_pretrained(save_path)
226
- print(f"Model saved to {save_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  if __name__ == "__main__":
229
  main()
 
1
  """
2
+ train.py — CICD RL Agent: optional SFT (supervised) then GRPO (RL) on CI/CD YAML fixes.
3
+
4
+ Default: short SFT on (prompt → correct_yaml), then GRPO with correctness-heavy rewards.
5
+
6
+ python train.py # SFT (short) + GRPO (same as before)
7
+ python train.py --stages grpo # GRPO only (old behavior, no SFT)
8
+ python train.py --stages sft # SFT only; saves ./cicd_rl_sft_lora
9
+ train.py --stages sft,grpo --sft-epochs 2
10
+ train.py --no-final-eval
11
+ train.py --eval-timeout 90
12
+
13
+ Console: SFT/GRPO log lines (loss/rewards + step X/Y), per-stage times and step counts, then
14
+ a final eval of every task with correct/wrong/timeout, wall time, and reward breakdown.
15
+
16
  Requires: pip install unsloth trl datasets transformers
17
  """
18
 
19
+ import argparse
20
  import os, re, sys
21
+ import time
22
  sys.path.insert(0, os.path.dirname(__file__))
23
  try:
24
  import yaml
 
37
  # GRPO: use `max_completion_length` (TRL); older examples used `max_new_tokens`.
38
  MAX_COMPLETION_TOKENS = 128
39
 
40
+ # SFT: teach exact gold YAML before RL polish (short run by design).
41
+ SFT_EPOCHS = 1
42
+ SFT_LEARNING_RATE = 2e-4
43
+ SFT_MAX_SEQ = 1024
44
+ SFT_DATASET_SIZE = 512
45
+ SFT_OUTPUT = "./cicd_rl_sft_lora"
46
+
47
+ # Post-training quick eval: mark each task correct / wrong / timeout if generation exceeds this (seconds).
48
+ EVAL_GEN_TIMEOUT_SEC = 60.0
49
+
50
+ # Reward mix: GRPO sums per-function rewards; keep correctness as the dominant term.
51
+ REWARD_FIX_MATCH = 5.0
52
+ REWARD_FIX_MISS = -1.5
53
+ REWARD_STRUCT_SCALE = 0.2
54
+ REWARD_HALLU_GOOD = 0.1
55
+ REWARD_HALLU_BAD = -0.35
56
+
57
  from cicd_debug_env.tasks import ALL_TASKS
58
  from datasets import Dataset
59
  import random
 
89
  })
90
  return Dataset.from_list(records)
91
 
92
+
93
+ def build_sft_dataset(tokenizer) -> Dataset:
94
+ """Supervised (prompt, assistant) = same chat format as inference; target is exact correct_yaml."""
95
+ easy = [t for t in ALL_TASKS if t["difficulty"] == "easy"]
96
+ medium = [t for t in ALL_TASKS if t["difficulty"] == "medium"]
97
+ hard = [t for t in ALL_TASKS if t["difficulty"] == "hard"]
98
+ records = []
99
+ for _ in range(SFT_DATASET_SIZE):
100
+ r = random.random()
101
+ task = random.choice(easy if r < 0.5 else medium if r < 0.8 else hard)
102
+ gold = (task.get("correct_yaml") or "").strip()
103
+ messages = [
104
+ {"role": "system", "content": SYSTEM_PROMPT},
105
+ {"role": "user", "content": build_prompt(task)},
106
+ {"role": "assistant", "content": gold},
107
+ ]
108
+ text = tokenizer.apply_chat_template(
109
+ messages, tokenize=False, add_generation_prompt=False
110
+ )
111
+ records.append({"text": text})
112
+ return Dataset.from_list(records)
113
+
114
+
115
  def _completion_to_text(completion) -> str:
116
  """
117
  Normalize TRL/Unsloth completion payloads to plain text.
 
172
  pred_canon = _canonical_yaml(pred)
173
  correct_canon = _canonical_yaml(correct)
174
  # Strict reward: exact/canonical exact gets high reward; everything else is negative.
175
+ ok = bool(pred_canon and pred_canon == correct_canon)
176
+ rewards.append(REWARD_FIX_MATCH if ok else REWARD_FIX_MISS)
177
  return rewards
178
 
179
  def reward_yaml_structure(completions, prompts, **kwargs):
 
192
  score = 0.4 * int(starts_yaml) + 0.4 * int(has_yaml_keys) + 0.2 * int(line_count_ok)
193
  if has_prose_or_md:
194
  score -= 1.0
195
+ rewards.append(score * REWARD_STRUCT_SCALE)
196
  return rewards
197
 
198
  def reward_no_hallucination(completions, prompts, **kwargs):
 
204
  for c in completions:
205
  lower = _completion_to_text(c).lower()
206
  bad_hits = sum(1 for p in bad if p in lower)
207
+ values.append(REWARD_HALLU_BAD if bad_hits > 0 else REWARD_HALLU_GOOD)
208
  return values
209
 
210
  REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]
211
 
212
+
213
+ def _grpo_console_callback(max_steps: int, label: str = "GRPO"):
214
+ from transformers import TrainerCallback
215
+
216
+ class _GRPOConsoleLogCallback(TrainerCallback):
217
+ def __init__(self) -> None:
218
+ self._max = max_steps
219
+ self._label = label
220
+
221
+ def on_log(self, args, state, control, logs=None, **kwargs):
222
+ if not logs:
223
+ return
224
+ parts = [f"[{self._label} turn/step {state.global_step}/{self._max}]"]
225
+ for k in sorted(logs.keys()):
226
+ kl = k.lower()
227
+ if "reward" in kl or k in ("loss", "kl", "learning_rate", "train_loss") or "loss" in kl:
228
+ v = logs[k]
229
+ if isinstance(v, (int, float)):
230
+ parts.append(f"{k}={v:.6g}")
231
+ else:
232
+ parts.append(f"{k}={v}")
233
+ print(" | ".join(parts), flush=True)
234
+
235
+ return _GRPOConsoleLogCallback()
236
+
237
+
238
+ def _sft_console_callback():
239
+ from transformers import TrainerCallback
240
+
241
+ class _SFTConsoleLogCallback(TrainerCallback):
242
+ def on_log(self, args, state, control, logs=None, **kwargs):
243
+ if not logs:
244
+ return
245
+ line = f"[SFT turn/step {state.global_step}]"
246
+ for k, v in sorted(logs.items()):
247
+ if "loss" in k.lower() or "learning_rate" in k:
248
+ if isinstance(v, (int, float)):
249
+ line += f" {k}={v:.6g}"
250
+ print(line, flush=True)
251
+
252
+ return _SFTConsoleLogCallback()
253
+
254
+
255
+ def _format_seconds(sec: float) -> str:
256
+ if sec < 60:
257
+ return f"{sec:.1f}s"
258
+ m, s = int(sec // 60), sec % 60
259
+ if m < 60:
260
+ return f"{m}m {s:.1f}s"
261
+ h, m = m // 60, m % 60
262
+ return f"{h}h {m}m {s:.0f}s"
263
+
264
+
265
+ def _print_grpo_reward_tail(trainer) -> None:
266
+ hist = getattr(trainer.state, "log_history", None) or []
267
+ if not hist:
268
+ print("(No log_history available for reward summary.)", flush=True)
269
+ return
270
+ print("\n--- Last GRPO log entries (rewards) ---", flush=True)
271
+ for row in hist[-5:]:
272
+ rbits = {k: v for k, v in row.items() if "reward" in k.lower() or k == "loss"}
273
+ if rbits:
274
+ print(f" step {row.get('step', '?')}: {rbits}", flush=True)
275
+
276
+
277
+ def _set_inference_mode(model) -> None:
278
+ if USE_UNSLOTH:
279
+ from unsloth import FastLanguageModel
280
+ FastLanguageModel.for_inference(model)
281
+ else:
282
+ model.eval()
283
+
284
+
285
+ def _generate_for_task(model, tokenizer, task: dict, max_new_tokens: int) -> str:
286
+ import torch
287
+ messages = [
288
+ {"role": "system", "content": SYSTEM_PROMPT},
289
+ {"role": "user", "content": build_prompt(task)},
290
+ ]
291
+ text = tokenizer.apply_chat_template(
292
+ messages, tokenize=False, add_generation_prompt=True
293
+ )
294
+ dev = next(model.parameters()).device
295
+ inputs = tokenizer(text, return_tensors="pt").to(dev)
296
+ with torch.inference_mode():
297
+ out = model.generate(
298
+ **inputs, max_new_tokens=max_new_tokens, do_sample=False
299
+ )
300
+ return tokenizer.decode(
301
+ out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
302
+ )
303
+
304
+
305
+ def _eval_task_status(raw: str, task: dict, took_sec: float, timeout_sec: float) -> str:
306
+ if took_sec > timeout_sec:
307
+ return "timeout"
308
+ pred = _strip_markdown_fences(_completion_to_text(raw))
309
+ gold = (task.get("correct_yaml") or "").strip()
310
+ p_can = _canonical_yaml(pred)
311
+ g_can = _canonical_yaml(gold)
312
+ if p_can and g_can and p_can == g_can:
313
+ return "correct"
314
+ return "wrong"
315
+
316
+
317
+ def run_final_task_eval(
318
+ model,
319
+ tokenizer,
320
+ max_new_tokens: int = MAX_COMPLETION_TOKENS,
321
+ timeout_sec: float = EVAL_GEN_TIMEOUT_SEC,
322
+ ) -> None:
323
+ """One generation per task; labels: correct, wrong, or timeout (if wall time > timeout_sec)."""
324
+ _set_inference_mode(model)
325
+ print(
326
+ f"\n========== EVAL: all {len(ALL_TASKS)} tasks (1 turn each; max_new_tokens={max_new_tokens}, "
327
+ f"timeout if wall time > {timeout_sec}s) ==========",
328
+ flush=True,
329
+ )
330
+ for task in ALL_TASKS:
331
+ tid = task.get("id", "?")
332
+ t0 = time.perf_counter()
333
+ try:
334
+ raw = _generate_for_task(model, tokenizer, task, max_new_tokens)
335
+ except Exception as e: # noqa: BLE001
336
+ took = time.perf_counter() - t0
337
+ print(
338
+ f" {tid}: error — {e!r} (after {took:.1f}s)",
339
+ flush=True,
340
+ )
341
+ continue
342
+ took = time.perf_counter() - t0
343
+ status = _eval_task_status(raw, task, took, timeout_sec)
344
+ r_fix = reward_fix_correctness(
345
+ [raw], [None], [task.get("correct_yaml", "")], [task["pipeline_yaml"]]
346
+ )[0]
347
+ r_stru = reward_yaml_structure([raw], [None])[0]
348
+ r_hallu = reward_no_hallucination([raw], [None])[0]
349
+ r_sum = r_fix + r_stru + r_hallu
350
+ print(
351
+ f" {tid}: {status:7s} | t={took:5.2f}s | rewards: total={r_sum:+.2f} "
352
+ f"(fix={r_fix:+.2f} struct={r_stru:+.2f} no_hallu={r_hallu:+.2f})",
353
+ flush=True,
354
+ )
355
+ print("========== EVAL end ==========\n", flush=True)
356
+
357
+
358
+ def _wandb_ok() -> bool:
359
+ try:
360
+ import wandb # noqa: F401
361
+ return True
362
+ except Exception:
363
+ return False
364
+
365
+
366
+ def run_sft(model, tokenizer, use_wandb: bool, sft_epochs: float):
367
+ from trl import SFTTrainer, SFTConfig
368
+
369
+ sft_data = build_sft_dataset(tokenizer)
370
+ print(f"SFT dataset: {len(sft_data)} samples, {sft_epochs} epoch(s)")
371
+
372
+ sft_config = SFTConfig(
373
+ output_dir="./cicd_rl_sft_output",
374
+ per_device_train_batch_size=BATCH_SIZE,
375
+ gradient_accumulation_steps=GRAD_ACCUM,
376
+ num_train_epochs=sft_epochs,
377
+ learning_rate=SFT_LEARNING_RATE,
378
+ logging_steps=10,
379
+ save_strategy="no",
380
+ max_length=SFT_MAX_SEQ,
381
+ dataset_text_field="text",
382
+ report_to="wandb" if use_wandb else "none",
383
+ remove_unused_columns=False,
384
+ optim="adamw_8bit",
385
+ # Train loss on assistant tokens only (full gold YAML in the assistant turn).
386
+ assistant_only_loss=True,
387
+ )
388
+ trainer = SFTTrainer(
389
+ model=model,
390
+ args=sft_config,
391
+ train_dataset=sft_data,
392
+ processing_class=tokenizer,
393
+ callbacks=[_sft_console_callback()],
394
+ )
395
+ if use_wandb:
396
+ import wandb
397
+ wandb.init(project="cicd-rl-agent", name="sft-cicd-yaml", reinit=True)
398
+ print("Starting SFT (supervised: prompt -> correct YAML)...")
399
+ trainer.train()
400
+ model.save_pretrained(SFT_OUTPUT)
401
+ tokenizer.save_pretrained(SFT_OUTPUT)
402
+ print(f"SFT LoRA saved to {SFT_OUTPUT}")
403
+ return trainer
404
+
405
+
406
+ def _post_train_smoke_unsloth(tokenizer, model) -> None:
407
+ import torch
408
+ from unsloth import FastLanguageModel
409
+
410
+ print("Testing post-training inference...")
411
+ FastLanguageModel.for_inference(model)
412
+ if not torch.cuda.is_available():
413
+ print("(CUDA not available; skip generate smoke test.)")
414
+ return
415
+ test_input = tokenizer("Fix this YAML: steps:\n - run: npm tset", return_tensors="pt").to("cuda")
416
+ with torch.inference_mode():
417
+ out = model.generate(**test_input, max_new_tokens=64)
418
+ print(tokenizer.decode(out[0], skip_special_tokens=True))
419
+
420
+
421
  def main():
422
+ p = argparse.ArgumentParser(description="SFT (optional) + GRPO training for CICD YAML fix agent")
423
+ p.add_argument(
424
+ "--stages",
425
+ type=str,
426
+ default="sft,grpo",
427
+ help="Comma list: sft, grpo (default: sft,grpo = supervised then RL)",
428
+ )
429
+ p.add_argument("--sft-epochs", type=float, default=SFT_EPOCHS, help="SFT pass size (set 0 to skip SFT in code paths that still use --stages; prefer --stages grpo)")
430
+ p.add_argument(
431
+ "--no-final-eval",
432
+ action="store_true",
433
+ help="Skip end-of-run eval (correct / wrong / timeout per task).",
434
+ )
435
+ p.add_argument(
436
+ "--eval-timeout",
437
+ type=float,
438
+ default=EVAL_GEN_TIMEOUT_SEC,
439
+ help="Mark task eval as 'timeout' if a single generate() takes longer than this (seconds).",
440
+ )
441
+ args = p.parse_args()
442
+ wants = {s.strip().lower() for s in args.stages.split(",") if s.strip()}
443
+ if not wants.issubset({"sft", "grpo"}) or not wants:
444
+ print("Error: --stages must list one or more of: sft, grpo (e.g. sft,grpo or grpo)")
445
+ sys.exit(1)
446
+
447
  # Colab often sets WANDB_DISABLED in the runtime env.
 
448
  if os.environ.get("WANDB_DISABLED", "").strip().lower() in {"1", "true", "yes", "on"}:
449
+ print("Detected WANDB_DISABLED; unsetting it because report_to may be 'wandb'.")
450
  os.environ.pop("WANDB_DISABLED", None)
451
 
452
  if USE_UNSLOTH:
453
  from unsloth import FastLanguageModel
454
  model, tokenizer = FastLanguageModel.from_pretrained(
455
+ model_name=MODEL_NAME, max_seq_length=1024, dtype=None, load_in_4bit=True
456
+ )
457
  model = FastLanguageModel.get_peft_model(
458
+ model,
459
+ r=16,
460
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
461
+ lora_alpha=16,
462
+ lora_dropout=0.0,
463
+ bias="none",
464
+ use_gradient_checkpointing="unsloth",
465
+ random_state=42,
466
+ )
467
  else:
468
  from transformers import AutoModelForCausalLM, AutoTokenizer
469
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
471
  if tokenizer.pad_token is None:
472
  tokenizer.pad_token = tokenizer.eos_token
473
 
474
+ use_wandb = _wandb_ok()
475
+ if not use_wandb:
476
+ print("wandb is not installed; falling back to report_to='none' where applicable.")
477
 
478
+ if "sft" in wants and args.sft_epochs <= 0:
479
+ print("Error: --sft-epochs must be > 0 when SFT is in --stages")
480
+ sys.exit(1)
 
 
 
 
481
 
482
+ t_start = time.perf_counter()
483
+ sft_time_s = 0.0
484
+ grpo_time_s = 0.0
485
+ sft_steps = 0
486
+ grpo_steps = 0
487
+ sft_trainer = None
488
+ grpo_trainer = None
 
 
 
 
 
 
 
489
 
490
+ if "sft" in wants:
491
+ t0 = time.perf_counter()
492
+ sft_trainer = run_sft(model, tokenizer, use_wandb, float(args.sft_epochs))
493
+ sft_time_s = time.perf_counter() - t0
494
+ sft_steps = getattr(sft_trainer.state, "global_step", 0) if sft_trainer else 0
495
+ print(
496
+ f"--- SFT done: {sft_steps} optimizer turn(s) / step(s), time {_format_seconds(sft_time_s)} ---\n",
497
+ flush=True,
498
+ )
499
+
500
+ if "grpo" in wants:
501
+ dataset = build_dataset()
502
+ print(f"GRPO dataset: {len(dataset)} samples")
503
+ from trl import GRPOTrainer, GRPOConfig
504
+
505
+ grpo_args = GRPOConfig(
506
+ output_dir="./cicd_rl_output",
507
+ per_device_train_batch_size=BATCH_SIZE,
508
+ gradient_accumulation_steps=GRAD_ACCUM,
509
+ learning_rate=5e-6,
510
+ max_steps=MAX_STEPS,
511
+ num_generations=4,
512
+ max_completion_length=MAX_COMPLETION_TOKENS,
513
+ logging_steps=5,
514
+ save_steps=50,
515
+ report_to="wandb" if use_wandb else "none",
516
+ remove_unused_columns=False,
517
+ warmup_steps=10,
518
+ lr_scheduler_type="cosine",
519
+ optim="adamw_8bit",
520
+ )
521
+ grpo_trainer = GRPOTrainer(
522
+ model=model,
523
+ args=grpo_args,
524
+ reward_funcs=REWARD_FUNCTIONS,
525
+ train_dataset=dataset,
526
+ processing_class=tokenizer,
527
+ callbacks=[_grpo_console_callback(MAX_STEPS, "GRPO")],
528
+ )
529
+ print("Starting GRPO training... (rewards + loss in log lines; online reward below)\n", flush=True)
530
+ if use_wandb:
531
+ import wandb
532
+ wandb.init(project="cicd-rl-agent", name="grpo-cicd-yaml", reinit=True)
533
+ t0 = time.perf_counter()
534
+ grpo_trainer.train()
535
+ grpo_time_s = time.perf_counter() - t0
536
+ grpo_steps = getattr(grpo_trainer.state, "global_step", 0)
537
+ print("GRPO training complete!", flush=True)
538
+ _print_grpo_reward_tail(grpo_trainer)
539
+ print(
540
+ f"\n--- GRPO done: {grpo_steps} optimizer turn(s) / step(s) (of {MAX_STEPS} max), "
541
+ f'time { _format_seconds(grpo_time_s) } ---\n',
542
+ flush=True,
543
+ )
544
 
545
  save_path = "./cicd_rl_agent_final"
546
+ if "grpo" in wants:
547
  model.save_pretrained(save_path)
548
  tokenizer.save_pretrained(save_path)
549
+ print(f"Final LoRA saved to {save_path} (SFT+GRPO pipeline end state).")
550
+ if USE_UNSLOTH:
551
+ _post_train_smoke_unsloth(tokenizer, model)
552
+ else:
553
+ print("Non-Unsloth path: inference test skipped.")
554
+ elif "sft" in wants:
555
+ # SFT weights already written in run_sft(); also mirror to default eval path for convenience.
556
  model.save_pretrained(save_path)
557
  tokenizer.save_pretrained(save_path)
558
+ print(f"SFT-only run: LoRA is in {SFT_OUTPUT} and copied to {save_path} for eval_lora defaults.")
559
+
560
+ total_s = time.perf_counter() - t_start
561
+ print("\n========== TRAINING SUMMARY ==========", flush=True)
562
+ print(f"Total wall time: {_format_seconds(total_s)}", flush=True)
563
+ if sft_time_s:
564
+ print(
565
+ f" SFT: time={_format_seconds(sft_time_s)} | turn(s)/step(s) = {sft_steps} | (supervised, loss in [SFT turn/step ...] lines)",
566
+ flush=True,
567
+ )
568
+ if grpo_time_s:
569
+ print(
570
+ f" GRPO: time={_format_seconds(grpo_time_s)} | turn(s)/step(s) = {grpo_steps} | (online rewards in [GRPO turn/step ...] lines)",
571
+ flush=True,
572
+ )
573
+ print(
574
+ " Note: each eval task is a single user→assistant 'turn'; GRPO/SFT 'turns' = optimizer update steps.\n"
575
+ "========================================\n",
576
+ flush=True,
577
+ )
578
+
579
+ if not args.no_final_eval and (sft_time_s or grpo_time_s):
580
+ run_final_task_eval(
581
+ model, tokenizer, MAX_COMPLETION_TOKENS, timeout_sec=float(args.eval_timeout)
582
+ )
583
+ elif args.no_final_eval:
584
+ print("Skipped final per-task eval (--no-final-eval).", flush=True)
585
+
586
 
587
  if __name__ == "__main__":
588
  main()
train_colab.ipynb CHANGED
@@ -9,12 +9,12 @@
9
  },
10
  {
11
  "cell_type": "code",
12
- "execution_count": null,
13
  "metadata": {},
14
- "outputs": [],
15
  "source": [
16
  "!pip install unsloth trl transformers datasets torch wandb pydantic"
17
- ]
 
 
18
  },
19
  {
20
  "cell_type": "markdown",
@@ -25,9 +25,7 @@
25
  },
26
  {
27
  "cell_type": "code",
28
- "execution_count": null,
29
  "metadata": {},
30
- "outputs": [],
31
  "source": [
32
  "import os\n",
33
  "import random\n",
@@ -82,7 +80,9 @@
82
  " return Dataset.from_list(records)\n",
83
  "\n",
84
  "print(f\"Loaded {len(ALL_TASKS)} tasks (easy/medium/hard). Sample task ids:\", [t['id'] for t in ALL_TASKS[:3]], \"...\")"
85
- ]
 
 
86
  },
87
  {
88
  "cell_type": "markdown",
@@ -93,9 +93,7 @@
93
  },
94
  {
95
  "cell_type": "code",
96
- "execution_count": null,
97
  "metadata": {},
98
- "outputs": [],
99
  "source": [
100
  "import torch\n",
101
  "from unsloth import FastLanguageModel\n",
@@ -119,7 +117,9 @@
119
  ")\n",
120
  "if tokenizer.pad_token is None:\n",
121
  " tokenizer.pad_token = tokenizer.eos_token"
122
- ]
 
 
123
  },
124
  {
125
  "cell_type": "markdown",
@@ -130,13 +130,13 @@
130
  },
131
  {
132
  "cell_type": "code",
133
- "execution_count": null,
134
  "metadata": {},
135
- "outputs": [],
136
  "source": [
137
  "train_dataset = build_dataset()\n",
138
  "print(f\"Dataset size: {len(train_dataset)} (target split ~50% easy / 30% medium / 20% hard)\")"
139
- ]
 
 
140
  },
141
  {
142
  "cell_type": "markdown",
@@ -147,9 +147,7 @@
147
  },
148
  {
149
  "cell_type": "code",
150
- "execution_count": null,
151
  "metadata": {},
152
- "outputs": [],
153
  "source": [
154
  "def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs):\n",
155
  " \"\"\"How closely the completion matches the reference `correct_yaml` (full match, partial, unchanged, or wrong).\"\"\"\n",
@@ -188,20 +186,29 @@
188
  " return [-0.3 if any(p.lower() in c.lower() for p in bad) else 0.3 for c in completions]\n",
189
  "\n",
190
  "REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]"
191
- ]
 
 
192
  },
193
  {
194
  "cell_type": "markdown",
195
  "metadata": {},
196
  "source": [
197
- "## 🚀 Configure and Run GRPO Training"
 
 
 
 
 
 
 
 
 
198
  ]
199
  },
200
  {
201
  "cell_type": "code",
202
- "execution_count": null,
203
  "metadata": {},
204
- "outputs": [],
205
  "source": [
206
  "import wandb\n",
207
  "from trl import GRPOConfig, GRPOTrainer\n",
@@ -232,7 +239,9 @@
232
  ")\n",
233
  "wandb.init(project=\"cicd-rl-agent\")\n",
234
  "trainer.train()"
235
- ]
 
 
236
  },
237
  {
238
  "cell_type": "markdown",
@@ -243,9 +252,7 @@
243
  },
244
  {
245
  "cell_type": "code",
246
- "execution_count": null,
247
  "metadata": {},
248
- "outputs": [],
249
  "source": [
250
  "import matplotlib.pyplot as plt\n",
251
  "\n",
@@ -269,7 +276,9 @@
269
  "plt.tight_layout()\n",
270
  "plt.savefig(\"reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
271
  "plt.show()"
272
- ]
 
 
273
  },
274
  {
275
  "cell_type": "markdown",
@@ -280,9 +289,7 @@
280
  },
281
  {
282
  "cell_type": "code",
283
- "execution_count": null,
284
  "metadata": {},
285
- "outputs": [],
286
  "source": [
287
  "def generate_yaml(model, tok, task: dict) -> str:\n",
288
  " FastLanguageModel.for_inference(model)\n",
@@ -322,7 +329,9 @@
322
  " print(out_train[:800])\n",
323
  " print(f\"\\nBase matches correct_yaml: {ok_base}\")\n",
324
  " print(f\"Trained matches correct_yaml: {ok_train}\")"
325
- ]
 
 
326
  }
327
  ],
328
  "metadata": {
@@ -338,4 +347,4 @@
338
  },
339
  "nbformat": 4,
340
  "nbformat_minor": 4
341
- }
 
9
  },
10
  {
11
  "cell_type": "code",
 
12
  "metadata": {},
 
13
  "source": [
14
  "!pip install unsloth trl transformers datasets torch wandb pydantic"
15
+ ],
16
+ "execution_count": null,
17
+ "outputs": []
18
  },
19
  {
20
  "cell_type": "markdown",
 
25
  },
26
  {
27
  "cell_type": "code",
 
28
  "metadata": {},
 
29
  "source": [
30
  "import os\n",
31
  "import random\n",
 
80
  " return Dataset.from_list(records)\n",
81
  "\n",
82
  "print(f\"Loaded {len(ALL_TASKS)} tasks (easy/medium/hard). Sample task ids:\", [t['id'] for t in ALL_TASKS[:3]], \"...\")"
83
+ ],
84
+ "execution_count": null,
85
+ "outputs": []
86
  },
87
  {
88
  "cell_type": "markdown",
 
93
  },
94
  {
95
  "cell_type": "code",
 
96
  "metadata": {},
 
97
  "source": [
98
  "import torch\n",
99
  "from unsloth import FastLanguageModel\n",
 
117
  ")\n",
118
  "if tokenizer.pad_token is None:\n",
119
  " tokenizer.pad_token = tokenizer.eos_token"
120
+ ],
121
+ "execution_count": null,
122
+ "outputs": []
123
  },
124
  {
125
  "cell_type": "markdown",
 
130
  },
131
  {
132
  "cell_type": "code",
 
133
  "metadata": {},
 
134
  "source": [
135
  "train_dataset = build_dataset()\n",
136
  "print(f\"Dataset size: {len(train_dataset)} (target split ~50% easy / 30% medium / 20% hard)\")"
137
+ ],
138
+ "execution_count": null,
139
+ "outputs": []
140
  },
141
  {
142
  "cell_type": "markdown",
 
147
  },
148
  {
149
  "cell_type": "code",
 
150
  "metadata": {},
 
151
  "source": [
152
  "def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs):\n",
153
  " \"\"\"How closely the completion matches the reference `correct_yaml` (full match, partial, unchanged, or wrong).\"\"\"\n",
 
186
  " return [-0.3 if any(p.lower() in c.lower() for p in bad) else 0.3 for c in completions]\n",
187
  "\n",
188
  "REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]"
189
+ ],
190
+ "execution_count": null,
191
+ "outputs": []
192
  },
193
  {
194
  "cell_type": "markdown",
195
  "metadata": {},
196
  "source": [
197
+ "## 🚀 Training: SFT + GRPO (recommended) or GRPO in-notebook\n",
198
+ "\n",
199
+ "**Best path (matches `train.py` in the repo):** in the repo root run:\n",
200
+ "`!cd $REPO_DIR && python train.py` \n",
201
+ "Default is a short **supervised (SFT)** pass on exact `correct_yaml`, then **GRPO** with correctness-weighted rewards. \n",
202
+ "- GRPO only (old one-stage): `python train.py --stages grpo` \n",
203
+ "- SFT only: `python train.py --stages sft` \n",
204
+ "- Two SFT epochs: `python train.py --sft-epochs 2`\n",
205
+ "\n",
206
+ "**Alternative below:** the next cell runs **GRPO only** in the notebook (no SFT), like older Colab flows."
207
  ]
208
  },
209
  {
210
  "cell_type": "code",
 
211
  "metadata": {},
 
212
  "source": [
213
  "import wandb\n",
214
  "from trl import GRPOConfig, GRPOTrainer\n",
 
239
  ")\n",
240
  "wandb.init(project=\"cicd-rl-agent\")\n",
241
  "trainer.train()"
242
+ ],
243
+ "execution_count": null,
244
+ "outputs": []
245
  },
246
  {
247
  "cell_type": "markdown",
 
252
  },
253
  {
254
  "cell_type": "code",
 
255
  "metadata": {},
 
256
  "source": [
257
  "import matplotlib.pyplot as plt\n",
258
  "\n",
 
276
  "plt.tight_layout()\n",
277
  "plt.savefig(\"reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
278
  "plt.show()"
279
+ ],
280
+ "execution_count": null,
281
+ "outputs": []
282
  },
283
  {
284
  "cell_type": "markdown",
 
289
  },
290
  {
291
  "cell_type": "code",
 
292
  "metadata": {},
 
293
  "source": [
294
  "def generate_yaml(model, tok, task: dict) -> str:\n",
295
  " FastLanguageModel.for_inference(model)\n",
 
329
  " print(out_train[:800])\n",
330
  " print(f\"\\nBase matches correct_yaml: {ok_base}\")\n",
331
  " print(f\"Trained matches correct_yaml: {ok_train}\")"
332
+ ],
333
+ "execution_count": null,
334
+ "outputs": []
335
  }
336
  ],
337
  "metadata": {
 
347
  },
348
  "nbformat": 4,
349
  "nbformat_minor": 4
350
+ }