Spaces:
Running
Running
Fix HF Jobs bootstrap (pin transformers/trl, drop torchao stack); add reward and trainer JSONL logging; stabilize launch_job.
Browse files- launch_job.py +86 -11
- ultimate_sota_training.py +339 -100
launch_job.py
CHANGED
|
@@ -1,13 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from huggingface_hub import HfApi
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
)
|
| 11 |
-
print("JOB_ID:", job.job_id)
|
| 12 |
-
except Exception as e:
|
| 13 |
-
print("FAILED:", str(e))
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Submit ultimate_sota_training.py to Hugging Face GPU Jobs (HfApi.run_job).
|
| 3 |
+
|
| 4 |
+
The Job command must be a single robust shell line (semicolon-separated). Hugging Face
|
| 5 |
+
has been observed to flatten multiline `bash -lc` payloads, which breaks `set` and can
|
| 6 |
+
leave the job stuck or failing silently.
|
| 7 |
+
|
| 8 |
+
Requires: huggingface_hub, `huggingface-cli login`.
|
| 9 |
+
|
| 10 |
+
Secrets: if SKIP_HUB_PUSH is not 1, the job requests Hub secret name HF_TOKEN mapped into
|
| 11 |
+
the container as env HF_TOKEN (Settings → Access Tokens / Job secrets).
|
| 12 |
+
|
| 13 |
+
Environment (optional):
|
| 14 |
+
HF_JOB_NAMESPACE default: whoami
|
| 15 |
+
HF_JOB_FLAVOR default: l4x1 (often faster than T4 for this workload; override with t4-small to save $)
|
| 16 |
+
HF_JOB_IMAGE default: pytorch CUDA 12.4 devel
|
| 17 |
+
HF_JOB_TIMEOUT default: 8h
|
| 18 |
+
TRAIN_REPO_GIT_URL, OPENENV_BASE_URL
|
| 19 |
+
TRAIN_MAX_STEPS default: 80 (faster run; raise for stronger fit)
|
| 20 |
+
ROWS_PER_TASK default: 32
|
| 21 |
+
GRPO_NUM_GENERATIONS default: 6
|
| 22 |
+
SKIP_HUB_PUSH default: 0
|
| 23 |
+
"""
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import os
|
| 27 |
+
import shlex
|
| 28 |
+
|
| 29 |
from huggingface_hub import HfApi
|
| 30 |
+
|
| 31 |
+
_DEFAULT_REPO = "https://huggingface.co/spaces/md896/sql-debug-env"
|
| 32 |
+
_REPO_URL = os.environ.get("TRAIN_REPO_GIT_URL", _DEFAULT_REPO)
|
| 33 |
+
_OPENENV = os.environ.get("OPENENV_BASE_URL", "https://md896-sql-debug-env.hf.space")
|
| 34 |
+
_MAX_STEPS = os.environ.get("TRAIN_MAX_STEPS", "80")
|
| 35 |
+
_ROWS = os.environ.get("ROWS_PER_TASK", "32")
|
| 36 |
+
_NUM_GEN = os.environ.get("GRPO_NUM_GENERATIONS", "6")
|
| 37 |
+
_SKIP_PUSH = os.environ.get("SKIP_HUB_PUSH", "0")
|
| 38 |
+
_TIMEOUT = os.environ.get("HF_JOB_TIMEOUT", "8h")
|
| 39 |
+
# l4x1: newer GPU, good for Unsloth; use HF_JOB_FLAVOR=t4-small if queue or cost is better for you
|
| 40 |
+
_FLAVOR = os.environ.get("HF_JOB_FLAVOR", "l4x1")
|
| 41 |
+
_IMAGE = os.environ.get(
|
| 42 |
+
"HF_JOB_IMAGE",
|
| 43 |
+
"pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
|
| 44 |
+
)
|
| 45 |
+
_NAMESPACE = os.environ.get("HF_JOB_NAMESPACE")
|
| 46 |
+
|
| 47 |
+
_SECRETS = None
|
| 48 |
+
if _SKIP_PUSH.strip().lower() not in ("1", "true", "yes"):
|
| 49 |
+
_SECRETS = {"HF_TOKEN": "HF_TOKEN"}
|
| 50 |
+
|
| 51 |
+
# One line only — survives UI/API newline flattening.
|
| 52 |
+
_bash = (
|
| 53 |
+
"set -euxo pipefail; "
|
| 54 |
+
"export DEBIAN_FRONTEND=noninteractive; "
|
| 55 |
+
"apt-get update -qq && apt-get install -y -qq git ca-certificates; "
|
| 56 |
+
"export PIP_BREAK_SYSTEM_PACKAGES=1; "
|
| 57 |
+
f"rm -rf train-repo; git clone {shlex.quote(_REPO_URL)} train-repo; "
|
| 58 |
+
"cd train-repo; "
|
| 59 |
+
"python -u ultimate_sota_training.py"
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
_job_env = {
|
| 63 |
+
"OPENENV_BASE_URL": _OPENENV,
|
| 64 |
+
"TRAIN_MAX_STEPS": _MAX_STEPS,
|
| 65 |
+
"ROWS_PER_TASK": _ROWS,
|
| 66 |
+
"GRPO_NUM_GENERATIONS": _NUM_GEN,
|
| 67 |
+
"SKIP_HUB_PUSH": _SKIP_PUSH,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if __name__ == "__main__":
|
| 71 |
+
api = HfApi()
|
| 72 |
+
ns = _NAMESPACE or api.whoami()["name"]
|
| 73 |
+
job = api.run_job(
|
| 74 |
+
image=_IMAGE,
|
| 75 |
+
command=["bash", "-lc", _bash],
|
| 76 |
+
flavor=_FLAVOR,
|
| 77 |
+
namespace=ns,
|
| 78 |
+
timeout=_TIMEOUT,
|
| 79 |
+
secrets=_SECRETS,
|
| 80 |
+
env=_job_env,
|
| 81 |
+
)
|
| 82 |
+
print("JOB_ID:", job.id)
|
| 83 |
+
print("JOB_URL:", job.url)
|
| 84 |
+
print("FLAVOR:", _FLAVOR, "| TRAIN_MAX_STEPS:", _MAX_STEPS, "| ROWS_PER_TASK:", _ROWS)
|
| 85 |
+
print(
|
| 86 |
+
"Note: SCHEDULING is Hugging Face queue time, not your script. "
|
| 87 |
+
"Cancel stuck jobs and retry, or try HF_JOB_FLAVOR=t4-small / t4-medium."
|
| 88 |
)
|
|
|
|
|
|
|
|
|
ultimate_sota_training.py
CHANGED
|
@@ -1,10 +1,27 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
- system Python may be externally managed (PEP-668) → uses --break-system-packages
|
| 9 |
- preinstalled CUDA/PyTorch stacks can conflict with optional vision packages
|
| 10 |
|
|
@@ -16,17 +33,23 @@ Key stability choices:
|
|
| 16 |
|
| 17 |
from __future__ import annotations
|
| 18 |
|
|
|
|
| 19 |
import json
|
|
|
|
| 20 |
import os
|
| 21 |
import random
|
| 22 |
import re
|
| 23 |
import subprocess
|
| 24 |
import sys
|
| 25 |
import time
|
|
|
|
| 26 |
from dataclasses import dataclass
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Any, Dict, List, Optional
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
|
| 32 |
return subprocess.run(cmd, check=check)
|
|
@@ -41,6 +64,7 @@ def bootstrap_deps() -> None:
|
|
| 41 |
Best-effort dependency bootstrap for ephemeral HF containers.
|
| 42 |
|
| 43 |
Set SKIP_BOOTSTRAP=1 to disable.
|
|
|
|
| 44 |
"""
|
| 45 |
if os.environ.get("SKIP_BOOTSTRAP") == "1":
|
| 46 |
return
|
|
@@ -53,31 +77,39 @@ def bootstrap_deps() -> None:
|
|
| 53 |
# (PEP-668). Prefer an explicit opt-out for all pip ops in ephemeral jobs.
|
| 54 |
os.environ.setdefault("PIP_BREAK_SYSTEM_PACKAGES", "1")
|
| 55 |
|
| 56 |
-
print("
|
| 57 |
|
| 58 |
# Text-only run: torchvision/torchaudio are not required and are a common source
|
| 59 |
# of crashes when torch versions shift in container images.
|
| 60 |
_pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
|
| 61 |
|
| 62 |
-
|
|
|
|
| 63 |
_pip(
|
| 64 |
[
|
| 65 |
"install",
|
| 66 |
"--break-system-packages",
|
| 67 |
"httpx>=0.27.0",
|
| 68 |
"datasets>=3.4.1,<4.4.0",
|
| 69 |
-
"trl>=0.18.2,<=0.22.2",
|
| 70 |
-
"mergekit",
|
| 71 |
-
"llm-blender",
|
| 72 |
-
"weave",
|
| 73 |
-
"wandb",
|
| 74 |
"matplotlib",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
]
|
| 76 |
)
|
| 77 |
|
| 78 |
-
# Unsloth (and its dependency set) can be fast-moving; install from git.
|
| 79 |
-
# Build isolation/resolution can sometimes change torch; removing torchvision
|
| 80 |
-
# above keeps transformers imports stable for text-only workloads.
|
| 81 |
_pip(
|
| 82 |
[
|
| 83 |
"install",
|
|
@@ -86,10 +118,32 @@ def bootstrap_deps() -> None:
|
|
| 86 |
]
|
| 87 |
)
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
_pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
bootstrap_deps()
|
| 95 |
|
|
@@ -126,71 +180,146 @@ import transformers.utils.hub
|
|
| 126 |
if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
|
| 127 |
transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
|
| 128 |
|
|
|
|
| 129 |
from trl import GRPOConfig, GRPOTrainer
|
| 130 |
from unsloth import FastLanguageModel
|
| 131 |
|
| 132 |
-
# --- 1. CONFIGURATION ---
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
|
|
|
|
|
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
After you have finished thinking, you MUST output the exact fixed SQL query inside <sql> tags.
|
| 144 |
Do not output any markdown blocks like ```sql.
|
| 145 |
|
| 146 |
Example:
|
| 147 |
-
<
|
| 148 |
-
I
|
| 149 |
-
</
|
| 150 |
<sql>
|
| 151 |
-
WITH OrderTotals AS (
|
|
|
|
| 152 |
</sql>"""
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
resp = client.post("/reset", json={"task_id": t_id})
|
|
|
|
| 162 |
obs = resp.json()["observation"]
|
| 163 |
-
|
| 164 |
prompt = (
|
| 165 |
-
f"{
|
| 166 |
f"Task: {obs['task_description']}\n"
|
| 167 |
f"Broken Query: {obs['original_query']}\n\n"
|
| 168 |
-
"Provide your <think> and <sql> output:"
|
| 169 |
)
|
| 170 |
-
|
| 171 |
-
for _ in range(40):
|
| 172 |
rows.append({"prompt": prompt, "task_id": t_id})
|
| 173 |
-
|
| 174 |
if not rows:
|
| 175 |
-
raise RuntimeError("Failed to
|
|
|
|
| 176 |
return Dataset.from_list(rows)
|
| 177 |
|
| 178 |
-
# --- 3. MULTI-REWARD SHAPING (
|
|
|
|
|
|
|
|
|
|
| 179 |
|
| 180 |
def extract_xml_tag(text, tag):
|
| 181 |
pattern = f"<{tag}>(.*?)</{tag}>"
|
| 182 |
match = re.search(pattern, text, re.DOTALL)
|
| 183 |
return match.group(1).strip() if match else None
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
def format_reward_func(completions, **kwargs):
|
| 186 |
-
"""Reward 1:
|
|
|
|
| 187 |
rewards = []
|
| 188 |
for comp in completions:
|
| 189 |
-
has_think = extract_xml_tag(comp,
|
| 190 |
has_sql = extract_xml_tag(comp, "sql") is not None
|
| 191 |
rewards.append(0.1 if (has_think and has_sql) else 0.0)
|
|
|
|
| 192 |
return rewards
|
| 193 |
|
|
|
|
| 194 |
def syntax_reward_func(completions, **kwargs):
|
| 195 |
"""Reward 2: Does the SQL look like valid code? (+0.2)"""
|
| 196 |
rewards = []
|
|
@@ -200,29 +329,85 @@ def syntax_reward_func(completions, **kwargs):
|
|
| 200 |
rewards.append(0.2)
|
| 201 |
else:
|
| 202 |
rewards.append(0.0)
|
|
|
|
| 203 |
return rewards
|
| 204 |
|
|
|
|
| 205 |
def execution_reward_func(completions, task_id, **kwargs):
|
| 206 |
-
"""Reward 3:
|
| 207 |
-
rewards = []
|
| 208 |
-
|
|
|
|
|
|
|
| 209 |
for query, t_id in zip(completions, task_id):
|
| 210 |
sql = extract_xml_tag(query, "sql")
|
| 211 |
if not sql:
|
| 212 |
-
rewards.append(0.0)
|
| 213 |
continue
|
| 214 |
-
|
|
|
|
| 215 |
try:
|
| 216 |
-
client.post("/reset", json={"task_id": t_id})
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
except Exception:
|
| 220 |
reward = 0.0
|
| 221 |
-
|
| 222 |
-
reward += random.uniform(-1e-6, 1e-6)
|
| 223 |
rewards.append(reward)
|
|
|
|
| 224 |
return rewards
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
# --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
|
| 227 |
|
| 228 |
@dataclass(frozen=True)
|
|
@@ -329,17 +514,42 @@ def plot_reward_curve(reward_series: List[tuple[float, float]], paths: ArtifactP
|
|
| 329 |
_ensure_dir(paths.root)
|
| 330 |
plt.tight_layout()
|
| 331 |
plt.savefig(paths.reward_curve_png, dpi=200)
|
| 332 |
-
print(f"
|
|
|
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
|
|
|
|
| 336 |
def run_sota_train():
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 341 |
model_name=MODEL_NAME,
|
| 342 |
-
max_seq_length=
|
| 343 |
load_in_4bit=True,
|
| 344 |
)
|
| 345 |
|
|
@@ -357,10 +567,7 @@ def run_sota_train():
|
|
| 357 |
|
| 358 |
def quick_exec_eval(max_items: int = 8) -> float:
|
| 359 |
"""
|
| 360 |
-
Quick before/after check:
|
| 361 |
-
- sample a few prompts
|
| 362 |
-
- generate <think>/<sql>
|
| 363 |
-
- score via live execution reward
|
| 364 |
"""
|
| 365 |
subset = train_dataset.select(range(min(max_items, len(train_dataset))))
|
| 366 |
prompts = subset["prompt"]
|
|
@@ -382,39 +589,60 @@ def run_sota_train():
|
|
| 382 |
rewards = execution_reward_func(completions, task_ids)
|
| 383 |
return float(sum(rewards) / max(len(rewards), 1))
|
| 384 |
|
| 385 |
-
print("
|
| 386 |
baseline_avg_reward = quick_exec_eval()
|
| 387 |
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
|
| 402 |
trainer = GRPOTrainer(
|
| 403 |
model=model,
|
| 404 |
-
reward_funcs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
args=training_args,
|
| 406 |
train_dataset=train_dataset,
|
| 407 |
processing_class=tokenizer,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
)
|
| 409 |
|
| 410 |
-
print("
|
| 411 |
trainer.train()
|
| 412 |
|
| 413 |
-
print("
|
| 414 |
post_avg_reward = quick_exec_eval()
|
| 415 |
|
| 416 |
# --- Save artifacts (real logs/plots) ---
|
| 417 |
-
artifacts = ArtifactPaths(root=Path(
|
| 418 |
log_history = getattr(trainer.state, "log_history", []) or []
|
| 419 |
save_log_history(log_history, artifacts)
|
| 420 |
reward_series = extract_reward_series(log_history)
|
|
@@ -427,9 +655,16 @@ def run_sota_train():
|
|
| 427 |
metrics = {}
|
| 428 |
metrics.update(
|
| 429 |
{
|
|
|
|
|
|
|
|
|
|
| 430 |
"baseline_avg_reward": baseline_avg_reward,
|
| 431 |
"post_avg_reward": post_avg_reward,
|
| 432 |
"delta_avg_reward": post_avg_reward - baseline_avg_reward,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
}
|
| 434 |
)
|
| 435 |
metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
|
@@ -447,22 +682,26 @@ def run_sota_train():
|
|
| 447 |
out_path = artifacts.root / "before_after_avg_reward.png"
|
| 448 |
plt.tight_layout()
|
| 449 |
plt.savefig(out_path, dpi=200)
|
| 450 |
-
print(f"
|
| 451 |
except Exception as e:
|
| 452 |
-
print(f"
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
print("
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
if __name__ == "__main__":
|
| 468 |
run_sota_train()
|
|
|
|
| 1 |
"""
|
| 2 |
+
Unsloth + OpenEnv GRPO training (production-oriented).
|
| 3 |
+
|
| 4 |
+
Produces real training artifacts (trainer log_history, metrics JSON, reward plots) and
|
| 5 |
+
optional Hub push of LoRA weights. Every execution reward calls your live Space (or
|
| 6 |
+
local server) at OPENENV_BASE_URL — not a mock.
|
| 7 |
+
|
| 8 |
+
Environment (control cost vs quality on HF Jobs / local GPU):
|
| 9 |
+
OPENENV_BASE_URL — OpenEnv HTTP root (default: Space URL from openenv.yaml)
|
| 10 |
+
OPENENV_TASK_IDS — Comma list; if unset, uses GET /tasks from the server
|
| 11 |
+
ROWS_PER_TASK — GRPO rows per task_id (default: 48)
|
| 12 |
+
OPENENV_REQUEST_TIMEOUT_SEC — HTTP timeout for reset/step (default: 120)
|
| 13 |
+
REASONING_XML_TAG — XML tag name for chain-of-thought (default: think)
|
| 14 |
+
TRAIN_MAX_STEPS — GRPO optimizer steps (default: 200; was 30 for smoke)
|
| 15 |
+
TRAIN_NUM_EPOCHS, TRAIN_LR, GRPO_NUM_GENERATIONS, GRPO_MAX_COMPLETION_LEN
|
| 16 |
+
PER_DEVICE_TRAIN_BS, GRAD_ACCUM
|
| 17 |
+
TRL_REPORT_TO — none | wandb | tensorboard (auto: wandb if key else tensorboard)
|
| 18 |
+
BOOTSTRAP_*_VERSION — pin transformers / accelerate / trl for HF Jobs (see bootstrap_deps)
|
| 19 |
+
Artifacts: artifacts/reward_components.jsonl, artifacts/trainer_on_log.jsonl, tensorboard/
|
| 20 |
+
HF_HUB_REPO_ID — push target (default md896/sota-sql-agent-7b)
|
| 21 |
+
SKIP_HUB_PUSH=1 — do not push after train
|
| 22 |
+
HF_TOKEN / HUGGING_FACE_HUB_TOKEN — Hub auth for push
|
| 23 |
+
|
| 24 |
+
Designed for Hugging Face Jobs / Spaces where:
|
| 25 |
- system Python may be externally managed (PEP-668) → uses --break-system-packages
|
| 26 |
- preinstalled CUDA/PyTorch stacks can conflict with optional vision packages
|
| 27 |
|
|
|
|
| 33 |
|
| 34 |
from __future__ import annotations
|
| 35 |
|
| 36 |
+
import contextvars
|
| 37 |
import json
|
| 38 |
+
import math
|
| 39 |
import os
|
| 40 |
import random
|
| 41 |
import re
|
| 42 |
import subprocess
|
| 43 |
import sys
|
| 44 |
import time
|
| 45 |
+
import uuid
|
| 46 |
from dataclasses import dataclass
|
| 47 |
from pathlib import Path
|
| 48 |
from typing import Any, Dict, List, Optional
|
| 49 |
|
| 50 |
+
# Set by TrainerCallback so reward funcs can tag JSONL rows with the real global_step.
|
| 51 |
+
CURRENT_GRPO_STEP: contextvars.ContextVar[int] = contextvars.ContextVar("CURRENT_GRPO_STEP", default=-1)
|
| 52 |
+
|
| 53 |
|
| 54 |
def _run(cmd: List[str], *, check: bool = True) -> subprocess.CompletedProcess:
|
| 55 |
return subprocess.run(cmd, check=check)
|
|
|
|
| 64 |
Best-effort dependency bootstrap for ephemeral HF containers.
|
| 65 |
|
| 66 |
Set SKIP_BOOTSTRAP=1 to disable.
|
| 67 |
+
Pins: BOOTSTRAP_TRANSFORMERS_VERSION, BOOTSTRAP_ACCELERATE_VERSION, BOOTSTRAP_TRL_VERSION.
|
| 68 |
"""
|
| 69 |
if os.environ.get("SKIP_BOOTSTRAP") == "1":
|
| 70 |
return
|
|
|
|
| 77 |
# (PEP-668). Prefer an explicit opt-out for all pip ops in ephemeral jobs.
|
| 78 |
os.environ.setdefault("PIP_BREAK_SYSTEM_PACKAGES", "1")
|
| 79 |
|
| 80 |
+
print("Bootstrapping dependencies...")
|
| 81 |
|
| 82 |
# Text-only run: torchvision/torchaudio are not required and are a common source
|
| 83 |
# of crashes when torch versions shift in container images.
|
| 84 |
_pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
|
| 85 |
|
| 86 |
+
_pip(["uninstall", "-y", "torchao"], check=False)
|
| 87 |
+
|
| 88 |
_pip(
|
| 89 |
[
|
| 90 |
"install",
|
| 91 |
"--break-system-packages",
|
| 92 |
"httpx>=0.27.0",
|
| 93 |
"datasets>=3.4.1,<4.4.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
"matplotlib",
|
| 95 |
+
"tensorboard",
|
| 96 |
+
"wandb",
|
| 97 |
+
]
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
_tf = os.environ.get("BOOTSTRAP_TRANSFORMERS_VERSION", "4.48.3")
|
| 101 |
+
_acc = os.environ.get("BOOTSTRAP_ACCELERATE_VERSION", "0.34.2")
|
| 102 |
+
_trl = os.environ.get("BOOTSTRAP_TRL_VERSION", "0.18.2")
|
| 103 |
+
_pip(
|
| 104 |
+
[
|
| 105 |
+
"install",
|
| 106 |
+
"--break-system-packages",
|
| 107 |
+
f"transformers=={_tf}",
|
| 108 |
+
f"accelerate=={_acc}",
|
| 109 |
+
f"trl=={_trl}",
|
| 110 |
]
|
| 111 |
)
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
_pip(
|
| 114 |
[
|
| 115 |
"install",
|
|
|
|
| 118 |
]
|
| 119 |
)
|
| 120 |
|
| 121 |
+
_pip(
|
| 122 |
+
[
|
| 123 |
+
"install",
|
| 124 |
+
"--break-system-packages",
|
| 125 |
+
"--force-reinstall",
|
| 126 |
+
"--no-deps",
|
| 127 |
+
f"transformers=={_tf}",
|
| 128 |
+
f"accelerate=={_acc}",
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
_pip(["install", "--break-system-packages", "--no-deps", f"trl=={_trl}"])
|
| 132 |
+
|
| 133 |
+
_pip(["uninstall", "-y", "torchao"], check=False)
|
| 134 |
_pip(["uninstall", "--break-system-packages", "-y", "torchvision", "torchaudio"], check=False)
|
| 135 |
|
| 136 |
+
try:
|
| 137 |
+
import accelerate # noqa: F401
|
| 138 |
+
import transformers # noqa: F401
|
| 139 |
+
from trl import GRPOConfig as _BootstrapGRPOConfig # noqa: F401
|
| 140 |
+
|
| 141 |
+
_ = _BootstrapGRPOConfig
|
| 142 |
+
except Exception as e:
|
| 143 |
+
raise RuntimeError(
|
| 144 |
+
"Post-bootstrap import check failed. Adjust BOOTSTRAP_*_VERSION or SKIP_BOOTSTRAP=1."
|
| 145 |
+
) from e
|
| 146 |
+
|
| 147 |
|
| 148 |
bootstrap_deps()
|
| 149 |
|
|
|
|
| 180 |
if not hasattr(transformers.utils.hub, "TRANSFORMERS_CACHE"):
|
| 181 |
transformers.utils.hub.TRANSFORMERS_CACHE = "/tmp"
|
| 182 |
|
| 183 |
+
from transformers import TrainerCallback
|
| 184 |
from trl import GRPOConfig, GRPOTrainer
|
| 185 |
from unsloth import FastLanguageModel
|
| 186 |
|
| 187 |
+
# --- 1. CONFIGURATION (env-first; defaults match openenv.yaml) ---
|
| 188 |
+
_DEFAULT_OPENENV_BASE = "https://md896-sql-debug-env.hf.space"
|
| 189 |
+
BYPASS_HEADERS: Dict[str, str] = {}
|
| 190 |
+
|
| 191 |
+
MODEL_NAME = os.environ.get("TRAIN_MODEL_NAME", "unsloth/Qwen2.5-Coder-7B-Instruct")
|
| 192 |
+
|
| 193 |
|
| 194 |
+
def get_bridge_url() -> str:
|
| 195 |
+
return os.environ.get("OPENENV_BASE_URL", _DEFAULT_OPENENV_BASE).rstrip("/")
|
| 196 |
|
| 197 |
+
|
| 198 |
+
def get_request_timeout() -> float:
|
| 199 |
+
return float(os.environ.get("OPENENV_REQUEST_TIMEOUT_SEC", "120"))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def build_system_prompt() -> str:
|
| 203 |
+
"""Single prompt template for every task (easy → expert); tag name is configurable."""
|
| 204 |
+
tag = os.environ.get("REASONING_XML_TAG", "think")
|
| 205 |
+
return f"""You are an elite SQL engineer. You fix broken SQLite analytics queries using the task description and the broken query.
|
| 206 |
+
You MUST output your reasoning process inside <{tag}> tags.
|
| 207 |
After you have finished thinking, you MUST output the exact fixed SQL query inside <sql> tags.
|
| 208 |
Do not output any markdown blocks like ```sql.
|
| 209 |
|
| 210 |
Example:
|
| 211 |
+
<{tag}>
|
| 212 |
+
I will check joins, filters, and aggregation, then write a corrected SELECT or WITH query.
|
| 213 |
+
</{tag}>
|
| 214 |
<sql>
|
| 215 |
+
WITH OrderTotals AS (SELECT order_id, SUM(amount) AS total FROM line_items GROUP BY order_id)
|
| 216 |
+
SELECT o.id, ot.total FROM orders o JOIN OrderTotals ot ON o.id = ot.order_id;
|
| 217 |
</sql>"""
|
| 218 |
|
| 219 |
+
|
| 220 |
+
def _fetch_task_ids(client: httpx.Client) -> List[str]:
|
| 221 |
+
raw = os.environ.get("OPENENV_TASK_IDS", "").strip()
|
| 222 |
+
if raw:
|
| 223 |
+
return [x.strip() for x in raw.split(",") if x.strip()]
|
| 224 |
+
r = client.get("/tasks", timeout=get_request_timeout())
|
| 225 |
+
r.raise_for_status()
|
| 226 |
+
body = r.json()
|
| 227 |
+
tasks = body.get("tasks") or []
|
| 228 |
+
ids = [t["task_id"] for t in tasks if t.get("task_id")]
|
| 229 |
+
if not ids:
|
| 230 |
+
raise RuntimeError("/tasks returned no task_id entries")
|
| 231 |
+
return ids
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def make_real_dataset() -> Dataset:
|
| 235 |
+
bridge = get_bridge_url()
|
| 236 |
+
timeout = get_request_timeout()
|
| 237 |
+
rows_per_task = max(1, int(os.environ.get("ROWS_PER_TASK", "48")))
|
| 238 |
+
system = build_system_prompt()
|
| 239 |
+
|
| 240 |
+
print(f"Connecting to OpenEnv at {bridge} (timeout={timeout}s)...")
|
| 241 |
+
rows: List[Dict[str, Any]] = []
|
| 242 |
+
|
| 243 |
+
with httpx.Client(base_url=bridge, headers=BYPASS_HEADERS, timeout=timeout) as client:
|
| 244 |
+
h = client.get("/health", timeout=min(30.0, timeout))
|
| 245 |
+
h.raise_for_status()
|
| 246 |
+
print(f"OpenEnv health: {h.json()}")
|
| 247 |
+
|
| 248 |
+
task_ids = _fetch_task_ids(client)
|
| 249 |
+
print(f"Training task_ids ({len(task_ids)}): {task_ids}")
|
| 250 |
+
|
| 251 |
+
for t_id in task_ids:
|
| 252 |
resp = client.post("/reset", json={"task_id": t_id})
|
| 253 |
+
resp.raise_for_status()
|
| 254 |
obs = resp.json()["observation"]
|
| 255 |
+
|
| 256 |
prompt = (
|
| 257 |
+
f"{system}\n\n"
|
| 258 |
f"Task: {obs['task_description']}\n"
|
| 259 |
f"Broken Query: {obs['original_query']}\n\n"
|
| 260 |
+
f"Provide your <{os.environ.get('REASONING_XML_TAG', 'think')}> and <sql> output:"
|
| 261 |
)
|
| 262 |
+
for _ in range(rows_per_task):
|
|
|
|
| 263 |
rows.append({"prompt": prompt, "task_id": t_id})
|
| 264 |
+
|
| 265 |
if not rows:
|
| 266 |
+
raise RuntimeError("Failed to build dataset (no rows).")
|
| 267 |
+
print(f"Dataset: {len(rows)} prompts ({rows_per_task} per task).")
|
| 268 |
return Dataset.from_list(rows)
|
| 269 |
|
| 270 |
+
# --- 3. MULTI-REWARD SHAPING + JSONL logging (per-component batch stats) ---
|
| 271 |
+
|
| 272 |
+
_REWARD_COMPONENTS_JSONL: Optional[Path] = None
|
| 273 |
+
|
| 274 |
|
| 275 |
def extract_xml_tag(text, tag):
|
| 276 |
pattern = f"<{tag}>(.*?)</{tag}>"
|
| 277 |
match = re.search(pattern, text, re.DOTALL)
|
| 278 |
return match.group(1).strip() if match else None
|
| 279 |
|
| 280 |
+
|
| 281 |
+
def _reward_batch_stats(values: List[float]) -> Dict[str, float]:
|
| 282 |
+
if not values:
|
| 283 |
+
return {"mean": 0.0, "std": 0.0, "min": 0.0, "max": 0.0}
|
| 284 |
+
n = len(values)
|
| 285 |
+
mean = sum(values) / n
|
| 286 |
+
var = sum((x - mean) ** 2 for x in values) / max(n - 1, 1)
|
| 287 |
+
return {"mean": mean, "std": math.sqrt(var), "min": min(values), "max": max(values)}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def _append_jsonl(path: Path, row: Dict[str, Any]) -> None:
|
| 291 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 292 |
+
with path.open("a", encoding="utf-8") as f:
|
| 293 |
+
f.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _log_reward_component(name: str, values: List[float]) -> None:
|
| 297 |
+
if _REWARD_COMPONENTS_JSONL is None:
|
| 298 |
+
return
|
| 299 |
+
_append_jsonl(
|
| 300 |
+
_REWARD_COMPONENTS_JSONL,
|
| 301 |
+
{
|
| 302 |
+
"time_epoch_s": time.time(),
|
| 303 |
+
"global_step": CURRENT_GRPO_STEP.get(),
|
| 304 |
+
"reward_component": name,
|
| 305 |
+
"n": len(values),
|
| 306 |
+
**_reward_batch_stats(values),
|
| 307 |
+
},
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
def format_reward_func(completions, **kwargs):
|
| 312 |
+
"""Reward 1: CoT + sql XML tags (+0.1). Tag name follows REASONING_XML_TAG."""
|
| 313 |
+
tag = os.environ.get("REASONING_XML_TAG", "think")
|
| 314 |
rewards = []
|
| 315 |
for comp in completions:
|
| 316 |
+
has_think = extract_xml_tag(comp, tag) is not None
|
| 317 |
has_sql = extract_xml_tag(comp, "sql") is not None
|
| 318 |
rewards.append(0.1 if (has_think and has_sql) else 0.0)
|
| 319 |
+
_log_reward_component("format_xml", rewards)
|
| 320 |
return rewards
|
| 321 |
|
| 322 |
+
|
| 323 |
def syntax_reward_func(completions, **kwargs):
|
| 324 |
"""Reward 2: Does the SQL look like valid code? (+0.2)"""
|
| 325 |
rewards = []
|
|
|
|
| 329 |
rewards.append(0.2)
|
| 330 |
else:
|
| 331 |
rewards.append(0.0)
|
| 332 |
+
_log_reward_component("syntax_select_with", rewards)
|
| 333 |
return rewards
|
| 334 |
|
| 335 |
+
|
| 336 |
def execution_reward_func(completions, task_id, **kwargs):
|
| 337 |
+
"""Reward 3: live OpenEnv submit_query against the real Space/API (not a stub)."""
|
| 338 |
+
rewards: List[float] = []
|
| 339 |
+
base = get_bridge_url()
|
| 340 |
+
timeout = get_request_timeout()
|
| 341 |
+
with httpx.Client(base_url=base, headers=BYPASS_HEADERS, timeout=timeout) as client:
|
| 342 |
for query, t_id in zip(completions, task_id):
|
| 343 |
sql = extract_xml_tag(query, "sql")
|
| 344 |
if not sql:
|
| 345 |
+
rewards.append(0.0)
|
| 346 |
continue
|
| 347 |
+
|
| 348 |
+
session_headers = {"X-Session-Id": str(uuid.uuid4())}
|
| 349 |
try:
|
| 350 |
+
r0 = client.post("/reset", json={"task_id": t_id}, headers=session_headers)
|
| 351 |
+
r0.raise_for_status()
|
| 352 |
+
resp = client.post(
|
| 353 |
+
"/step",
|
| 354 |
+
json={"action": {"action_type": "submit_query", "query": sql}},
|
| 355 |
+
headers=session_headers,
|
| 356 |
+
)
|
| 357 |
+
resp.raise_for_status()
|
| 358 |
+
reward = float(resp.json().get("reward", 0.0))
|
| 359 |
except Exception:
|
| 360 |
reward = 0.0
|
| 361 |
+
|
| 362 |
+
reward += random.uniform(-1e-6, 1e-6)
|
| 363 |
rewards.append(reward)
|
| 364 |
+
_log_reward_component("openenv_execution", rewards)
|
| 365 |
return rewards
|
| 366 |
|
| 367 |
+
|
| 368 |
+
def length_shape_reward_func(completions, **kwargs):
|
| 369 |
+
"""Reward 4: soft preference for shorter completions (bounded; does not replace execution reward)."""
|
| 370 |
+
cap = float(os.environ.get("COMPLETION_SOFT_CHAR_CAP", "3500"))
|
| 371 |
+
bonus_max = float(os.environ.get("LENGTH_BONUS_MAX", "0.05"))
|
| 372 |
+
rewards: List[float] = []
|
| 373 |
+
for comp in completions:
|
| 374 |
+
L = len(comp) if comp else 0
|
| 375 |
+
if L <= 0:
|
| 376 |
+
rewards.append(0.0)
|
| 377 |
+
else:
|
| 378 |
+
rewards.append(bonus_max * max(0.0, 1.0 - min(L, cap) / cap))
|
| 379 |
+
_log_reward_component("length_shape", rewards)
|
| 380 |
+
return rewards
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class GrpoStepContextCallback(TrainerCallback):
|
| 384 |
+
"""Expose true global_step to reward funcs for JSONL alignment."""
|
| 385 |
+
|
| 386 |
+
def on_step_begin(self, args, state, control, **kwargs):
|
| 387 |
+
CURRENT_GRPO_STEP.set(int(state.global_step))
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class JsonlOnLogCallback(TrainerCallback):
|
| 391 |
+
"""Mirror every trainer `logs` dict to JSONL (loss, learning_rate, reward keys, etc.)."""
|
| 392 |
+
|
| 393 |
+
def __init__(self, path: Path):
|
| 394 |
+
self.path = path
|
| 395 |
+
self.path.parent.mkdir(parents=True, exist_ok=True)
|
| 396 |
+
self._fp = path.open("w", encoding="utf-8")
|
| 397 |
+
|
| 398 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 399 |
+
if not logs:
|
| 400 |
+
return
|
| 401 |
+
row: Dict[str, Any] = {"global_step": int(state.global_step), **dict(logs)}
|
| 402 |
+
self._fp.write(json.dumps(row, ensure_ascii=False, default=str) + "\n")
|
| 403 |
+
self._fp.flush()
|
| 404 |
+
|
| 405 |
+
def on_train_end(self, args, state, control, **kwargs):
|
| 406 |
+
try:
|
| 407 |
+
self._fp.close()
|
| 408 |
+
except Exception:
|
| 409 |
+
pass
|
| 410 |
+
|
| 411 |
# --- 3b. ARTIFACTS / PLOTS (REAL, FROM LOGS) ---
|
| 412 |
|
| 413 |
@dataclass(frozen=True)
|
|
|
|
| 514 |
_ensure_dir(paths.root)
|
| 515 |
plt.tight_layout()
|
| 516 |
plt.savefig(paths.reward_curve_png, dpi=200)
|
| 517 |
+
print(f"Saved {paths.reward_curve_png}")
|
| 518 |
+
|
| 519 |
|
| 520 |
+
def _resolve_report_to() -> str:
|
| 521 |
+
raw = os.environ.get("TRL_REPORT_TO", "").strip().lower()
|
| 522 |
+
if raw in ("", "auto"):
|
| 523 |
+
if os.environ.get("WANDB_API_KEY"):
|
| 524 |
+
return "wandb"
|
| 525 |
+
return "tensorboard"
|
| 526 |
+
if raw in ("false", "no", "off", "none"):
|
| 527 |
+
return "none"
|
| 528 |
+
return raw
|
| 529 |
|
| 530 |
+
|
| 531 |
+
# --- 4. Unsloth GRPO training loop (live OpenEnv rewards) ---
|
| 532 |
def run_sota_train():
|
| 533 |
+
global _REWARD_COMPONENTS_JSONL
|
| 534 |
+
|
| 535 |
+
max_steps = int(os.environ.get("TRAIN_MAX_STEPS", "200"))
|
| 536 |
+
out_dir = os.environ.get("OUTPUT_DIR", "./sota_results")
|
| 537 |
+
artifacts_early = Path(out_dir) / "artifacts"
|
| 538 |
+
_ensure_dir(artifacts_early)
|
| 539 |
+
_REWARD_COMPONENTS_JSONL = artifacts_early / "reward_components.jsonl"
|
| 540 |
+
_REWARD_COMPONENTS_JSONL.write_text("", encoding="utf-8")
|
| 541 |
+
|
| 542 |
+
print(f"Starting Unsloth GRPO on {MODEL_NAME}...")
|
| 543 |
+
print(
|
| 544 |
+
f"OpenEnv={get_bridge_url()} | max_steps={max_steps} | "
|
| 545 |
+
f"rows_per_task={os.environ.get('ROWS_PER_TASK', '48')} | "
|
| 546 |
+
f"report_to={_resolve_report_to()}"
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
max_seq = int(os.environ.get("MAX_SEQ_LENGTH", "1024"))
|
| 550 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 551 |
model_name=MODEL_NAME,
|
| 552 |
+
max_seq_length=max_seq,
|
| 553 |
load_in_4bit=True,
|
| 554 |
)
|
| 555 |
|
|
|
|
| 567 |
|
| 568 |
def quick_exec_eval(max_items: int = 8) -> float:
|
| 569 |
"""
|
| 570 |
+
Quick before/after check: sample prompts, generate CoT + sql, score via live OpenEnv.
|
|
|
|
|
|
|
|
|
|
| 571 |
"""
|
| 572 |
subset = train_dataset.select(range(min(max_items, len(train_dataset))))
|
| 573 |
prompts = subset["prompt"]
|
|
|
|
| 589 |
rewards = execution_reward_func(completions, task_ids)
|
| 590 |
return float(sum(rewards) / max(len(rewards), 1))
|
| 591 |
|
| 592 |
+
print("Quick baseline eval (pre-train)...")
|
| 593 |
baseline_avg_reward = quick_exec_eval()
|
| 594 |
|
| 595 |
+
report_to = _resolve_report_to()
|
| 596 |
+
tb_dir = Path(out_dir) / "tensorboard"
|
| 597 |
+
if report_to == "tensorboard":
|
| 598 |
+
_ensure_dir(tb_dir)
|
| 599 |
+
|
| 600 |
+
_cfg: Dict[str, Any] = dict(
|
| 601 |
+
output_dir=out_dir,
|
| 602 |
+
learning_rate=float(os.environ.get("TRAIN_LR", "5e-6")),
|
| 603 |
+
per_device_train_batch_size=int(os.environ.get("PER_DEVICE_TRAIN_BS", "1")),
|
| 604 |
+
gradient_accumulation_steps=int(os.environ.get("GRAD_ACCUM", "2")),
|
| 605 |
+
num_generations=int(os.environ.get("GRPO_NUM_GENERATIONS", "8")),
|
| 606 |
+
max_completion_length=int(os.environ.get("GRPO_MAX_COMPLETION_LEN", "512")),
|
| 607 |
+
temperature=float(os.environ.get("GRPO_TEMPERATURE", "0.9")),
|
| 608 |
+
num_train_epochs=int(os.environ.get("TRAIN_NUM_EPOCHS", "1")),
|
| 609 |
+
max_steps=max_steps,
|
| 610 |
+
logging_steps=int(os.environ.get("LOGGING_STEPS", "1")),
|
| 611 |
+
logging_first_step=True,
|
| 612 |
+
report_to=report_to,
|
| 613 |
)
|
| 614 |
+
if report_to == "tensorboard":
|
| 615 |
+
_cfg["logging_dir"] = str(tb_dir)
|
| 616 |
+
training_args = GRPOConfig(**_cfg)
|
| 617 |
+
|
| 618 |
+
trainer_logs_path = artifacts_early / "trainer_on_log.jsonl"
|
| 619 |
+
trainer_logs_path.write_text("", encoding="utf-8")
|
| 620 |
|
| 621 |
trainer = GRPOTrainer(
|
| 622 |
model=model,
|
| 623 |
+
reward_funcs=[
|
| 624 |
+
format_reward_func,
|
| 625 |
+
syntax_reward_func,
|
| 626 |
+
execution_reward_func,
|
| 627 |
+
length_shape_reward_func,
|
| 628 |
+
],
|
| 629 |
args=training_args,
|
| 630 |
train_dataset=train_dataset,
|
| 631 |
processing_class=tokenizer,
|
| 632 |
+
callbacks=[
|
| 633 |
+
GrpoStepContextCallback(),
|
| 634 |
+
JsonlOnLogCallback(trainer_logs_path),
|
| 635 |
+
],
|
| 636 |
)
|
| 637 |
|
| 638 |
+
print("Training with live execution rewards against OpenEnv...")
|
| 639 |
trainer.train()
|
| 640 |
|
| 641 |
+
print("Quick eval (post-train)...")
|
| 642 |
post_avg_reward = quick_exec_eval()
|
| 643 |
|
| 644 |
# --- Save artifacts (real logs/plots) ---
|
| 645 |
+
artifacts = ArtifactPaths(root=Path(out_dir) / "artifacts")
|
| 646 |
log_history = getattr(trainer.state, "log_history", []) or []
|
| 647 |
save_log_history(log_history, artifacts)
|
| 648 |
reward_series = extract_reward_series(log_history)
|
|
|
|
| 655 |
metrics = {}
|
| 656 |
metrics.update(
|
| 657 |
{
|
| 658 |
+
"openenv_base_url": get_bridge_url(),
|
| 659 |
+
"train_max_steps": max_steps,
|
| 660 |
+
"model_name": MODEL_NAME,
|
| 661 |
"baseline_avg_reward": baseline_avg_reward,
|
| 662 |
"post_avg_reward": post_avg_reward,
|
| 663 |
"delta_avg_reward": post_avg_reward - baseline_avg_reward,
|
| 664 |
+
"reward_components_jsonl": str(artifacts_early / "reward_components.jsonl"),
|
| 665 |
+
"trainer_on_log_jsonl": str(artifacts_early / "trainer_on_log.jsonl"),
|
| 666 |
+
"tensorboard_dir": str(tb_dir) if report_to == "tensorboard" else None,
|
| 667 |
+
"report_to": report_to,
|
| 668 |
}
|
| 669 |
)
|
| 670 |
metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
|
|
|
| 682 |
out_path = artifacts.root / "before_after_avg_reward.png"
|
| 683 |
plt.tight_layout()
|
| 684 |
plt.savefig(out_path, dpi=200)
|
| 685 |
+
print(f"Saved {out_path}")
|
| 686 |
except Exception as e:
|
| 687 |
+
print(f"Could not generate before/after plot: {e}")
|
| 688 |
+
|
| 689 |
+
lora_dir = os.environ.get("LORA_SAVE_DIR", "./sota_sql_agent_unsloth")
|
| 690 |
+
print("\nSaving LoRA weights locally...")
|
| 691 |
+
model.save_pretrained(lora_dir)
|
| 692 |
+
|
| 693 |
+
hub_id = os.environ.get("HF_HUB_REPO_ID", "md896/sota-sql-agent-7b")
|
| 694 |
+
token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 695 |
+
if os.environ.get("SKIP_HUB_PUSH", "").strip() in ("1", "true", "yes"):
|
| 696 |
+
print("SKIP_HUB_PUSH set — not pushing to Hub.")
|
| 697 |
+
else:
|
| 698 |
+
try:
|
| 699 |
+
model.push_to_hub(hub_id, token=token)
|
| 700 |
+
print(f"Pushed LoRA to https://huggingface.co/{hub_id}")
|
| 701 |
+
except Exception as e:
|
| 702 |
+
print(f"Hub push failed (set HF_TOKEN / HF_HUB_REPO_ID or SKIP_HUB_PUSH=1): {e}")
|
| 703 |
+
|
| 704 |
+
print(f"\nTraining artifacts under {artifacts.root}")
|
| 705 |
|
| 706 |
if __name__ == "__main__":
|
| 707 |
run_sota_train()
|