| """ |
| train.py — CICD RL Agent: optional SFT (supervised) then GRPO (RL) on CI/CD YAML fixes. |
| |
| Default: short SFT on (prompt → correct_yaml), then GRPO with correctness-heavy rewards. |
| |
| python train.py # SFT (short) + GRPO (same as before) |
| python train.py --stages grpo # GRPO only (old behavior, no SFT) |
| python train.py --stages sft # SFT only; saves ./cicd_rl_sft_lora |
| train.py --stages sft,grpo --sft-epochs 2 |
| train.py --no-final-eval |
| train.py --eval-timeout 90 |
| |
| Console: SFT/GRPO log lines (loss/rewards + step X/Y), per-stage times and step counts, then |
| a final eval of every task with correct/wrong/timeout, wall time, and reward breakdown. |
| |
| Requires: pip install unsloth trl datasets transformers |
| """ |
|
|
| import argparse |
| import os, re, sys |
| import time |
| sys.path.insert(0, os.path.dirname(__file__)) |
| try: |
| import yaml |
| except Exception: |
| yaml = None |
|
|
| USE_UNSLOTH = True |
| if USE_UNSLOTH: |
| import unsloth |
| MODEL_NAME = "unsloth/Qwen2.5-0.5B-Instruct" |
| MAX_STEPS = 300 |
| |
| BATCH_SIZE = 4 |
| GRAD_ACCUM = 4 |
| NUM_SAMPLES = 512 |
| |
| MAX_COMPLETION_TOKENS = 128 |
|
|
| |
| SFT_EPOCHS = 1 |
| SFT_LEARNING_RATE = 2e-4 |
| SFT_MAX_SEQ = 1024 |
| SFT_DATASET_SIZE = 512 |
| SFT_OUTPUT = "./cicd_rl_sft_lora" |
|
|
| |
| EVAL_GEN_TIMEOUT_SEC = 60.0 |
|
|
| |
| REWARD_FIX_MATCH = 5.0 |
| REWARD_FIX_MISS = -1.5 |
| REWARD_STRUCT_SCALE = 0.2 |
| REWARD_HALLU_GOOD = 0.1 |
| REWARD_HALLU_BAD = -0.35 |
|
|
| from cicd_debug_env.tasks import ALL_TASKS |
| from datasets import Dataset |
| import random |
|
|
| SYSTEM_PROMPT = ( |
| "You are an expert DevOps engineer. " |
| "You receive a broken CI/CD pipeline YAML and error details. " |
| "Output ONLY the corrected YAML — no explanation, no markdown fences." |
| ) |
|
|
| def build_prompt(task: dict) -> str: |
| return ( |
| f"### Error\n{task.get('error_message', '')}\n\n" |
| f"### Broken Pipeline\n{task['pipeline_yaml']}\n\n" |
| f"### Fixed Pipeline (YAML only):\n" |
| ) |
|
|
| def build_dataset(): |
| easy = [t for t in ALL_TASKS if t["difficulty"] == "easy"] |
| medium = [t for t in ALL_TASKS if t["difficulty"] == "medium"] |
| hard = [t for t in ALL_TASKS if t["difficulty"] == "hard"] |
| records = [] |
| for _ in range(NUM_SAMPLES): |
| r = random.random() |
| task = random.choice(easy if r < 0.5 else medium if r < 0.8 else hard) |
| records.append({ |
| "prompt": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": build_prompt(task)}, |
| ], |
| "correct_yaml": task.get("correct_yaml", ""), |
| "pipeline_yaml": task["pipeline_yaml"], |
| }) |
| return Dataset.from_list(records) |
|
|
|
|
| def build_sft_dataset(tokenizer) -> Dataset: |
| """Supervised (prompt, assistant) = same chat format as inference; target is exact correct_yaml.""" |
| easy = [t for t in ALL_TASKS if t["difficulty"] == "easy"] |
| medium = [t for t in ALL_TASKS if t["difficulty"] == "medium"] |
| hard = [t for t in ALL_TASKS if t["difficulty"] == "hard"] |
| records = [] |
| for _ in range(SFT_DATASET_SIZE): |
| r = random.random() |
| task = random.choice(easy if r < 0.5 else medium if r < 0.8 else hard) |
| gold = (task.get("correct_yaml") or "").strip() |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": build_prompt(task)}, |
| {"role": "assistant", "content": gold}, |
| ] |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=False |
| ) |
| records.append({"text": text}) |
| return Dataset.from_list(records) |
|
|
|
|
| def _completion_to_text(completion) -> str: |
| """ |
| Normalize TRL/Unsloth completion payloads to plain text. |
| `completion` can be a string, dict, or list of message chunks. |
| """ |
| if isinstance(completion, str): |
| return completion |
| if isinstance(completion, dict): |
| if isinstance(completion.get("content"), str): |
| return completion["content"] |
| if isinstance(completion.get("text"), str): |
| return completion["text"] |
| return str(completion) |
| if isinstance(completion, list): |
| parts = [] |
| for item in completion: |
| if isinstance(item, str): |
| parts.append(item) |
| elif isinstance(item, dict): |
| content = item.get("content", item.get("text", "")) |
| if isinstance(content, str): |
| parts.append(content) |
| elif content is not None: |
| parts.append(str(content)) |
| elif item is not None: |
| parts.append(str(item)) |
| return "\n".join(p for p in parts if p) |
| return "" if completion is None else str(completion) |
|
|
| def _strip_markdown_fences(text: str) -> str: |
| t = text.strip() |
| if t.startswith("```"): |
| t = re.sub(r"^```[a-zA-Z0-9_-]*\n?", "", t) |
| t = re.sub(r"\n?```$", "", t.strip()) |
| return t.strip() |
|
|
| def _normalize_yaml_like(text: str) -> str: |
| lines = [line.rstrip() for line in text.splitlines()] |
| lines = [line for line in lines if line.strip()] |
| return "\n".join(lines).strip() |
|
|
| def _canonical_yaml(text: str) -> str: |
| stripped = _normalize_yaml_like(_strip_markdown_fences(text)) |
| if not stripped: |
| return "" |
| if yaml is None: |
| return stripped |
| try: |
| parsed = yaml.safe_load(stripped) |
| return yaml.safe_dump(parsed, sort_keys=True).strip() |
| except Exception: |
| return stripped |
|
|
| def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs): |
| rewards = [] |
| for c, correct, _broken in zip(completions, correct_yaml, pipeline_yaml): |
| pred = _completion_to_text(c) |
| pred_canon = _canonical_yaml(pred) |
| correct_canon = _canonical_yaml(correct) |
| |
| ok = bool(pred_canon and pred_canon == correct_canon) |
| rewards.append(REWARD_FIX_MATCH if ok else REWARD_FIX_MISS) |
| return rewards |
|
|
| def reward_yaml_structure(completions, prompts, **kwargs): |
| rewards = [] |
| for c in completions: |
| t = _strip_markdown_fences(_completion_to_text(c)) |
| lines = [x for x in t.splitlines() if x.strip()] |
| starts_yaml = t.startswith(("name:", "jobs:", "steps:", "on:", "env:", "- ")) |
| has_yaml_keys = any(k in t for k in ["steps:", "jobs:", "name:", "run:", "uses:", "env:", "with:"]) |
| line_count_ok = 1 <= len(lines) <= 120 |
| has_prose_or_md = any( |
| p in t.lower() |
| for p in ["here is", "explanation", "i fixed", "this yaml", "```", "---", "note:"] |
| ) |
| |
| score = 0.4 * int(starts_yaml) + 0.4 * int(has_yaml_keys) + 0.2 * int(line_count_ok) |
| if has_prose_or_md: |
| score -= 1.0 |
| rewards.append(score * REWARD_STRUCT_SCALE) |
| return rewards |
|
|
| def reward_no_hallucination(completions, prompts, **kwargs): |
| bad = [ |
| "i cannot", "i am sorry", "as an ai", "here is", "```yaml", "```", |
| "explanation:", "note:", "sure!", "of course", "the fixed yaml", "this yaml", |
| ] |
| values = [] |
| for c in completions: |
| lower = _completion_to_text(c).lower() |
| bad_hits = sum(1 for p in bad if p in lower) |
| values.append(REWARD_HALLU_BAD if bad_hits > 0 else REWARD_HALLU_GOOD) |
| return values |
|
|
| REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination] |
|
|
|
|
| def _grpo_console_callback(max_steps: int, label: str = "GRPO"): |
| from transformers import TrainerCallback |
|
|
| class _GRPOConsoleLogCallback(TrainerCallback): |
| def __init__(self) -> None: |
| self._max = max_steps |
| self._label = label |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if not logs: |
| return |
| parts = [f"[{self._label} turn/step {state.global_step}/{self._max}]"] |
| for k in sorted(logs.keys()): |
| kl = k.lower() |
| if "reward" in kl or k in ("loss", "kl", "learning_rate", "train_loss") or "loss" in kl: |
| v = logs[k] |
| if isinstance(v, (int, float)): |
| parts.append(f"{k}={v:.6g}") |
| else: |
| parts.append(f"{k}={v}") |
| print(" | ".join(parts), flush=True) |
|
|
| return _GRPOConsoleLogCallback() |
|
|
|
|
| def _sft_console_callback(): |
| from transformers import TrainerCallback |
|
|
| class _SFTConsoleLogCallback(TrainerCallback): |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if not logs: |
| return |
| line = f"[SFT turn/step {state.global_step}]" |
| for k, v in sorted(logs.items()): |
| if "loss" in k.lower() or "learning_rate" in k: |
| if isinstance(v, (int, float)): |
| line += f" {k}={v:.6g}" |
| print(line, flush=True) |
|
|
| return _SFTConsoleLogCallback() |
|
|
|
|
| def _format_seconds(sec: float) -> str: |
| if sec < 60: |
| return f"{sec:.1f}s" |
| m, s = int(sec // 60), sec % 60 |
| if m < 60: |
| return f"{m}m {s:.1f}s" |
| h, m = m // 60, m % 60 |
| return f"{h}h {m}m {s:.0f}s" |
|
|
|
|
| def _print_grpo_reward_tail(trainer) -> None: |
| hist = getattr(trainer.state, "log_history", None) or [] |
| if not hist: |
| print("(No log_history available for reward summary.)", flush=True) |
| return |
| print("\n--- Last GRPO log entries (rewards) ---", flush=True) |
| for row in hist[-5:]: |
| rbits = {k: v for k, v in row.items() if "reward" in k.lower() or k == "loss"} |
| if rbits: |
| print(f" step {row.get('step', '?')}: {rbits}", flush=True) |
|
|
|
|
| def _set_inference_mode(model) -> None: |
| if USE_UNSLOTH: |
| from unsloth import FastLanguageModel |
| FastLanguageModel.for_inference(model) |
| else: |
| model.eval() |
|
|
|
|
| def _generate_for_task(model, tokenizer, task: dict, max_new_tokens: int) -> str: |
| import torch |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": build_prompt(task)}, |
| ] |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| dev = next(model.parameters()).device |
| inputs = tokenizer(text, return_tensors="pt").to(dev) |
| with torch.inference_mode(): |
| out = model.generate( |
| **inputs, max_new_tokens=max_new_tokens, do_sample=False |
| ) |
| return tokenizer.decode( |
| out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True |
| ) |
|
|
|
|
| def _eval_task_status(raw: str, task: dict, took_sec: float, timeout_sec: float) -> str: |
| if took_sec > timeout_sec: |
| return "timeout" |
| pred = _strip_markdown_fences(_completion_to_text(raw)) |
| gold = (task.get("correct_yaml") or "").strip() |
| p_can = _canonical_yaml(pred) |
| g_can = _canonical_yaml(gold) |
| if p_can and g_can and p_can == g_can: |
| return "correct" |
| return "wrong" |
|
|
|
|
| def run_final_task_eval( |
| model, |
| tokenizer, |
| max_new_tokens: int = MAX_COMPLETION_TOKENS, |
| timeout_sec: float = EVAL_GEN_TIMEOUT_SEC, |
| ) -> None: |
| """One generation per task; labels: correct, wrong, or timeout (if wall time > timeout_sec).""" |
| _set_inference_mode(model) |
| print( |
| f"\n========== EVAL: all {len(ALL_TASKS)} tasks (1 turn each; max_new_tokens={max_new_tokens}, " |
| f"timeout if wall time > {timeout_sec}s) ==========", |
| flush=True, |
| ) |
| for task in ALL_TASKS: |
| tid = task.get("id", "?") |
| t0 = time.perf_counter() |
| try: |
| raw = _generate_for_task(model, tokenizer, task, max_new_tokens) |
| except Exception as e: |
| took = time.perf_counter() - t0 |
| print( |
| f" {tid}: error — {e!r} (after {took:.1f}s)", |
| flush=True, |
| ) |
| continue |
| took = time.perf_counter() - t0 |
| status = _eval_task_status(raw, task, took, timeout_sec) |
| r_fix = reward_fix_correctness( |
| [raw], [None], [task.get("correct_yaml", "")], [task["pipeline_yaml"]] |
| )[0] |
| r_stru = reward_yaml_structure([raw], [None])[0] |
| r_hallu = reward_no_hallucination([raw], [None])[0] |
| r_sum = r_fix + r_stru + r_hallu |
| print( |
| f" {tid}: {status:7s} | t={took:5.2f}s | rewards: total={r_sum:+.2f} " |
| f"(fix={r_fix:+.2f} struct={r_stru:+.2f} no_hallu={r_hallu:+.2f})", |
| flush=True, |
| ) |
| print("========== EVAL end ==========\n", flush=True) |
|
|
|
|
| def _wandb_ok() -> bool: |
| try: |
| import wandb |
| return True |
| except Exception: |
| return False |
|
|
|
|
| def run_sft(model, tokenizer, use_wandb: bool, sft_epochs: float): |
| from trl import SFTTrainer, SFTConfig |
|
|
| sft_data = build_sft_dataset(tokenizer) |
| print(f"SFT dataset: {len(sft_data)} samples, {sft_epochs} epoch(s)") |
|
|
| sft_config = SFTConfig( |
| output_dir="./cicd_rl_sft_output", |
| per_device_train_batch_size=BATCH_SIZE, |
| gradient_accumulation_steps=GRAD_ACCUM, |
| num_train_epochs=sft_epochs, |
| learning_rate=SFT_LEARNING_RATE, |
| logging_steps=10, |
| save_strategy="no", |
| max_length=SFT_MAX_SEQ, |
| dataset_text_field="text", |
| report_to="wandb" if use_wandb else "none", |
| remove_unused_columns=False, |
| optim="adamw_8bit", |
| |
| assistant_only_loss=True, |
| ) |
| trainer = SFTTrainer( |
| model=model, |
| args=sft_config, |
| train_dataset=sft_data, |
| processing_class=tokenizer, |
| callbacks=[_sft_console_callback()], |
| ) |
| if use_wandb: |
| import wandb |
| wandb.init(project="cicd-rl-agent", name="sft-cicd-yaml", reinit=True) |
| print("Starting SFT (supervised: prompt -> correct YAML)...") |
| trainer.train() |
| model.save_pretrained(SFT_OUTPUT) |
| tokenizer.save_pretrained(SFT_OUTPUT) |
| print(f"SFT LoRA saved to {SFT_OUTPUT}") |
| return trainer |
|
|
|
|
| def _post_train_smoke_unsloth(tokenizer, model) -> None: |
| import torch |
| from unsloth import FastLanguageModel |
|
|
| print("Testing post-training inference...") |
| FastLanguageModel.for_inference(model) |
| if not torch.cuda.is_available(): |
| print("(CUDA not available; skip generate smoke test.)") |
| return |
| test_input = tokenizer("Fix this YAML: steps:\n - run: npm tset", return_tensors="pt").to("cuda") |
| with torch.inference_mode(): |
| out = model.generate(**test_input, max_new_tokens=64) |
| print(tokenizer.decode(out[0], skip_special_tokens=True)) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser(description="SFT (optional) + GRPO training for CICD YAML fix agent") |
| p.add_argument( |
| "--stages", |
| type=str, |
| default="sft,grpo", |
| help="Comma list: sft, grpo (default: sft,grpo = supervised then RL)", |
| ) |
| p.add_argument("--sft-epochs", type=float, default=SFT_EPOCHS, help="SFT pass size (set 0 to skip SFT in code paths that still use --stages; prefer --stages grpo)") |
| p.add_argument( |
| "--no-final-eval", |
| action="store_true", |
| help="Skip end-of-run eval (correct / wrong / timeout per task).", |
| ) |
| p.add_argument( |
| "--eval-timeout", |
| type=float, |
| default=EVAL_GEN_TIMEOUT_SEC, |
| help="Mark task eval as 'timeout' if a single generate() takes longer than this (seconds).", |
| ) |
| args = p.parse_args() |
| wants = {s.strip().lower() for s in args.stages.split(",") if s.strip()} |
| if not wants.issubset({"sft", "grpo"}) or not wants: |
| print("Error: --stages must list one or more of: sft, grpo (e.g. sft,grpo or grpo)") |
| sys.exit(1) |
|
|
| |
| if os.environ.get("WANDB_DISABLED", "").strip().lower() in {"1", "true", "yes", "on"}: |
| print("Detected WANDB_DISABLED; unsetting it because report_to may be 'wandb'.") |
| os.environ.pop("WANDB_DISABLED", None) |
|
|
| if USE_UNSLOTH: |
| from unsloth import FastLanguageModel |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=MODEL_NAME, max_seq_length=1024, dtype=None, load_in_4bit=True |
| ) |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=16, |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], |
| lora_alpha=16, |
| lora_dropout=0.0, |
| bias="none", |
| use_gradient_checkpointing="unsloth", |
| random_state=42, |
| ) |
| else: |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| use_wandb = _wandb_ok() |
| if not use_wandb: |
| print("wandb is not installed; falling back to report_to='none' where applicable.") |
|
|
| if "sft" in wants and args.sft_epochs <= 0: |
| print("Error: --sft-epochs must be > 0 when SFT is in --stages") |
| sys.exit(1) |
|
|
| t_start = time.perf_counter() |
| sft_time_s = 0.0 |
| grpo_time_s = 0.0 |
| sft_steps = 0 |
| grpo_steps = 0 |
| sft_trainer = None |
| grpo_trainer = None |
|
|
| if "sft" in wants: |
| t0 = time.perf_counter() |
| sft_trainer = run_sft(model, tokenizer, use_wandb, float(args.sft_epochs)) |
| sft_time_s = time.perf_counter() - t0 |
| sft_steps = getattr(sft_trainer.state, "global_step", 0) if sft_trainer else 0 |
| print( |
| f"--- SFT done: {sft_steps} optimizer turn(s) / step(s), time {_format_seconds(sft_time_s)} ---\n", |
| flush=True, |
| ) |
|
|
| if "grpo" in wants: |
| dataset = build_dataset() |
| print(f"GRPO dataset: {len(dataset)} samples") |
| from trl import GRPOTrainer, GRPOConfig |
|
|
| grpo_args = GRPOConfig( |
| output_dir="./cicd_rl_output", |
| per_device_train_batch_size=BATCH_SIZE, |
| gradient_accumulation_steps=GRAD_ACCUM, |
| learning_rate=5e-6, |
| max_steps=MAX_STEPS, |
| num_generations=4, |
| max_completion_length=MAX_COMPLETION_TOKENS, |
| logging_steps=5, |
| save_steps=50, |
| report_to="wandb" if use_wandb else "none", |
| remove_unused_columns=False, |
| warmup_steps=10, |
| lr_scheduler_type="cosine", |
| optim="adamw_8bit", |
| ) |
| grpo_trainer = GRPOTrainer( |
| model=model, |
| args=grpo_args, |
| reward_funcs=REWARD_FUNCTIONS, |
| train_dataset=dataset, |
| processing_class=tokenizer, |
| callbacks=[_grpo_console_callback(MAX_STEPS, "GRPO")], |
| ) |
| print("Starting GRPO training... (rewards + loss in log lines; online reward below)\n", flush=True) |
| if use_wandb: |
| import wandb |
| wandb.init(project="cicd-rl-agent", name="grpo-cicd-yaml", reinit=True) |
| t0 = time.perf_counter() |
| grpo_trainer.train() |
| grpo_time_s = time.perf_counter() - t0 |
| grpo_steps = getattr(grpo_trainer.state, "global_step", 0) |
| print("GRPO training complete!", flush=True) |
| _print_grpo_reward_tail(grpo_trainer) |
| print( |
| f"\n--- GRPO done: {grpo_steps} optimizer turn(s) / step(s) (of {MAX_STEPS} max), " |
| f'time { _format_seconds(grpo_time_s) } ---\n', |
| flush=True, |
| ) |
|
|
| save_path = "./cicd_rl_agent_final" |
| if "grpo" in wants: |
| model.save_pretrained(save_path) |
| tokenizer.save_pretrained(save_path) |
| print(f"Final LoRA saved to {save_path} (SFT+GRPO pipeline end state).") |
| if USE_UNSLOTH: |
| _post_train_smoke_unsloth(tokenizer, model) |
| else: |
| print("Non-Unsloth path: inference test skipped.") |
| elif "sft" in wants: |
| |
| model.save_pretrained(save_path) |
| tokenizer.save_pretrained(save_path) |
| print(f"SFT-only run: LoRA is in {SFT_OUTPUT} and copied to {save_path} for eval_lora defaults.") |
|
|
| total_s = time.perf_counter() - t_start |
| print("\n========== TRAINING SUMMARY ==========", flush=True) |
| print(f"Total wall time: {_format_seconds(total_s)}", flush=True) |
| if sft_time_s: |
| print( |
| f" SFT: time={_format_seconds(sft_time_s)} | turn(s)/step(s) = {sft_steps} | (supervised, loss in [SFT turn/step ...] lines)", |
| flush=True, |
| ) |
| if grpo_time_s: |
| print( |
| f" GRPO: time={_format_seconds(grpo_time_s)} | turn(s)/step(s) = {grpo_steps} | (online rewards in [GRPO turn/step ...] lines)", |
| flush=True, |
| ) |
| print( |
| " Note: each eval task is a single user→assistant 'turn'; GRPO/SFT 'turns' = optimizer update steps.\n" |
| "========================================\n", |
| flush=True, |
| ) |
|
|
| if not args.no_final_eval and (sft_time_s or grpo_time_s): |
| run_final_task_eval( |
| model, tokenizer, MAX_COMPLETION_TOKENS, timeout_sec=float(args.eval_timeout) |
| ) |
| elif args.no_final_eval: |
| print("Skipped final per-task eval (--no-final-eval).", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|