sentinel-env / training /launch_hf_job.py
XcodeAddy's picture
Revert "Merge remote main with local project"
1c2514b
from __future__ import annotations
import argparse
import os
import shlex
import sys
from textwrap import dedent
from huggingface_hub import get_token, run_job
# Current Unsloth pulls torchao, which expects torch >= 2.11. Keep the Jobs
# image aligned so GRPO imports fail fast only for real code issues.
DEFAULT_IMAGE = "pytorch/pytorch:2.11.0-cuda12.8-cudnn9-devel"
DEFAULT_REPO = "https://github.com/ADITYAGABA1322/sentinel-env"
DEFAULT_MODEL = "unsloth/Qwen2.5-0.5B-Instruct"
def shell_join(lines: list[str]) -> str:
return " && ".join(line.strip() for line in lines if line.strip())
def bootstrap_repo(repo_url: str) -> list[str]:
return [
"set -eux",
"command -v git || (apt-get update && apt-get install -y git)",
f"git clone {shlex.quote(repo_url)} sentinel-env",
"cd sentinel-env",
"python -m venv --system-site-packages .job-venv || (apt-get update && apt-get install -y python3-venv && python -m venv --system-site-packages .job-venv)",
". .job-venv/bin/activate",
"python -m pip install --upgrade pip",
"pip install -r requirements.txt",
"pip install -r requirements-train.txt",
(
"python -c \"import torch; "
"print('torch', torch.__version__); "
"print('gpu', torch.cuda.get_device_name() if torch.cuda.is_available() else 'none'); "
"from transformers import PreTrainedModel; "
"from trl import GRPOConfig, GRPOTrainer; "
"print('training imports ok')\""
),
]
def gpu_test_command() -> str:
return "python -c 'import torch; print(torch.cuda.get_device_name())'"
def train_command(args: argparse.Namespace, train: bool = True) -> str:
lines = bootstrap_repo(args.repo_url)
if not train:
return shell_join(lines)
lines.append(
" ".join(
[
"python training/train.py",
f"--episodes {args.episodes}",
f"--task {shlex.quote(args.task)}",
f"--seed {args.seed}",
f"--model {shlex.quote(args.model)}",
f"--epochs {args.epochs}",
f"--batch-size {args.batch_size}",
f"--learning-rate {args.learning_rate}",
f"--lora-rank {args.lora_rank}",
f"--num-generations {args.num_generations}",
f"--max-seq-length {args.max_seq_length}",
f"--output-dir {shlex.quote(args.output_dir)}",
]
)
)
if args.mode == "train-full":
upload_code = (
"import os; "
"from huggingface_hub import HfApi; "
"token=os.environ.get('HF_TOKEN'); "
"api=HfApi(token=token); "
"model_repo=os.environ.get('SENTINEL_MODEL_REPO','XcodeAddy/sentinel-grpo-qwen05'); "
"artifact_repo=os.environ.get('SENTINEL_ARTIFACT_REPO','XcodeAddy/sentinel-env-artifacts'); "
"job_id=os.environ.get('JOB_ID','manual'); "
"api.create_repo(model_repo, repo_type='model', exist_ok=True); "
f"api.upload_folder(folder_path='{args.output_dir}', repo_id=model_repo, repo_type='model'); "
"api.create_repo(artifact_repo, repo_type='dataset', exist_ok=True); "
"api.upload_folder(folder_path='outputs', repo_id=artifact_repo, repo_type='dataset', path_in_repo=f'job-{job_id}/outputs'); "
"print('Uploaded model adapter to', model_repo); "
"print('Uploaded outputs to', artifact_repo, 'under', f'job-{job_id}/outputs')"
)
lines.extend(
[
"python -c \"from training.replay import record_trained_actions; "
f"record_trained_actions(adapter_path='{args.output_dir}', "
f"base_model='{args.model}', tasks=['task1','task2','task3'], "
"seeds=range(30), out_path='outputs/trained_policy_replay.jsonl')\"",
"python training/evaluate.py --episodes 30 --task all "
"--policies random,heuristic,oracle_lite,trained "
"--replay outputs/trained_policy_replay.jsonl "
"--out outputs/eval_post.json --no-plot",
"cp outputs/eval_post.json outputs/evaluation_results.json",
"python -m training.plots --pre outputs/eval_pre.json "
"--post outputs/eval_post.json --out-dir outputs/charts",
f"python -c {shlex.quote(upload_code)}",
]
)
return shell_join(lines)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Launch SENTINEL training on Hugging Face Jobs without shell quoting pain."
)
parser.add_argument(
"--mode",
choices=["gpu-test", "import-smoke", "train-smoke", "train-full"],
default="gpu-test",
)
parser.add_argument("--namespace", default=os.environ.get("HF_NAMESPACE", "XcodeAddy"))
parser.add_argument("--flavor", default="a10g-small")
parser.add_argument("--timeout", default="2h")
parser.add_argument("--image", default=DEFAULT_IMAGE)
parser.add_argument("--repo-url", default=DEFAULT_REPO)
parser.add_argument("--model", default=DEFAULT_MODEL)
parser.add_argument("--episodes", type=int, default=50)
parser.add_argument("--task", choices=["task1", "task2", "task3", "all"], default="all")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--learning-rate", type=float, default=5e-6)
parser.add_argument("--lora-rank", type=int, default=8)
parser.add_argument("--num-generations", type=int, default=2)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--output-dir", default="training/sentinel_qwen05_grpo")
return parser.parse_args()
def main() -> None:
args = parse_args()
token = os.environ.get("HF_TOKEN") or get_token()
if not token:
raise SystemExit(
dedent(
"""
No Hugging Face token was found.
Either run:
read -s HF_TOKEN
export HF_TOKEN
Or log in once:
.venv/bin/hf auth login --add-to-git-credential
"""
).strip()
)
if args.mode == "gpu-test":
command = gpu_test_command()
elif args.mode == "import-smoke":
command = train_command(args, train=False)
else:
command = train_command(args)
print("Launching HF Job:")
print(f" mode = {args.mode}")
print(f" namespace = {args.namespace}")
print(f" flavor = {args.flavor}")
print(f" timeout = {args.timeout}")
print(f" image = {args.image}")
print(" command = bash -lc", shlex.quote(command[:260] + ("..." if len(command) > 260 else "")))
job = run_job(
image=args.image,
command=["bash", "-lc", command],
flavor=args.flavor,
timeout=args.timeout,
namespace=args.namespace,
token=token,
secrets={"HF_TOKEN": token},
env={
"SENTINEL_MODEL_REPO": "XcodeAddy/sentinel-grpo-qwen05",
"SENTINEL_ARTIFACT_REPO": "XcodeAddy/sentinel-env-artifacts",
},
labels={"project": "sentinel", "mode": args.mode},
)
print("Job launched.")
print("URL:", job.url)
print("ID:", job.id)
print()
print("Follow logs with:")
print(f" .venv/bin/hf jobs logs -f {job.id} --namespace {args.namespace} --token \"$HF_TOKEN\"")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
sys.exit(130)