cicd-rl-agent / train.py
Nikitasoni22's picture
training issue resolved
048dc4f
"""
train.py — CICD RL Agent: optional SFT (supervised) then GRPO (RL) on CI/CD YAML fixes.
Default: short SFT on (prompt → correct_yaml), then GRPO with correctness-heavy rewards.
python train.py # SFT (short) + GRPO (same as before)
python train.py --stages grpo # GRPO only (old behavior, no SFT)
python train.py --stages sft # SFT only; saves ./cicd_rl_sft_lora
train.py --stages sft,grpo --sft-epochs 2
train.py --no-final-eval
train.py --eval-timeout 90
Console: SFT/GRPO log lines (loss/rewards + step X/Y), per-stage times and step counts, then
a final eval of every task with correct/wrong/timeout, wall time, and reward breakdown.
Requires: pip install unsloth trl datasets transformers
"""
import argparse
import os, re, sys
import time
sys.path.insert(0, os.path.dirname(__file__))
try:
import yaml
except Exception:
yaml = None
USE_UNSLOTH = True
if USE_UNSLOTH:
import unsloth # noqa: F401 # before trl/transformers (Unsloth fast path)
MODEL_NAME = "unsloth/Qwen2.5-0.5B-Instruct"
MAX_STEPS = 300
# `num_generations` must divide the global train batch; 4|4 matches GRPO.
BATCH_SIZE = 4
GRAD_ACCUM = 4
NUM_SAMPLES = 512
# GRPO: use `max_completion_length` (TRL); older examples used `max_new_tokens`.
MAX_COMPLETION_TOKENS = 128
# SFT: teach exact gold YAML before RL polish (short run by design).
SFT_EPOCHS = 1
SFT_LEARNING_RATE = 2e-4
SFT_MAX_SEQ = 1024
SFT_DATASET_SIZE = 512
SFT_OUTPUT = "./cicd_rl_sft_lora"
# Post-training quick eval: mark each task correct / wrong / timeout if generation exceeds this (seconds).
EVAL_GEN_TIMEOUT_SEC = 60.0
# Reward mix: GRPO sums per-function rewards; keep correctness as the dominant term.
REWARD_FIX_MATCH = 5.0
REWARD_FIX_MISS = -1.5
REWARD_STRUCT_SCALE = 0.2
REWARD_HALLU_GOOD = 0.1
REWARD_HALLU_BAD = -0.35
from cicd_debug_env.tasks import ALL_TASKS
from datasets import Dataset
import random
SYSTEM_PROMPT = (
"You are an expert DevOps engineer. "
"You receive a broken CI/CD pipeline YAML and error details. "
"Output ONLY the corrected YAML — no explanation, no markdown fences."
)
def build_prompt(task: dict) -> str:
return (
f"### Error\n{task.get('error_message', '')}\n\n"
f"### Broken Pipeline\n{task['pipeline_yaml']}\n\n"
f"### Fixed Pipeline (YAML only):\n"
)
def build_dataset():
easy = [t for t in ALL_TASKS if t["difficulty"] == "easy"]
medium = [t for t in ALL_TASKS if t["difficulty"] == "medium"]
hard = [t for t in ALL_TASKS if t["difficulty"] == "hard"]
records = []
for _ in range(NUM_SAMPLES):
r = random.random()
task = random.choice(easy if r < 0.5 else medium if r < 0.8 else hard)
records.append({
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_prompt(task)},
],
"correct_yaml": task.get("correct_yaml", ""),
"pipeline_yaml": task["pipeline_yaml"],
})
return Dataset.from_list(records)
def build_sft_dataset(tokenizer) -> Dataset:
"""Supervised (prompt, assistant) = same chat format as inference; target is exact correct_yaml."""
easy = [t for t in ALL_TASKS if t["difficulty"] == "easy"]
medium = [t for t in ALL_TASKS if t["difficulty"] == "medium"]
hard = [t for t in ALL_TASKS if t["difficulty"] == "hard"]
records = []
for _ in range(SFT_DATASET_SIZE):
r = random.random()
task = random.choice(easy if r < 0.5 else medium if r < 0.8 else hard)
gold = (task.get("correct_yaml") or "").strip()
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_prompt(task)},
{"role": "assistant", "content": gold},
]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
records.append({"text": text})
return Dataset.from_list(records)
def _completion_to_text(completion) -> str:
"""
Normalize TRL/Unsloth completion payloads to plain text.
`completion` can be a string, dict, or list of message chunks.
"""
if isinstance(completion, str):
return completion
if isinstance(completion, dict):
if isinstance(completion.get("content"), str):
return completion["content"]
if isinstance(completion.get("text"), str):
return completion["text"]
return str(completion)
if isinstance(completion, list):
parts = []
for item in completion:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
content = item.get("content", item.get("text", ""))
if isinstance(content, str):
parts.append(content)
elif content is not None:
parts.append(str(content))
elif item is not None:
parts.append(str(item))
return "\n".join(p for p in parts if p)
return "" if completion is None else str(completion)
def _strip_markdown_fences(text: str) -> str:
t = text.strip()
if t.startswith("```"):
t = re.sub(r"^```[a-zA-Z0-9_-]*\n?", "", t)
t = re.sub(r"\n?```$", "", t.strip())
return t.strip()
def _normalize_yaml_like(text: str) -> str:
lines = [line.rstrip() for line in text.splitlines()]
lines = [line for line in lines if line.strip()]
return "\n".join(lines).strip()
def _canonical_yaml(text: str) -> str:
stripped = _normalize_yaml_like(_strip_markdown_fences(text))
if not stripped:
return ""
if yaml is None:
return stripped
try:
parsed = yaml.safe_load(stripped)
return yaml.safe_dump(parsed, sort_keys=True).strip()
except Exception:
return stripped
def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs):
rewards = []
for c, correct, _broken in zip(completions, correct_yaml, pipeline_yaml):
pred = _completion_to_text(c)
pred_canon = _canonical_yaml(pred)
correct_canon = _canonical_yaml(correct)
# Strict reward: exact/canonical exact gets high reward; everything else is negative.
ok = bool(pred_canon and pred_canon == correct_canon)
rewards.append(REWARD_FIX_MATCH if ok else REWARD_FIX_MISS)
return rewards
def reward_yaml_structure(completions, prompts, **kwargs):
rewards = []
for c in completions:
t = _strip_markdown_fences(_completion_to_text(c))
lines = [x for x in t.splitlines() if x.strip()]
starts_yaml = t.startswith(("name:", "jobs:", "steps:", "on:", "env:", "- "))
has_yaml_keys = any(k in t for k in ["steps:", "jobs:", "name:", "run:", "uses:", "env:", "with:"])
line_count_ok = 1 <= len(lines) <= 120
has_prose_or_md = any(
p in t.lower()
for p in ["here is", "explanation", "i fixed", "this yaml", "```", "---", "note:"]
)
# Reward valid YAML shape but heavily penalize markdown/prose wrappers.
score = 0.4 * int(starts_yaml) + 0.4 * int(has_yaml_keys) + 0.2 * int(line_count_ok)
if has_prose_or_md:
score -= 1.0
rewards.append(score * REWARD_STRUCT_SCALE)
return rewards
def reward_no_hallucination(completions, prompts, **kwargs):
bad = [
"i cannot", "i am sorry", "as an ai", "here is", "```yaml", "```",
"explanation:", "note:", "sure!", "of course", "the fixed yaml", "this yaml",
]
values = []
for c in completions:
lower = _completion_to_text(c).lower()
bad_hits = sum(1 for p in bad if p in lower)
values.append(REWARD_HALLU_BAD if bad_hits > 0 else REWARD_HALLU_GOOD)
return values
REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]
def _grpo_console_callback(max_steps: int, label: str = "GRPO"):
from transformers import TrainerCallback
class _GRPOConsoleLogCallback(TrainerCallback):
def __init__(self) -> None:
self._max = max_steps
self._label = label
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
parts = [f"[{self._label} turn/step {state.global_step}/{self._max}]"]
for k in sorted(logs.keys()):
kl = k.lower()
if "reward" in kl or k in ("loss", "kl", "learning_rate", "train_loss") or "loss" in kl:
v = logs[k]
if isinstance(v, (int, float)):
parts.append(f"{k}={v:.6g}")
else:
parts.append(f"{k}={v}")
print(" | ".join(parts), flush=True)
return _GRPOConsoleLogCallback()
def _sft_console_callback():
from transformers import TrainerCallback
class _SFTConsoleLogCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
line = f"[SFT turn/step {state.global_step}]"
for k, v in sorted(logs.items()):
if "loss" in k.lower() or "learning_rate" in k:
if isinstance(v, (int, float)):
line += f" {k}={v:.6g}"
print(line, flush=True)
return _SFTConsoleLogCallback()
def _format_seconds(sec: float) -> str:
if sec < 60:
return f"{sec:.1f}s"
m, s = int(sec // 60), sec % 60
if m < 60:
return f"{m}m {s:.1f}s"
h, m = m // 60, m % 60
return f"{h}h {m}m {s:.0f}s"
def _print_grpo_reward_tail(trainer) -> None:
hist = getattr(trainer.state, "log_history", None) or []
if not hist:
print("(No log_history available for reward summary.)", flush=True)
return
print("\n--- Last GRPO log entries (rewards) ---", flush=True)
for row in hist[-5:]:
rbits = {k: v for k, v in row.items() if "reward" in k.lower() or k == "loss"}
if rbits:
print(f" step {row.get('step', '?')}: {rbits}", flush=True)
def _set_inference_mode(model) -> None:
if USE_UNSLOTH:
from unsloth import FastLanguageModel
FastLanguageModel.for_inference(model)
else:
model.eval()
def _generate_for_task(model, tokenizer, task: dict, max_new_tokens: int) -> str:
import torch
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_prompt(task)},
]
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
dev = next(model.parameters()).device
inputs = tokenizer(text, return_tensors="pt").to(dev)
with torch.inference_mode():
out = model.generate(
**inputs, max_new_tokens=max_new_tokens, do_sample=False
)
return tokenizer.decode(
out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
)
def _eval_task_status(raw: str, task: dict, took_sec: float, timeout_sec: float) -> str:
if took_sec > timeout_sec:
return "timeout"
pred = _strip_markdown_fences(_completion_to_text(raw))
gold = (task.get("correct_yaml") or "").strip()
p_can = _canonical_yaml(pred)
g_can = _canonical_yaml(gold)
if p_can and g_can and p_can == g_can:
return "correct"
return "wrong"
def run_final_task_eval(
model,
tokenizer,
max_new_tokens: int = MAX_COMPLETION_TOKENS,
timeout_sec: float = EVAL_GEN_TIMEOUT_SEC,
) -> None:
"""One generation per task; labels: correct, wrong, or timeout (if wall time > timeout_sec)."""
_set_inference_mode(model)
print(
f"\n========== EVAL: all {len(ALL_TASKS)} tasks (1 turn each; max_new_tokens={max_new_tokens}, "
f"timeout if wall time > {timeout_sec}s) ==========",
flush=True,
)
for task in ALL_TASKS:
tid = task.get("id", "?")
t0 = time.perf_counter()
try:
raw = _generate_for_task(model, tokenizer, task, max_new_tokens)
except Exception as e: # noqa: BLE001
took = time.perf_counter() - t0
print(
f" {tid}: error — {e!r} (after {took:.1f}s)",
flush=True,
)
continue
took = time.perf_counter() - t0
status = _eval_task_status(raw, task, took, timeout_sec)
r_fix = reward_fix_correctness(
[raw], [None], [task.get("correct_yaml", "")], [task["pipeline_yaml"]]
)[0]
r_stru = reward_yaml_structure([raw], [None])[0]
r_hallu = reward_no_hallucination([raw], [None])[0]
r_sum = r_fix + r_stru + r_hallu
print(
f" {tid}: {status:7s} | t={took:5.2f}s | rewards: total={r_sum:+.2f} "
f"(fix={r_fix:+.2f} struct={r_stru:+.2f} no_hallu={r_hallu:+.2f})",
flush=True,
)
print("========== EVAL end ==========\n", flush=True)
def _wandb_ok() -> bool:
try:
import wandb # noqa: F401
return True
except Exception:
return False
def run_sft(model, tokenizer, use_wandb: bool, sft_epochs: float):
from trl import SFTTrainer, SFTConfig
sft_data = build_sft_dataset(tokenizer)
print(f"SFT dataset: {len(sft_data)} samples, {sft_epochs} epoch(s)")
sft_config = SFTConfig(
output_dir="./cicd_rl_sft_output",
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
num_train_epochs=sft_epochs,
learning_rate=SFT_LEARNING_RATE,
logging_steps=10,
save_strategy="no",
max_length=SFT_MAX_SEQ,
dataset_text_field="text",
report_to="wandb" if use_wandb else "none",
remove_unused_columns=False,
optim="adamw_8bit",
# Train loss on assistant tokens only (full gold YAML in the assistant turn).
assistant_only_loss=True,
)
trainer = SFTTrainer(
model=model,
args=sft_config,
train_dataset=sft_data,
processing_class=tokenizer,
callbacks=[_sft_console_callback()],
)
if use_wandb:
import wandb
wandb.init(project="cicd-rl-agent", name="sft-cicd-yaml", reinit=True)
print("Starting SFT (supervised: prompt -> correct YAML)...")
trainer.train()
model.save_pretrained(SFT_OUTPUT)
tokenizer.save_pretrained(SFT_OUTPUT)
print(f"SFT LoRA saved to {SFT_OUTPUT}")
return trainer
def _post_train_smoke_unsloth(tokenizer, model) -> None:
import torch
from unsloth import FastLanguageModel
print("Testing post-training inference...")
FastLanguageModel.for_inference(model)
if not torch.cuda.is_available():
print("(CUDA not available; skip generate smoke test.)")
return
test_input = tokenizer("Fix this YAML: steps:\n - run: npm tset", return_tensors="pt").to("cuda")
with torch.inference_mode():
out = model.generate(**test_input, max_new_tokens=64)
print(tokenizer.decode(out[0], skip_special_tokens=True))
def main():
p = argparse.ArgumentParser(description="SFT (optional) + GRPO training for CICD YAML fix agent")
p.add_argument(
"--stages",
type=str,
default="sft,grpo",
help="Comma list: sft, grpo (default: sft,grpo = supervised then RL)",
)
p.add_argument("--sft-epochs", type=float, default=SFT_EPOCHS, help="SFT pass size (set 0 to skip SFT in code paths that still use --stages; prefer --stages grpo)")
p.add_argument(
"--no-final-eval",
action="store_true",
help="Skip end-of-run eval (correct / wrong / timeout per task).",
)
p.add_argument(
"--eval-timeout",
type=float,
default=EVAL_GEN_TIMEOUT_SEC,
help="Mark task eval as 'timeout' if a single generate() takes longer than this (seconds).",
)
args = p.parse_args()
wants = {s.strip().lower() for s in args.stages.split(",") if s.strip()}
if not wants.issubset({"sft", "grpo"}) or not wants:
print("Error: --stages must list one or more of: sft, grpo (e.g. sft,grpo or grpo)")
sys.exit(1)
# Colab often sets WANDB_DISABLED in the runtime env.
if os.environ.get("WANDB_DISABLED", "").strip().lower() in {"1", "true", "yes", "on"}:
print("Detected WANDB_DISABLED; unsetting it because report_to may be 'wandb'.")
os.environ.pop("WANDB_DISABLED", None)
if USE_UNSLOTH:
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME, max_seq_length=1024, dtype=None, load_in_4bit=True
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_alpha=16,
lora_dropout=0.0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=42,
)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
use_wandb = _wandb_ok()
if not use_wandb:
print("wandb is not installed; falling back to report_to='none' where applicable.")
if "sft" in wants and args.sft_epochs <= 0:
print("Error: --sft-epochs must be > 0 when SFT is in --stages")
sys.exit(1)
t_start = time.perf_counter()
sft_time_s = 0.0
grpo_time_s = 0.0
sft_steps = 0
grpo_steps = 0
sft_trainer = None
grpo_trainer = None
if "sft" in wants:
t0 = time.perf_counter()
sft_trainer = run_sft(model, tokenizer, use_wandb, float(args.sft_epochs))
sft_time_s = time.perf_counter() - t0
sft_steps = getattr(sft_trainer.state, "global_step", 0) if sft_trainer else 0
print(
f"--- SFT done: {sft_steps} optimizer turn(s) / step(s), time {_format_seconds(sft_time_s)} ---\n",
flush=True,
)
if "grpo" in wants:
dataset = build_dataset()
print(f"GRPO dataset: {len(dataset)} samples")
from trl import GRPOTrainer, GRPOConfig
grpo_args = GRPOConfig(
output_dir="./cicd_rl_output",
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=5e-6,
max_steps=MAX_STEPS,
num_generations=4,
max_completion_length=MAX_COMPLETION_TOKENS,
logging_steps=5,
save_steps=50,
report_to="wandb" if use_wandb else "none",
remove_unused_columns=False,
warmup_steps=10,
lr_scheduler_type="cosine",
optim="adamw_8bit",
)
grpo_trainer = GRPOTrainer(
model=model,
args=grpo_args,
reward_funcs=REWARD_FUNCTIONS,
train_dataset=dataset,
processing_class=tokenizer,
callbacks=[_grpo_console_callback(MAX_STEPS, "GRPO")],
)
print("Starting GRPO training... (rewards + loss in log lines; online reward below)\n", flush=True)
if use_wandb:
import wandb
wandb.init(project="cicd-rl-agent", name="grpo-cicd-yaml", reinit=True)
t0 = time.perf_counter()
grpo_trainer.train()
grpo_time_s = time.perf_counter() - t0
grpo_steps = getattr(grpo_trainer.state, "global_step", 0)
print("GRPO training complete!", flush=True)
_print_grpo_reward_tail(grpo_trainer)
print(
f"\n--- GRPO done: {grpo_steps} optimizer turn(s) / step(s) (of {MAX_STEPS} max), "
f'time { _format_seconds(grpo_time_s) } ---\n',
flush=True,
)
save_path = "./cicd_rl_agent_final"
if "grpo" in wants:
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Final LoRA saved to {save_path} (SFT+GRPO pipeline end state).")
if USE_UNSLOTH:
_post_train_smoke_unsloth(tokenizer, model)
else:
print("Non-Unsloth path: inference test skipped.")
elif "sft" in wants:
# SFT weights already written in run_sft(); also mirror to default eval path for convenience.
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"SFT-only run: LoRA is in {SFT_OUTPUT} and copied to {save_path} for eval_lora defaults.")
total_s = time.perf_counter() - t_start
print("\n========== TRAINING SUMMARY ==========", flush=True)
print(f"Total wall time: {_format_seconds(total_s)}", flush=True)
if sft_time_s:
print(
f" SFT: time={_format_seconds(sft_time_s)} | turn(s)/step(s) = {sft_steps} | (supervised, loss in [SFT turn/step ...] lines)",
flush=True,
)
if grpo_time_s:
print(
f" GRPO: time={_format_seconds(grpo_time_s)} | turn(s)/step(s) = {grpo_steps} | (online rewards in [GRPO turn/step ...] lines)",
flush=True,
)
print(
" Note: each eval task is a single user→assistant 'turn'; GRPO/SFT 'turns' = optimizer update steps.\n"
"========================================\n",
flush=True,
)
if not args.no_final_eval and (sft_time_s or grpo_time_s):
run_final_task_eval(
model, tokenizer, MAX_COMPLETION_TOKENS, timeout_sec=float(args.eval_timeout)
)
elif args.no_final_eval:
print("Skipped final per-task eval (--no-final-eval).", flush=True)
if __name__ == "__main__":
main()