Commit ·
048dc4f
1
Parent(s): 3e7bcbe
training issue resolved
Browse files- eval_lora.py +26 -1
- train.py +409 -50
- train_colab.ipynb +35 -26
eval_lora.py
CHANGED
|
@@ -5,6 +5,9 @@ Usage (Colab):
|
|
| 5 |
!python eval_lora.py --adapter-path ./cicd_rl_agent_final
|
| 6 |
|
| 7 |
Optional: --base-model must match what you fine-tuned.
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
import argparse
|
|
@@ -96,7 +99,14 @@ def main():
|
|
| 96 |
default=True,
|
| 97 |
help="Compare predicted YAML vs correct_yaml using canonicalized YAML tree",
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
args = p.parse_args()
|
|
|
|
| 100 |
|
| 101 |
if not os.path.isdir(args.adapter_path):
|
| 102 |
print(f"Adapter path not found: {args.adapter_path}")
|
|
@@ -146,9 +156,24 @@ def main():
|
|
| 146 |
comp = strip_code_fences(raw)
|
| 147 |
correct = task.get("correct_yaml", "")
|
| 148 |
label = partial_match_score(comp, correct, task["pipeline_yaml"], args.canonical_compare)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
by_diff[d].append(
|
| 150 |
{
|
| 151 |
-
"id":
|
| 152 |
"label": label,
|
| 153 |
}
|
| 154 |
)
|
|
|
|
| 5 |
!python eval_lora.py --adapter-path ./cicd_rl_agent_final
|
| 6 |
|
| 7 |
Optional: --base-model must match what you fine-tuned.
|
| 8 |
+
|
| 9 |
+
Debug a few tasks (raw vs canonical reference):
|
| 10 |
+
!python eval_lora.py --adapter-path ./cicd_rl_agent_final --inspect easy_003,medium_001
|
| 11 |
"""
|
| 12 |
|
| 13 |
import argparse
|
|
|
|
| 99 |
default=True,
|
| 100 |
help="Compare predicted YAML vs correct_yaml using canonicalized YAML tree",
|
| 101 |
)
|
| 102 |
+
p.add_argument(
|
| 103 |
+
"--inspect",
|
| 104 |
+
type=str,
|
| 105 |
+
default="",
|
| 106 |
+
help="Comma-separated task ids (e.g. easy_001,medium_002) to print raw model output vs reference",
|
| 107 |
+
)
|
| 108 |
args = p.parse_args()
|
| 109 |
+
inspect_ids = {s.strip() for s in args.inspect.split(",") if s.strip()}
|
| 110 |
|
| 111 |
if not os.path.isdir(args.adapter_path):
|
| 112 |
print(f"Adapter path not found: {args.adapter_path}")
|
|
|
|
| 156 |
comp = strip_code_fences(raw)
|
| 157 |
correct = task.get("correct_yaml", "")
|
| 158 |
label = partial_match_score(comp, correct, task["pipeline_yaml"], args.canonical_compare)
|
| 159 |
+
tid = task["id"]
|
| 160 |
+
if inspect_ids and tid in inspect_ids:
|
| 161 |
+
pred_c = canonical_yaml(comp) if args.canonical_compare else comp.strip()
|
| 162 |
+
gold_c = canonical_yaml(correct) if args.canonical_compare else correct.strip()
|
| 163 |
+
print(f"\n=== INSPECT {tid} (label={label}) ===\n")
|
| 164 |
+
print("--- raw model output ---")
|
| 165 |
+
print(raw)
|
| 166 |
+
print("--- after strip_code_fences ---")
|
| 167 |
+
print(comp)
|
| 168 |
+
print("--- canonical pred ---")
|
| 169 |
+
print(pred_c)
|
| 170 |
+
print("--- canonical reference (correct_yaml) ---")
|
| 171 |
+
print(gold_c)
|
| 172 |
+
print("--- match ---")
|
| 173 |
+
print("exact canonical match:", pred_c == gold_c)
|
| 174 |
by_diff[d].append(
|
| 175 |
{
|
| 176 |
+
"id": tid,
|
| 177 |
"label": label,
|
| 178 |
}
|
| 179 |
)
|
train.py
CHANGED
|
@@ -1,10 +1,24 @@
|
|
| 1 |
"""
|
| 2 |
-
train.py — CICD RL Agent
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
Requires: pip install unsloth trl datasets transformers
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
import os, re, sys
|
|
|
|
| 8 |
sys.path.insert(0, os.path.dirname(__file__))
|
| 9 |
try:
|
| 10 |
import yaml
|
|
@@ -23,6 +37,23 @@ NUM_SAMPLES = 512
|
|
| 23 |
# GRPO: use `max_completion_length` (TRL); older examples used `max_new_tokens`.
|
| 24 |
MAX_COMPLETION_TOKENS = 128
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
from cicd_debug_env.tasks import ALL_TASKS
|
| 27 |
from datasets import Dataset
|
| 28 |
import random
|
|
@@ -58,6 +89,29 @@ def build_dataset():
|
|
| 58 |
})
|
| 59 |
return Dataset.from_list(records)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def _completion_to_text(completion) -> str:
|
| 62 |
"""
|
| 63 |
Normalize TRL/Unsloth completion payloads to plain text.
|
|
@@ -118,7 +172,8 @@ def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **
|
|
| 118 |
pred_canon = _canonical_yaml(pred)
|
| 119 |
correct_canon = _canonical_yaml(correct)
|
| 120 |
# Strict reward: exact/canonical exact gets high reward; everything else is negative.
|
| 121 |
-
|
|
|
|
| 122 |
return rewards
|
| 123 |
|
| 124 |
def reward_yaml_structure(completions, prompts, **kwargs):
|
|
@@ -137,7 +192,7 @@ def reward_yaml_structure(completions, prompts, **kwargs):
|
|
| 137 |
score = 0.4 * int(starts_yaml) + 0.4 * int(has_yaml_keys) + 0.2 * int(line_count_ok)
|
| 138 |
if has_prose_or_md:
|
| 139 |
score -= 1.0
|
| 140 |
-
rewards.append(score)
|
| 141 |
return rewards
|
| 142 |
|
| 143 |
def reward_no_hallucination(completions, prompts, **kwargs):
|
|
@@ -149,27 +204,266 @@ def reward_no_hallucination(completions, prompts, **kwargs):
|
|
| 149 |
for c in completions:
|
| 150 |
lower = _completion_to_text(c).lower()
|
| 151 |
bad_hits = sum(1 for p in bad if p in lower)
|
| 152 |
-
values.append(
|
| 153 |
return values
|
| 154 |
|
| 155 |
REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
# Colab often sets WANDB_DISABLED in the runtime env.
|
| 159 |
-
# If report_to is wandb, this env var causes a hard runtime error in Trainer callbacks.
|
| 160 |
if os.environ.get("WANDB_DISABLED", "").strip().lower() in {"1", "true", "yes", "on"}:
|
| 161 |
-
print("Detected WANDB_DISABLED; unsetting it because report_to
|
| 162 |
os.environ.pop("WANDB_DISABLED", None)
|
| 163 |
|
| 164 |
if USE_UNSLOTH:
|
| 165 |
from unsloth import FastLanguageModel
|
| 166 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 167 |
-
model_name=MODEL_NAME, max_seq_length=1024, dtype=None, load_in_4bit=True
|
|
|
|
| 168 |
model = FastLanguageModel.get_peft_model(
|
| 169 |
-
model,
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
else:
|
| 174 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 175 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
@@ -177,53 +471,118 @@ def main():
|
|
| 177 |
if tokenizer.pad_token is None:
|
| 178 |
tokenizer.pad_token = tokenizer.eos_token
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
import wandb # noqa: F401
|
| 187 |
-
except Exception:
|
| 188 |
-
use_wandb = False
|
| 189 |
-
print("wandb is not installed; falling back to report_to='none'.")
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
logging_steps=5, save_steps=50,
|
| 199 |
-
report_to="wandb" if use_wandb else "none", remove_unused_columns=False,
|
| 200 |
-
warmup_steps=10, lr_scheduler_type="cosine", optim="adamw_8bit",
|
| 201 |
-
)
|
| 202 |
-
trainer = GRPOTrainer(
|
| 203 |
-
model=model, args=args, reward_funcs=REWARD_FUNCTIONS,
|
| 204 |
-
train_dataset=dataset, processing_class=tokenizer)
|
| 205 |
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
save_path = "./cicd_rl_agent_final"
|
| 214 |
-
if
|
| 215 |
model.save_pretrained(save_path)
|
| 216 |
tokenizer.save_pretrained(save_path)
|
| 217 |
-
print(f"LoRA
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
model.save_pretrained(save_path)
|
| 225 |
tokenizer.save_pretrained(save_path)
|
| 226 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
if __name__ == "__main__":
|
| 229 |
main()
|
|
|
|
| 1 |
"""
|
| 2 |
+
train.py — CICD RL Agent: optional SFT (supervised) then GRPO (RL) on CI/CD YAML fixes.
|
| 3 |
+
|
| 4 |
+
Default: short SFT on (prompt → correct_yaml), then GRPO with correctness-heavy rewards.
|
| 5 |
+
|
| 6 |
+
python train.py # SFT (short) + GRPO (same as before)
|
| 7 |
+
python train.py --stages grpo # GRPO only (old behavior, no SFT)
|
| 8 |
+
python train.py --stages sft # SFT only; saves ./cicd_rl_sft_lora
|
| 9 |
+
train.py --stages sft,grpo --sft-epochs 2
|
| 10 |
+
train.py --no-final-eval
|
| 11 |
+
train.py --eval-timeout 90
|
| 12 |
+
|
| 13 |
+
Console: SFT/GRPO log lines (loss/rewards + step X/Y), per-stage times and step counts, then
|
| 14 |
+
a final eval of every task with correct/wrong/timeout, wall time, and reward breakdown.
|
| 15 |
+
|
| 16 |
Requires: pip install unsloth trl datasets transformers
|
| 17 |
"""
|
| 18 |
|
| 19 |
+
import argparse
|
| 20 |
import os, re, sys
|
| 21 |
+
import time
|
| 22 |
sys.path.insert(0, os.path.dirname(__file__))
|
| 23 |
try:
|
| 24 |
import yaml
|
|
|
|
| 37 |
# GRPO: use `max_completion_length` (TRL); older examples used `max_new_tokens`.
|
| 38 |
MAX_COMPLETION_TOKENS = 128
|
| 39 |
|
| 40 |
+
# SFT: teach exact gold YAML before RL polish (short run by design).
|
| 41 |
+
SFT_EPOCHS = 1
|
| 42 |
+
SFT_LEARNING_RATE = 2e-4
|
| 43 |
+
SFT_MAX_SEQ = 1024
|
| 44 |
+
SFT_DATASET_SIZE = 512
|
| 45 |
+
SFT_OUTPUT = "./cicd_rl_sft_lora"
|
| 46 |
+
|
| 47 |
+
# Post-training quick eval: mark each task correct / wrong / timeout if generation exceeds this (seconds).
|
| 48 |
+
EVAL_GEN_TIMEOUT_SEC = 60.0
|
| 49 |
+
|
| 50 |
+
# Reward mix: GRPO sums per-function rewards; keep correctness as the dominant term.
|
| 51 |
+
REWARD_FIX_MATCH = 5.0
|
| 52 |
+
REWARD_FIX_MISS = -1.5
|
| 53 |
+
REWARD_STRUCT_SCALE = 0.2
|
| 54 |
+
REWARD_HALLU_GOOD = 0.1
|
| 55 |
+
REWARD_HALLU_BAD = -0.35
|
| 56 |
+
|
| 57 |
from cicd_debug_env.tasks import ALL_TASKS
|
| 58 |
from datasets import Dataset
|
| 59 |
import random
|
|
|
|
| 89 |
})
|
| 90 |
return Dataset.from_list(records)
|
| 91 |
|
| 92 |
+
|
| 93 |
+
def build_sft_dataset(tokenizer) -> Dataset:
|
| 94 |
+
"""Supervised (prompt, assistant) = same chat format as inference; target is exact correct_yaml."""
|
| 95 |
+
easy = [t for t in ALL_TASKS if t["difficulty"] == "easy"]
|
| 96 |
+
medium = [t for t in ALL_TASKS if t["difficulty"] == "medium"]
|
| 97 |
+
hard = [t for t in ALL_TASKS if t["difficulty"] == "hard"]
|
| 98 |
+
records = []
|
| 99 |
+
for _ in range(SFT_DATASET_SIZE):
|
| 100 |
+
r = random.random()
|
| 101 |
+
task = random.choice(easy if r < 0.5 else medium if r < 0.8 else hard)
|
| 102 |
+
gold = (task.get("correct_yaml") or "").strip()
|
| 103 |
+
messages = [
|
| 104 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 105 |
+
{"role": "user", "content": build_prompt(task)},
|
| 106 |
+
{"role": "assistant", "content": gold},
|
| 107 |
+
]
|
| 108 |
+
text = tokenizer.apply_chat_template(
|
| 109 |
+
messages, tokenize=False, add_generation_prompt=False
|
| 110 |
+
)
|
| 111 |
+
records.append({"text": text})
|
| 112 |
+
return Dataset.from_list(records)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
def _completion_to_text(completion) -> str:
|
| 116 |
"""
|
| 117 |
Normalize TRL/Unsloth completion payloads to plain text.
|
|
|
|
| 172 |
pred_canon = _canonical_yaml(pred)
|
| 173 |
correct_canon = _canonical_yaml(correct)
|
| 174 |
# Strict reward: exact/canonical exact gets high reward; everything else is negative.
|
| 175 |
+
ok = bool(pred_canon and pred_canon == correct_canon)
|
| 176 |
+
rewards.append(REWARD_FIX_MATCH if ok else REWARD_FIX_MISS)
|
| 177 |
return rewards
|
| 178 |
|
| 179 |
def reward_yaml_structure(completions, prompts, **kwargs):
|
|
|
|
| 192 |
score = 0.4 * int(starts_yaml) + 0.4 * int(has_yaml_keys) + 0.2 * int(line_count_ok)
|
| 193 |
if has_prose_or_md:
|
| 194 |
score -= 1.0
|
| 195 |
+
rewards.append(score * REWARD_STRUCT_SCALE)
|
| 196 |
return rewards
|
| 197 |
|
| 198 |
def reward_no_hallucination(completions, prompts, **kwargs):
|
|
|
|
| 204 |
for c in completions:
|
| 205 |
lower = _completion_to_text(c).lower()
|
| 206 |
bad_hits = sum(1 for p in bad if p in lower)
|
| 207 |
+
values.append(REWARD_HALLU_BAD if bad_hits > 0 else REWARD_HALLU_GOOD)
|
| 208 |
return values
|
| 209 |
|
| 210 |
REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]
|
| 211 |
|
| 212 |
+
|
| 213 |
+
def _grpo_console_callback(max_steps: int, label: str = "GRPO"):
|
| 214 |
+
from transformers import TrainerCallback
|
| 215 |
+
|
| 216 |
+
class _GRPOConsoleLogCallback(TrainerCallback):
|
| 217 |
+
def __init__(self) -> None:
|
| 218 |
+
self._max = max_steps
|
| 219 |
+
self._label = label
|
| 220 |
+
|
| 221 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 222 |
+
if not logs:
|
| 223 |
+
return
|
| 224 |
+
parts = [f"[{self._label} turn/step {state.global_step}/{self._max}]"]
|
| 225 |
+
for k in sorted(logs.keys()):
|
| 226 |
+
kl = k.lower()
|
| 227 |
+
if "reward" in kl or k in ("loss", "kl", "learning_rate", "train_loss") or "loss" in kl:
|
| 228 |
+
v = logs[k]
|
| 229 |
+
if isinstance(v, (int, float)):
|
| 230 |
+
parts.append(f"{k}={v:.6g}")
|
| 231 |
+
else:
|
| 232 |
+
parts.append(f"{k}={v}")
|
| 233 |
+
print(" | ".join(parts), flush=True)
|
| 234 |
+
|
| 235 |
+
return _GRPOConsoleLogCallback()
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _sft_console_callback():
|
| 239 |
+
from transformers import TrainerCallback
|
| 240 |
+
|
| 241 |
+
class _SFTConsoleLogCallback(TrainerCallback):
|
| 242 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 243 |
+
if not logs:
|
| 244 |
+
return
|
| 245 |
+
line = f"[SFT turn/step {state.global_step}]"
|
| 246 |
+
for k, v in sorted(logs.items()):
|
| 247 |
+
if "loss" in k.lower() or "learning_rate" in k:
|
| 248 |
+
if isinstance(v, (int, float)):
|
| 249 |
+
line += f" {k}={v:.6g}"
|
| 250 |
+
print(line, flush=True)
|
| 251 |
+
|
| 252 |
+
return _SFTConsoleLogCallback()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _format_seconds(sec: float) -> str:
|
| 256 |
+
if sec < 60:
|
| 257 |
+
return f"{sec:.1f}s"
|
| 258 |
+
m, s = int(sec // 60), sec % 60
|
| 259 |
+
if m < 60:
|
| 260 |
+
return f"{m}m {s:.1f}s"
|
| 261 |
+
h, m = m // 60, m % 60
|
| 262 |
+
return f"{h}h {m}m {s:.0f}s"
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _print_grpo_reward_tail(trainer) -> None:
|
| 266 |
+
hist = getattr(trainer.state, "log_history", None) or []
|
| 267 |
+
if not hist:
|
| 268 |
+
print("(No log_history available for reward summary.)", flush=True)
|
| 269 |
+
return
|
| 270 |
+
print("\n--- Last GRPO log entries (rewards) ---", flush=True)
|
| 271 |
+
for row in hist[-5:]:
|
| 272 |
+
rbits = {k: v for k, v in row.items() if "reward" in k.lower() or k == "loss"}
|
| 273 |
+
if rbits:
|
| 274 |
+
print(f" step {row.get('step', '?')}: {rbits}", flush=True)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def _set_inference_mode(model) -> None:
|
| 278 |
+
if USE_UNSLOTH:
|
| 279 |
+
from unsloth import FastLanguageModel
|
| 280 |
+
FastLanguageModel.for_inference(model)
|
| 281 |
+
else:
|
| 282 |
+
model.eval()
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _generate_for_task(model, tokenizer, task: dict, max_new_tokens: int) -> str:
|
| 286 |
+
import torch
|
| 287 |
+
messages = [
|
| 288 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 289 |
+
{"role": "user", "content": build_prompt(task)},
|
| 290 |
+
]
|
| 291 |
+
text = tokenizer.apply_chat_template(
|
| 292 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 293 |
+
)
|
| 294 |
+
dev = next(model.parameters()).device
|
| 295 |
+
inputs = tokenizer(text, return_tensors="pt").to(dev)
|
| 296 |
+
with torch.inference_mode():
|
| 297 |
+
out = model.generate(
|
| 298 |
+
**inputs, max_new_tokens=max_new_tokens, do_sample=False
|
| 299 |
+
)
|
| 300 |
+
return tokenizer.decode(
|
| 301 |
+
out[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _eval_task_status(raw: str, task: dict, took_sec: float, timeout_sec: float) -> str:
|
| 306 |
+
if took_sec > timeout_sec:
|
| 307 |
+
return "timeout"
|
| 308 |
+
pred = _strip_markdown_fences(_completion_to_text(raw))
|
| 309 |
+
gold = (task.get("correct_yaml") or "").strip()
|
| 310 |
+
p_can = _canonical_yaml(pred)
|
| 311 |
+
g_can = _canonical_yaml(gold)
|
| 312 |
+
if p_can and g_can and p_can == g_can:
|
| 313 |
+
return "correct"
|
| 314 |
+
return "wrong"
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def run_final_task_eval(
|
| 318 |
+
model,
|
| 319 |
+
tokenizer,
|
| 320 |
+
max_new_tokens: int = MAX_COMPLETION_TOKENS,
|
| 321 |
+
timeout_sec: float = EVAL_GEN_TIMEOUT_SEC,
|
| 322 |
+
) -> None:
|
| 323 |
+
"""One generation per task; labels: correct, wrong, or timeout (if wall time > timeout_sec)."""
|
| 324 |
+
_set_inference_mode(model)
|
| 325 |
+
print(
|
| 326 |
+
f"\n========== EVAL: all {len(ALL_TASKS)} tasks (1 turn each; max_new_tokens={max_new_tokens}, "
|
| 327 |
+
f"timeout if wall time > {timeout_sec}s) ==========",
|
| 328 |
+
flush=True,
|
| 329 |
+
)
|
| 330 |
+
for task in ALL_TASKS:
|
| 331 |
+
tid = task.get("id", "?")
|
| 332 |
+
t0 = time.perf_counter()
|
| 333 |
+
try:
|
| 334 |
+
raw = _generate_for_task(model, tokenizer, task, max_new_tokens)
|
| 335 |
+
except Exception as e: # noqa: BLE001
|
| 336 |
+
took = time.perf_counter() - t0
|
| 337 |
+
print(
|
| 338 |
+
f" {tid}: error — {e!r} (after {took:.1f}s)",
|
| 339 |
+
flush=True,
|
| 340 |
+
)
|
| 341 |
+
continue
|
| 342 |
+
took = time.perf_counter() - t0
|
| 343 |
+
status = _eval_task_status(raw, task, took, timeout_sec)
|
| 344 |
+
r_fix = reward_fix_correctness(
|
| 345 |
+
[raw], [None], [task.get("correct_yaml", "")], [task["pipeline_yaml"]]
|
| 346 |
+
)[0]
|
| 347 |
+
r_stru = reward_yaml_structure([raw], [None])[0]
|
| 348 |
+
r_hallu = reward_no_hallucination([raw], [None])[0]
|
| 349 |
+
r_sum = r_fix + r_stru + r_hallu
|
| 350 |
+
print(
|
| 351 |
+
f" {tid}: {status:7s} | t={took:5.2f}s | rewards: total={r_sum:+.2f} "
|
| 352 |
+
f"(fix={r_fix:+.2f} struct={r_stru:+.2f} no_hallu={r_hallu:+.2f})",
|
| 353 |
+
flush=True,
|
| 354 |
+
)
|
| 355 |
+
print("========== EVAL end ==========\n", flush=True)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def _wandb_ok() -> bool:
|
| 359 |
+
try:
|
| 360 |
+
import wandb # noqa: F401
|
| 361 |
+
return True
|
| 362 |
+
except Exception:
|
| 363 |
+
return False
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def run_sft(model, tokenizer, use_wandb: bool, sft_epochs: float):
|
| 367 |
+
from trl import SFTTrainer, SFTConfig
|
| 368 |
+
|
| 369 |
+
sft_data = build_sft_dataset(tokenizer)
|
| 370 |
+
print(f"SFT dataset: {len(sft_data)} samples, {sft_epochs} epoch(s)")
|
| 371 |
+
|
| 372 |
+
sft_config = SFTConfig(
|
| 373 |
+
output_dir="./cicd_rl_sft_output",
|
| 374 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 375 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 376 |
+
num_train_epochs=sft_epochs,
|
| 377 |
+
learning_rate=SFT_LEARNING_RATE,
|
| 378 |
+
logging_steps=10,
|
| 379 |
+
save_strategy="no",
|
| 380 |
+
max_length=SFT_MAX_SEQ,
|
| 381 |
+
dataset_text_field="text",
|
| 382 |
+
report_to="wandb" if use_wandb else "none",
|
| 383 |
+
remove_unused_columns=False,
|
| 384 |
+
optim="adamw_8bit",
|
| 385 |
+
# Train loss on assistant tokens only (full gold YAML in the assistant turn).
|
| 386 |
+
assistant_only_loss=True,
|
| 387 |
+
)
|
| 388 |
+
trainer = SFTTrainer(
|
| 389 |
+
model=model,
|
| 390 |
+
args=sft_config,
|
| 391 |
+
train_dataset=sft_data,
|
| 392 |
+
processing_class=tokenizer,
|
| 393 |
+
callbacks=[_sft_console_callback()],
|
| 394 |
+
)
|
| 395 |
+
if use_wandb:
|
| 396 |
+
import wandb
|
| 397 |
+
wandb.init(project="cicd-rl-agent", name="sft-cicd-yaml", reinit=True)
|
| 398 |
+
print("Starting SFT (supervised: prompt -> correct YAML)...")
|
| 399 |
+
trainer.train()
|
| 400 |
+
model.save_pretrained(SFT_OUTPUT)
|
| 401 |
+
tokenizer.save_pretrained(SFT_OUTPUT)
|
| 402 |
+
print(f"SFT LoRA saved to {SFT_OUTPUT}")
|
| 403 |
+
return trainer
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _post_train_smoke_unsloth(tokenizer, model) -> None:
|
| 407 |
+
import torch
|
| 408 |
+
from unsloth import FastLanguageModel
|
| 409 |
+
|
| 410 |
+
print("Testing post-training inference...")
|
| 411 |
+
FastLanguageModel.for_inference(model)
|
| 412 |
+
if not torch.cuda.is_available():
|
| 413 |
+
print("(CUDA not available; skip generate smoke test.)")
|
| 414 |
+
return
|
| 415 |
+
test_input = tokenizer("Fix this YAML: steps:\n - run: npm tset", return_tensors="pt").to("cuda")
|
| 416 |
+
with torch.inference_mode():
|
| 417 |
+
out = model.generate(**test_input, max_new_tokens=64)
|
| 418 |
+
print(tokenizer.decode(out[0], skip_special_tokens=True))
|
| 419 |
+
|
| 420 |
+
|
| 421 |
def main():
|
| 422 |
+
p = argparse.ArgumentParser(description="SFT (optional) + GRPO training for CICD YAML fix agent")
|
| 423 |
+
p.add_argument(
|
| 424 |
+
"--stages",
|
| 425 |
+
type=str,
|
| 426 |
+
default="sft,grpo",
|
| 427 |
+
help="Comma list: sft, grpo (default: sft,grpo = supervised then RL)",
|
| 428 |
+
)
|
| 429 |
+
p.add_argument("--sft-epochs", type=float, default=SFT_EPOCHS, help="SFT pass size (set 0 to skip SFT in code paths that still use --stages; prefer --stages grpo)")
|
| 430 |
+
p.add_argument(
|
| 431 |
+
"--no-final-eval",
|
| 432 |
+
action="store_true",
|
| 433 |
+
help="Skip end-of-run eval (correct / wrong / timeout per task).",
|
| 434 |
+
)
|
| 435 |
+
p.add_argument(
|
| 436 |
+
"--eval-timeout",
|
| 437 |
+
type=float,
|
| 438 |
+
default=EVAL_GEN_TIMEOUT_SEC,
|
| 439 |
+
help="Mark task eval as 'timeout' if a single generate() takes longer than this (seconds).",
|
| 440 |
+
)
|
| 441 |
+
args = p.parse_args()
|
| 442 |
+
wants = {s.strip().lower() for s in args.stages.split(",") if s.strip()}
|
| 443 |
+
if not wants.issubset({"sft", "grpo"}) or not wants:
|
| 444 |
+
print("Error: --stages must list one or more of: sft, grpo (e.g. sft,grpo or grpo)")
|
| 445 |
+
sys.exit(1)
|
| 446 |
+
|
| 447 |
# Colab often sets WANDB_DISABLED in the runtime env.
|
|
|
|
| 448 |
if os.environ.get("WANDB_DISABLED", "").strip().lower() in {"1", "true", "yes", "on"}:
|
| 449 |
+
print("Detected WANDB_DISABLED; unsetting it because report_to may be 'wandb'.")
|
| 450 |
os.environ.pop("WANDB_DISABLED", None)
|
| 451 |
|
| 452 |
if USE_UNSLOTH:
|
| 453 |
from unsloth import FastLanguageModel
|
| 454 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 455 |
+
model_name=MODEL_NAME, max_seq_length=1024, dtype=None, load_in_4bit=True
|
| 456 |
+
)
|
| 457 |
model = FastLanguageModel.get_peft_model(
|
| 458 |
+
model,
|
| 459 |
+
r=16,
|
| 460 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 461 |
+
lora_alpha=16,
|
| 462 |
+
lora_dropout=0.0,
|
| 463 |
+
bias="none",
|
| 464 |
+
use_gradient_checkpointing="unsloth",
|
| 465 |
+
random_state=42,
|
| 466 |
+
)
|
| 467 |
else:
|
| 468 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 469 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
| 471 |
if tokenizer.pad_token is None:
|
| 472 |
tokenizer.pad_token = tokenizer.eos_token
|
| 473 |
|
| 474 |
+
use_wandb = _wandb_ok()
|
| 475 |
+
if not use_wandb:
|
| 476 |
+
print("wandb is not installed; falling back to report_to='none' where applicable.")
|
| 477 |
|
| 478 |
+
if "sft" in wants and args.sft_epochs <= 0:
|
| 479 |
+
print("Error: --sft-epochs must be > 0 when SFT is in --stages")
|
| 480 |
+
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
|
| 482 |
+
t_start = time.perf_counter()
|
| 483 |
+
sft_time_s = 0.0
|
| 484 |
+
grpo_time_s = 0.0
|
| 485 |
+
sft_steps = 0
|
| 486 |
+
grpo_steps = 0
|
| 487 |
+
sft_trainer = None
|
| 488 |
+
grpo_trainer = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
+
if "sft" in wants:
|
| 491 |
+
t0 = time.perf_counter()
|
| 492 |
+
sft_trainer = run_sft(model, tokenizer, use_wandb, float(args.sft_epochs))
|
| 493 |
+
sft_time_s = time.perf_counter() - t0
|
| 494 |
+
sft_steps = getattr(sft_trainer.state, "global_step", 0) if sft_trainer else 0
|
| 495 |
+
print(
|
| 496 |
+
f"--- SFT done: {sft_steps} optimizer turn(s) / step(s), time {_format_seconds(sft_time_s)} ---\n",
|
| 497 |
+
flush=True,
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if "grpo" in wants:
|
| 501 |
+
dataset = build_dataset()
|
| 502 |
+
print(f"GRPO dataset: {len(dataset)} samples")
|
| 503 |
+
from trl import GRPOTrainer, GRPOConfig
|
| 504 |
+
|
| 505 |
+
grpo_args = GRPOConfig(
|
| 506 |
+
output_dir="./cicd_rl_output",
|
| 507 |
+
per_device_train_batch_size=BATCH_SIZE,
|
| 508 |
+
gradient_accumulation_steps=GRAD_ACCUM,
|
| 509 |
+
learning_rate=5e-6,
|
| 510 |
+
max_steps=MAX_STEPS,
|
| 511 |
+
num_generations=4,
|
| 512 |
+
max_completion_length=MAX_COMPLETION_TOKENS,
|
| 513 |
+
logging_steps=5,
|
| 514 |
+
save_steps=50,
|
| 515 |
+
report_to="wandb" if use_wandb else "none",
|
| 516 |
+
remove_unused_columns=False,
|
| 517 |
+
warmup_steps=10,
|
| 518 |
+
lr_scheduler_type="cosine",
|
| 519 |
+
optim="adamw_8bit",
|
| 520 |
+
)
|
| 521 |
+
grpo_trainer = GRPOTrainer(
|
| 522 |
+
model=model,
|
| 523 |
+
args=grpo_args,
|
| 524 |
+
reward_funcs=REWARD_FUNCTIONS,
|
| 525 |
+
train_dataset=dataset,
|
| 526 |
+
processing_class=tokenizer,
|
| 527 |
+
callbacks=[_grpo_console_callback(MAX_STEPS, "GRPO")],
|
| 528 |
+
)
|
| 529 |
+
print("Starting GRPO training... (rewards + loss in log lines; online reward below)\n", flush=True)
|
| 530 |
+
if use_wandb:
|
| 531 |
+
import wandb
|
| 532 |
+
wandb.init(project="cicd-rl-agent", name="grpo-cicd-yaml", reinit=True)
|
| 533 |
+
t0 = time.perf_counter()
|
| 534 |
+
grpo_trainer.train()
|
| 535 |
+
grpo_time_s = time.perf_counter() - t0
|
| 536 |
+
grpo_steps = getattr(grpo_trainer.state, "global_step", 0)
|
| 537 |
+
print("GRPO training complete!", flush=True)
|
| 538 |
+
_print_grpo_reward_tail(grpo_trainer)
|
| 539 |
+
print(
|
| 540 |
+
f"\n--- GRPO done: {grpo_steps} optimizer turn(s) / step(s) (of {MAX_STEPS} max), "
|
| 541 |
+
f'time { _format_seconds(grpo_time_s) } ---\n',
|
| 542 |
+
flush=True,
|
| 543 |
+
)
|
| 544 |
|
| 545 |
save_path = "./cicd_rl_agent_final"
|
| 546 |
+
if "grpo" in wants:
|
| 547 |
model.save_pretrained(save_path)
|
| 548 |
tokenizer.save_pretrained(save_path)
|
| 549 |
+
print(f"Final LoRA saved to {save_path} (SFT+GRPO pipeline end state).")
|
| 550 |
+
if USE_UNSLOTH:
|
| 551 |
+
_post_train_smoke_unsloth(tokenizer, model)
|
| 552 |
+
else:
|
| 553 |
+
print("Non-Unsloth path: inference test skipped.")
|
| 554 |
+
elif "sft" in wants:
|
| 555 |
+
# SFT weights already written in run_sft(); also mirror to default eval path for convenience.
|
| 556 |
model.save_pretrained(save_path)
|
| 557 |
tokenizer.save_pretrained(save_path)
|
| 558 |
+
print(f"SFT-only run: LoRA is in {SFT_OUTPUT} and copied to {save_path} for eval_lora defaults.")
|
| 559 |
+
|
| 560 |
+
total_s = time.perf_counter() - t_start
|
| 561 |
+
print("\n========== TRAINING SUMMARY ==========", flush=True)
|
| 562 |
+
print(f"Total wall time: {_format_seconds(total_s)}", flush=True)
|
| 563 |
+
if sft_time_s:
|
| 564 |
+
print(
|
| 565 |
+
f" SFT: time={_format_seconds(sft_time_s)} | turn(s)/step(s) = {sft_steps} | (supervised, loss in [SFT turn/step ...] lines)",
|
| 566 |
+
flush=True,
|
| 567 |
+
)
|
| 568 |
+
if grpo_time_s:
|
| 569 |
+
print(
|
| 570 |
+
f" GRPO: time={_format_seconds(grpo_time_s)} | turn(s)/step(s) = {grpo_steps} | (online rewards in [GRPO turn/step ...] lines)",
|
| 571 |
+
flush=True,
|
| 572 |
+
)
|
| 573 |
+
print(
|
| 574 |
+
" Note: each eval task is a single user→assistant 'turn'; GRPO/SFT 'turns' = optimizer update steps.\n"
|
| 575 |
+
"========================================\n",
|
| 576 |
+
flush=True,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
if not args.no_final_eval and (sft_time_s or grpo_time_s):
|
| 580 |
+
run_final_task_eval(
|
| 581 |
+
model, tokenizer, MAX_COMPLETION_TOKENS, timeout_sec=float(args.eval_timeout)
|
| 582 |
+
)
|
| 583 |
+
elif args.no_final_eval:
|
| 584 |
+
print("Skipped final per-task eval (--no-final-eval).", flush=True)
|
| 585 |
+
|
| 586 |
|
| 587 |
if __name__ == "__main__":
|
| 588 |
main()
|
train_colab.ipynb
CHANGED
|
@@ -9,12 +9,12 @@
|
|
| 9 |
},
|
| 10 |
{
|
| 11 |
"cell_type": "code",
|
| 12 |
-
"execution_count": null,
|
| 13 |
"metadata": {},
|
| 14 |
-
"outputs": [],
|
| 15 |
"source": [
|
| 16 |
"!pip install unsloth trl transformers datasets torch wandb pydantic"
|
| 17 |
-
]
|
|
|
|
|
|
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"cell_type": "markdown",
|
|
@@ -25,9 +25,7 @@
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
| 28 |
-
"execution_count": null,
|
| 29 |
"metadata": {},
|
| 30 |
-
"outputs": [],
|
| 31 |
"source": [
|
| 32 |
"import os\n",
|
| 33 |
"import random\n",
|
|
@@ -82,7 +80,9 @@
|
|
| 82 |
" return Dataset.from_list(records)\n",
|
| 83 |
"\n",
|
| 84 |
"print(f\"Loaded {len(ALL_TASKS)} tasks (easy/medium/hard). Sample task ids:\", [t['id'] for t in ALL_TASKS[:3]], \"...\")"
|
| 85 |
-
]
|
|
|
|
|
|
|
| 86 |
},
|
| 87 |
{
|
| 88 |
"cell_type": "markdown",
|
|
@@ -93,9 +93,7 @@
|
|
| 93 |
},
|
| 94 |
{
|
| 95 |
"cell_type": "code",
|
| 96 |
-
"execution_count": null,
|
| 97 |
"metadata": {},
|
| 98 |
-
"outputs": [],
|
| 99 |
"source": [
|
| 100 |
"import torch\n",
|
| 101 |
"from unsloth import FastLanguageModel\n",
|
|
@@ -119,7 +117,9 @@
|
|
| 119 |
")\n",
|
| 120 |
"if tokenizer.pad_token is None:\n",
|
| 121 |
" tokenizer.pad_token = tokenizer.eos_token"
|
| 122 |
-
]
|
|
|
|
|
|
|
| 123 |
},
|
| 124 |
{
|
| 125 |
"cell_type": "markdown",
|
|
@@ -130,13 +130,13 @@
|
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "code",
|
| 133 |
-
"execution_count": null,
|
| 134 |
"metadata": {},
|
| 135 |
-
"outputs": [],
|
| 136 |
"source": [
|
| 137 |
"train_dataset = build_dataset()\n",
|
| 138 |
"print(f\"Dataset size: {len(train_dataset)} (target split ~50% easy / 30% medium / 20% hard)\")"
|
| 139 |
-
]
|
|
|
|
|
|
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"cell_type": "markdown",
|
|
@@ -147,9 +147,7 @@
|
|
| 147 |
},
|
| 148 |
{
|
| 149 |
"cell_type": "code",
|
| 150 |
-
"execution_count": null,
|
| 151 |
"metadata": {},
|
| 152 |
-
"outputs": [],
|
| 153 |
"source": [
|
| 154 |
"def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs):\n",
|
| 155 |
" \"\"\"How closely the completion matches the reference `correct_yaml` (full match, partial, unchanged, or wrong).\"\"\"\n",
|
|
@@ -188,20 +186,29 @@
|
|
| 188 |
" return [-0.3 if any(p.lower() in c.lower() for p in bad) else 0.3 for c in completions]\n",
|
| 189 |
"\n",
|
| 190 |
"REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]"
|
| 191 |
-
]
|
|
|
|
|
|
|
| 192 |
},
|
| 193 |
{
|
| 194 |
"cell_type": "markdown",
|
| 195 |
"metadata": {},
|
| 196 |
"source": [
|
| 197 |
-
"## 🚀
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
]
|
| 199 |
},
|
| 200 |
{
|
| 201 |
"cell_type": "code",
|
| 202 |
-
"execution_count": null,
|
| 203 |
"metadata": {},
|
| 204 |
-
"outputs": [],
|
| 205 |
"source": [
|
| 206 |
"import wandb\n",
|
| 207 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
|
@@ -232,7 +239,9 @@
|
|
| 232 |
")\n",
|
| 233 |
"wandb.init(project=\"cicd-rl-agent\")\n",
|
| 234 |
"trainer.train()"
|
| 235 |
-
]
|
|
|
|
|
|
|
| 236 |
},
|
| 237 |
{
|
| 238 |
"cell_type": "markdown",
|
|
@@ -243,9 +252,7 @@
|
|
| 243 |
},
|
| 244 |
{
|
| 245 |
"cell_type": "code",
|
| 246 |
-
"execution_count": null,
|
| 247 |
"metadata": {},
|
| 248 |
-
"outputs": [],
|
| 249 |
"source": [
|
| 250 |
"import matplotlib.pyplot as plt\n",
|
| 251 |
"\n",
|
|
@@ -269,7 +276,9 @@
|
|
| 269 |
"plt.tight_layout()\n",
|
| 270 |
"plt.savefig(\"reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 271 |
"plt.show()"
|
| 272 |
-
]
|
|
|
|
|
|
|
| 273 |
},
|
| 274 |
{
|
| 275 |
"cell_type": "markdown",
|
|
@@ -280,9 +289,7 @@
|
|
| 280 |
},
|
| 281 |
{
|
| 282 |
"cell_type": "code",
|
| 283 |
-
"execution_count": null,
|
| 284 |
"metadata": {},
|
| 285 |
-
"outputs": [],
|
| 286 |
"source": [
|
| 287 |
"def generate_yaml(model, tok, task: dict) -> str:\n",
|
| 288 |
" FastLanguageModel.for_inference(model)\n",
|
|
@@ -322,7 +329,9 @@
|
|
| 322 |
" print(out_train[:800])\n",
|
| 323 |
" print(f\"\\nBase matches correct_yaml: {ok_base}\")\n",
|
| 324 |
" print(f\"Trained matches correct_yaml: {ok_train}\")"
|
| 325 |
-
]
|
|
|
|
|
|
|
| 326 |
}
|
| 327 |
],
|
| 328 |
"metadata": {
|
|
@@ -338,4 +347,4 @@
|
|
| 338 |
},
|
| 339 |
"nbformat": 4,
|
| 340 |
"nbformat_minor": 4
|
| 341 |
-
}
|
|
|
|
| 9 |
},
|
| 10 |
{
|
| 11 |
"cell_type": "code",
|
|
|
|
| 12 |
"metadata": {},
|
|
|
|
| 13 |
"source": [
|
| 14 |
"!pip install unsloth trl transformers datasets torch wandb pydantic"
|
| 15 |
+
],
|
| 16 |
+
"execution_count": null,
|
| 17 |
+
"outputs": []
|
| 18 |
},
|
| 19 |
{
|
| 20 |
"cell_type": "markdown",
|
|
|
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"cell_type": "code",
|
|
|
|
| 28 |
"metadata": {},
|
|
|
|
| 29 |
"source": [
|
| 30 |
"import os\n",
|
| 31 |
"import random\n",
|
|
|
|
| 80 |
" return Dataset.from_list(records)\n",
|
| 81 |
"\n",
|
| 82 |
"print(f\"Loaded {len(ALL_TASKS)} tasks (easy/medium/hard). Sample task ids:\", [t['id'] for t in ALL_TASKS[:3]], \"...\")"
|
| 83 |
+
],
|
| 84 |
+
"execution_count": null,
|
| 85 |
+
"outputs": []
|
| 86 |
},
|
| 87 |
{
|
| 88 |
"cell_type": "markdown",
|
|
|
|
| 93 |
},
|
| 94 |
{
|
| 95 |
"cell_type": "code",
|
|
|
|
| 96 |
"metadata": {},
|
|
|
|
| 97 |
"source": [
|
| 98 |
"import torch\n",
|
| 99 |
"from unsloth import FastLanguageModel\n",
|
|
|
|
| 117 |
")\n",
|
| 118 |
"if tokenizer.pad_token is None:\n",
|
| 119 |
" tokenizer.pad_token = tokenizer.eos_token"
|
| 120 |
+
],
|
| 121 |
+
"execution_count": null,
|
| 122 |
+
"outputs": []
|
| 123 |
},
|
| 124 |
{
|
| 125 |
"cell_type": "markdown",
|
|
|
|
| 130 |
},
|
| 131 |
{
|
| 132 |
"cell_type": "code",
|
|
|
|
| 133 |
"metadata": {},
|
|
|
|
| 134 |
"source": [
|
| 135 |
"train_dataset = build_dataset()\n",
|
| 136 |
"print(f\"Dataset size: {len(train_dataset)} (target split ~50% easy / 30% medium / 20% hard)\")"
|
| 137 |
+
],
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"outputs": []
|
| 140 |
},
|
| 141 |
{
|
| 142 |
"cell_type": "markdown",
|
|
|
|
| 147 |
},
|
| 148 |
{
|
| 149 |
"cell_type": "code",
|
|
|
|
| 150 |
"metadata": {},
|
|
|
|
| 151 |
"source": [
|
| 152 |
"def reward_fix_correctness(completions, prompts, correct_yaml, pipeline_yaml, **kwargs):\n",
|
| 153 |
" \"\"\"How closely the completion matches the reference `correct_yaml` (full match, partial, unchanged, or wrong).\"\"\"\n",
|
|
|
|
| 186 |
" return [-0.3 if any(p.lower() in c.lower() for p in bad) else 0.3 for c in completions]\n",
|
| 187 |
"\n",
|
| 188 |
"REWARD_FUNCTIONS = [reward_fix_correctness, reward_yaml_structure, reward_no_hallucination]"
|
| 189 |
+
],
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"outputs": []
|
| 192 |
},
|
| 193 |
{
|
| 194 |
"cell_type": "markdown",
|
| 195 |
"metadata": {},
|
| 196 |
"source": [
|
| 197 |
+
"## 🚀 Training: SFT + GRPO (recommended) or GRPO in-notebook\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"**Best path (matches `train.py` in the repo):** in the repo root run:\n",
|
| 200 |
+
"`!cd $REPO_DIR && python train.py` \n",
|
| 201 |
+
"Default is a short **supervised (SFT)** pass on exact `correct_yaml`, then **GRPO** with correctness-weighted rewards. \n",
|
| 202 |
+
"- GRPO only (old one-stage): `python train.py --stages grpo` \n",
|
| 203 |
+
"- SFT only: `python train.py --stages sft` \n",
|
| 204 |
+
"- Two SFT epochs: `python train.py --sft-epochs 2`\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"**Alternative below:** the next cell runs **GRPO only** in the notebook (no SFT), like older Colab flows."
|
| 207 |
]
|
| 208 |
},
|
| 209 |
{
|
| 210 |
"cell_type": "code",
|
|
|
|
| 211 |
"metadata": {},
|
|
|
|
| 212 |
"source": [
|
| 213 |
"import wandb\n",
|
| 214 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
|
|
|
| 239 |
")\n",
|
| 240 |
"wandb.init(project=\"cicd-rl-agent\")\n",
|
| 241 |
"trainer.train()"
|
| 242 |
+
],
|
| 243 |
+
"execution_count": null,
|
| 244 |
+
"outputs": []
|
| 245 |
},
|
| 246 |
{
|
| 247 |
"cell_type": "markdown",
|
|
|
|
| 252 |
},
|
| 253 |
{
|
| 254 |
"cell_type": "code",
|
|
|
|
| 255 |
"metadata": {},
|
|
|
|
| 256 |
"source": [
|
| 257 |
"import matplotlib.pyplot as plt\n",
|
| 258 |
"\n",
|
|
|
|
| 276 |
"plt.tight_layout()\n",
|
| 277 |
"plt.savefig(\"reward_curve.png\", dpi=150, bbox_inches=\"tight\")\n",
|
| 278 |
"plt.show()"
|
| 279 |
+
],
|
| 280 |
+
"execution_count": null,
|
| 281 |
+
"outputs": []
|
| 282 |
},
|
| 283 |
{
|
| 284 |
"cell_type": "markdown",
|
|
|
|
| 289 |
},
|
| 290 |
{
|
| 291 |
"cell_type": "code",
|
|
|
|
| 292 |
"metadata": {},
|
|
|
|
| 293 |
"source": [
|
| 294 |
"def generate_yaml(model, tok, task: dict) -> str:\n",
|
| 295 |
" FastLanguageModel.for_inference(model)\n",
|
|
|
|
| 329 |
" print(out_train[:800])\n",
|
| 330 |
" print(f\"\\nBase matches correct_yaml: {ok_base}\")\n",
|
| 331 |
" print(f\"Trained matches correct_yaml: {ok_train}\")"
|
| 332 |
+
],
|
| 333 |
+
"execution_count": null,
|
| 334 |
+
"outputs": []
|
| 335 |
}
|
| 336 |
],
|
| 337 |
"metadata": {
|
|
|
|
| 347 |
},
|
| 348 |
"nbformat": 4,
|
| 349 |
"nbformat_minor": 4
|
| 350 |
+
}
|