"""Launch the GSM8K GRPO smoke as a SageMaker Training Job (F3 §3.4). Two ways to run, selected by --image: * --image dlc (default): use the AWS PyTorch DLC directly as the training image and ship the framework + script via source_dir; deps install at job start from requirements.txt. No local Docker build — the lowest-friction path for a one-shot smoke (F3's "source_dir on top of the baked image", here on top of the *stock* DLC). * --image : use a prebuilt baked image (docker/Dockerfile.sagemaker pushed to ECR via scripts/build_and_push_ecr.sh) — the repeatable path that avoids the ~5-10 min startup pip-install (F3 §3.2). Live facts (verified 2026-06-09, acct 386931836011, us-west-2): * ml.g5.2xlarge training-job quota = 1 → runs with no quota ticket. * DLC tag resolved live: 2.6.0-gpu-py312-cu126-ubuntu22.04-sagemaker-v1.25 (NOTE: cu126, not the cu124 in some docs; bare floating tag does not exist, the -v1.25 build suffix is required). * role: AmazonSageMaker-ExecutionRole-20250725T133247 * bucket: amazon-sagemaker-386931836011-us-west-2-7597bf4d9a3d Usage (from the laptop / Studio, sagemaker SDK v2): pip install 'sagemaker>=2.200,<3' python examples/gsm8k_grpo/run_sagemaker_launch.py --max-steps 20 python examples/gsm8k_grpo/run_sagemaker_launch.py --no-wait # fire and poll later """ from __future__ import annotations import argparse import os import shutil import sys import tempfile REGION = "us-west-2" ACCOUNT = "386931836011" ROLE = f"arn:aws:iam::{ACCOUNT}:role/service-role/AmazonSageMaker-ExecutionRole-20250725T133247" BUCKET = f"amazon-sagemaker-{ACCOUNT}-{REGION}-7597bf4d9a3d" # DLC tag resolved live against the 763104351884 us-west-2 registry. # MUST be the torch-2.7 DLC: ComposerReplicationTrainer → trl 1.5.x → # transformers>=4.56.2 → torch.float8_e8m0fnu (torch>=2.7). The torch-2.6 DLC # (cu126) fails AutoModel.from_pretrained with AttributeError on that dtype. # (Learned from two live runs 2026-06-09; see requirements.txt + the quickstart.) DLC_IMAGE = ( "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "pytorch-training:2.7.1-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.26" ) _HERE = os.path.dirname(os.path.abspath(__file__)) _REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..", "..")) def _stage_source() -> str: """Build a minimal source_dir: the entry script + requirements.txt + the composer_replication package (importable as a local package from /opt/ml/code). Keeps the S3 upload tiny — no .venv/research/docs/.git. Returns the staging dir path (caller cleans up).""" staging = tempfile.mkdtemp(prefix="composer-sm-src-") shutil.copy2(os.path.join(_HERE, "run_sagemaker.py"), staging) shutil.copy2(os.path.join(_HERE, "requirements.txt"), staging) shutil.copytree( os.path.join(_REPO_ROOT, "composer_replication"), os.path.join(staging, "composer_replication"), ignore=shutil.ignore_patterns("__pycache__", "*.pyc", "tests", "*.egg-info"), ) return staging def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--image", default="dlc", help="'dlc' (stock PyTorch DLC + source_dir) or a prebuilt ECR image URI") ap.add_argument("--instance-type", default="ml.g5.2xlarge") ap.add_argument("--max-steps", type=int, default=20) ap.add_argument("--n-train-rows", type=int, default=100) ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct") ap.add_argument("--vllm", action="store_true", help="enable colocated vLLM rollout. OFF by default: the default " "requirements.txt omits vllm (it hard-pins torch==2.6, fighting " "the torch-2.7 DLC trl 1.5 needs). Only pass this with a baked " "image (--image) carrying a torch-2.7-matched vllm>=0.9.") ap.add_argument("--spot", action="store_true", help="use managed spot (quota=1 too)") ap.add_argument("--no-wait", action="store_true", help="submit and return; poll later") ap.add_argument("--max-run", type=int, default=3600) args = ap.parse_args() import sagemaker from sagemaker.estimator import Estimator sess = sagemaker.Session(default_bucket=BUCKET) image = DLC_IMAGE if args.image == "dlc" else args.image print(f"[launch] region={REGION} image={image}") print(f"[launch] role={ROLE}") print(f"[launch] instance={args.instance_type} max_steps={args.max_steps} " f"vllm={args.vllm} spot={args.spot}") staging = _stage_source() print(f"[launch] staged source_dir at {staging}") hyperparameters = { "model": args.model, "n_train_rows": args.n_train_rows, "max_steps": args.max_steps, "use_vllm": "true" if args.vllm else "false", } spot_kwargs = {} if args.spot: spot_kwargs = {"use_spot_instances": True, "max_wait": args.max_run + 3600} est = Estimator( image_uri=image, role=ROLE, instance_type=args.instance_type, instance_count=1, volume_size=100, max_run=args.max_run, sagemaker_session=sess, output_path=f"s3://{BUCKET}/composer-rl/smoke/output", base_job_name="composer-grpo-smoke", entry_point="run_sagemaker.py", source_dir=staging, hyperparameters=hyperparameters, environment={ "HF_HUB_ENABLE_HF_TRANSFER": "1", # DLC sagemaker-training-toolkit installs requirements.txt from source_dir. }, # EnableNetworkIsolation MUST be False (default) so the container can reach # huggingface.co (model + GSM8K) and S3 (F3 §3.5). keep_alive_period_in_seconds=0, # warm-pool quota=0 in this acct → leave off **spot_kwargs, ) try: est.fit(wait=not args.no_wait, logs=("All" if not args.no_wait else None)) finally: shutil.rmtree(staging, ignore_errors=True) print(f"[launch] job name: {est.latest_training_job.name}") if not args.no_wait: print(f"[launch] model artifact: {est.model_data}") return 0 if __name__ == "__main__": sys.exit(main())