from __future__ import annotations import argparse import os import shlex import sys from textwrap import dedent from huggingface_hub import get_token, run_job # Current Unsloth pulls torchao, which expects torch >= 2.11. Keep the Jobs # image aligned so GRPO imports fail fast only for real code issues. DEFAULT_IMAGE = "pytorch/pytorch:2.11.0-cuda12.8-cudnn9-devel" DEFAULT_REPO = "https://github.com/ADITYAGABA1322/sentinel-env" DEFAULT_MODEL = "unsloth/Qwen2.5-0.5B-Instruct" def shell_join(lines: list[str]) -> str: return " && ".join(line.strip() for line in lines if line.strip()) def bootstrap_repo(repo_url: str) -> list[str]: return [ "set -eux", "command -v git || (apt-get update && apt-get install -y git)", f"git clone {shlex.quote(repo_url)} sentinel-env", "cd sentinel-env", "python -m venv --system-site-packages .job-venv || (apt-get update && apt-get install -y python3-venv && python -m venv --system-site-packages .job-venv)", ". .job-venv/bin/activate", "python -m pip install --upgrade pip", "pip install -r requirements.txt", "pip install -r requirements-train.txt", ( "python -c \"import torch; " "print('torch', torch.__version__); " "print('gpu', torch.cuda.get_device_name() if torch.cuda.is_available() else 'none'); " "from transformers import PreTrainedModel; " "from trl import GRPOConfig, GRPOTrainer; " "print('training imports ok')\"" ), ] def gpu_test_command() -> str: return "python -c 'import torch; print(torch.cuda.get_device_name())'" def train_command(args: argparse.Namespace, train: bool = True) -> str: lines = bootstrap_repo(args.repo_url) if not train: return shell_join(lines) lines.append( " ".join( [ "python training/train.py", f"--episodes {args.episodes}", f"--task {shlex.quote(args.task)}", f"--seed {args.seed}", f"--model {shlex.quote(args.model)}", f"--epochs {args.epochs}", f"--batch-size {args.batch_size}", f"--learning-rate {args.learning_rate}", f"--lora-rank {args.lora_rank}", f"--num-generations {args.num_generations}", f"--max-seq-length {args.max_seq_length}", f"--output-dir {shlex.quote(args.output_dir)}", ] ) ) if args.mode == "train-full": upload_code = ( "import os; " "from huggingface_hub import HfApi; " "token=os.environ.get('HF_TOKEN'); " "api=HfApi(token=token); " "model_repo=os.environ.get('SENTINEL_MODEL_REPO','XcodeAddy/sentinel-grpo-qwen05'); " "artifact_repo=os.environ.get('SENTINEL_ARTIFACT_REPO','XcodeAddy/sentinel-env-artifacts'); " "job_id=os.environ.get('JOB_ID','manual'); " "api.create_repo(model_repo, repo_type='model', exist_ok=True); " f"api.upload_folder(folder_path='{args.output_dir}', repo_id=model_repo, repo_type='model'); " "api.create_repo(artifact_repo, repo_type='dataset', exist_ok=True); " "api.upload_folder(folder_path='outputs', repo_id=artifact_repo, repo_type='dataset', path_in_repo=f'job-{job_id}/outputs'); " "print('Uploaded model adapter to', model_repo); " "print('Uploaded outputs to', artifact_repo, 'under', f'job-{job_id}/outputs')" ) lines.extend( [ "python -c \"from training.replay import record_trained_actions; " f"record_trained_actions(adapter_path='{args.output_dir}', " f"base_model='{args.model}', tasks=['task1','task2','task3'], " "seeds=range(30), out_path='outputs/trained_policy_replay.jsonl')\"", "python training/evaluate.py --episodes 30 --task all " "--policies random,heuristic,oracle_lite,trained " "--replay outputs/trained_policy_replay.jsonl " "--out outputs/eval_post.json --no-plot", "cp outputs/eval_post.json outputs/evaluation_results.json", "python -m training.plots --pre outputs/eval_pre.json " "--post outputs/eval_post.json --out-dir outputs/charts", f"python -c {shlex.quote(upload_code)}", ] ) return shell_join(lines) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Launch SENTINEL training on Hugging Face Jobs without shell quoting pain." ) parser.add_argument( "--mode", choices=["gpu-test", "import-smoke", "train-smoke", "train-full"], default="gpu-test", ) parser.add_argument("--namespace", default=os.environ.get("HF_NAMESPACE", "XcodeAddy")) parser.add_argument("--flavor", default="a10g-small") parser.add_argument("--timeout", default="2h") parser.add_argument("--image", default=DEFAULT_IMAGE) parser.add_argument("--repo-url", default=DEFAULT_REPO) parser.add_argument("--model", default=DEFAULT_MODEL) parser.add_argument("--episodes", type=int, default=50) parser.add_argument("--task", choices=["task1", "task2", "task3", "all"], default="all") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--lora-rank", type=int, default=8) parser.add_argument("--num-generations", type=int, default=2) parser.add_argument("--max-seq-length", type=int, default=1024) parser.add_argument("--output-dir", default="training/sentinel_qwen05_grpo") return parser.parse_args() def main() -> None: args = parse_args() token = os.environ.get("HF_TOKEN") or get_token() if not token: raise SystemExit( dedent( """ No Hugging Face token was found. Either run: read -s HF_TOKEN export HF_TOKEN Or log in once: .venv/bin/hf auth login --add-to-git-credential """ ).strip() ) if args.mode == "gpu-test": command = gpu_test_command() elif args.mode == "import-smoke": command = train_command(args, train=False) else: command = train_command(args) print("Launching HF Job:") print(f" mode = {args.mode}") print(f" namespace = {args.namespace}") print(f" flavor = {args.flavor}") print(f" timeout = {args.timeout}") print(f" image = {args.image}") print(" command = bash -lc", shlex.quote(command[:260] + ("..." if len(command) > 260 else ""))) job = run_job( image=args.image, command=["bash", "-lc", command], flavor=args.flavor, timeout=args.timeout, namespace=args.namespace, token=token, secrets={"HF_TOKEN": token}, env={ "SENTINEL_MODEL_REPO": "XcodeAddy/sentinel-grpo-qwen05", "SENTINEL_ARTIFACT_REPO": "XcodeAddy/sentinel-env-artifacts", }, labels={"project": "sentinel", "mode": args.mode}, ) print("Job launched.") print("URL:", job.url) print("ID:", job.id) print() print("Follow logs with:") print(f" .venv/bin/hf jobs logs -f {job.id} --namespace {args.namespace} --token \"$HF_TOKEN\"") if __name__ == "__main__": try: main() except KeyboardInterrupt: sys.exit(130)