Baladithya Balamurugan
Wave 20: fix SageMaker smoke — torch-2.7 DLC + drop vllm pin (the real conflict)
a578ad9
Raw
History Blame Contribute Delete
6.23 kB
"""Launch the GSM8K GRPO smoke as a SageMaker Training Job (F3 §3.4).
Two ways to run, selected by --image:
* --image dlc (default): use the AWS PyTorch DLC directly as the training
image and ship the framework + script via source_dir; deps install at job
start from requirements.txt. No local Docker build — the lowest-friction
path for a one-shot smoke (F3's "source_dir on top of the baked image",
here on top of the *stock* DLC).
* --image <ecr-uri>: use a prebuilt baked image (docker/Dockerfile.sagemaker
pushed to ECR via scripts/build_and_push_ecr.sh) — the repeatable path that
avoids the ~5-10 min startup pip-install (F3 §3.2).
Live facts (verified 2026-06-09, acct 386931836011, us-west-2):
* ml.g5.2xlarge training-job quota = 1 → runs with no quota ticket.
* DLC tag resolved live: 2.6.0-gpu-py312-cu126-ubuntu22.04-sagemaker-v1.25
(NOTE: cu126, not the cu124 in some docs; bare floating tag does not exist,
the -v1.25 build suffix is required).
* role: AmazonSageMaker-ExecutionRole-20250725T133247
* bucket: amazon-sagemaker-386931836011-us-west-2-7597bf4d9a3d
Usage (from the laptop / Studio, sagemaker SDK v2):
pip install 'sagemaker>=2.200,<3'
python examples/gsm8k_grpo/run_sagemaker_launch.py --max-steps 20
python examples/gsm8k_grpo/run_sagemaker_launch.py --no-wait # fire and poll later
"""
from __future__ import annotations
import argparse
import os
import shutil
import sys
import tempfile
REGION = "us-west-2"
ACCOUNT = "386931836011"
ROLE = f"arn:aws:iam::{ACCOUNT}:role/service-role/AmazonSageMaker-ExecutionRole-20250725T133247"
BUCKET = f"amazon-sagemaker-{ACCOUNT}-{REGION}-7597bf4d9a3d"
# DLC tag resolved live against the 763104351884 us-west-2 registry.
# MUST be the torch-2.7 DLC: ComposerReplicationTrainer → trl 1.5.x →
# transformers>=4.56.2 → torch.float8_e8m0fnu (torch>=2.7). The torch-2.6 DLC
# (cu126) fails AutoModel.from_pretrained with AttributeError on that dtype.
# (Learned from two live runs 2026-06-09; see requirements.txt + the quickstart.)
DLC_IMAGE = (
"763104351884.dkr.ecr.us-west-2.amazonaws.com/"
"pytorch-training:2.7.1-gpu-py312-cu128-ubuntu22.04-sagemaker-v1.26"
)
_HERE = os.path.dirname(os.path.abspath(__file__))
_REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..", ".."))
def _stage_source() -> str:
"""Build a minimal source_dir: the entry script + requirements.txt + the
composer_replication package (importable as a local package from
/opt/ml/code). Keeps the S3 upload tiny — no .venv/research/docs/.git.
Returns the staging dir path (caller cleans up)."""
staging = tempfile.mkdtemp(prefix="composer-sm-src-")
shutil.copy2(os.path.join(_HERE, "run_sagemaker.py"), staging)
shutil.copy2(os.path.join(_HERE, "requirements.txt"), staging)
shutil.copytree(
os.path.join(_REPO_ROOT, "composer_replication"),
os.path.join(staging, "composer_replication"),
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", "tests", "*.egg-info"),
)
return staging
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--image", default="dlc",
help="'dlc' (stock PyTorch DLC + source_dir) or a prebuilt ECR image URI")
ap.add_argument("--instance-type", default="ml.g5.2xlarge")
ap.add_argument("--max-steps", type=int, default=20)
ap.add_argument("--n-train-rows", type=int, default=100)
ap.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct")
ap.add_argument("--vllm", action="store_true",
help="enable colocated vLLM rollout. OFF by default: the default "
"requirements.txt omits vllm (it hard-pins torch==2.6, fighting "
"the torch-2.7 DLC trl 1.5 needs). Only pass this with a baked "
"image (--image) carrying a torch-2.7-matched vllm>=0.9.")
ap.add_argument("--spot", action="store_true", help="use managed spot (quota=1 too)")
ap.add_argument("--no-wait", action="store_true", help="submit and return; poll later")
ap.add_argument("--max-run", type=int, default=3600)
args = ap.parse_args()
import sagemaker
from sagemaker.estimator import Estimator
sess = sagemaker.Session(default_bucket=BUCKET)
image = DLC_IMAGE if args.image == "dlc" else args.image
print(f"[launch] region={REGION} image={image}")
print(f"[launch] role={ROLE}")
print(f"[launch] instance={args.instance_type} max_steps={args.max_steps} "
f"vllm={args.vllm} spot={args.spot}")
staging = _stage_source()
print(f"[launch] staged source_dir at {staging}")
hyperparameters = {
"model": args.model,
"n_train_rows": args.n_train_rows,
"max_steps": args.max_steps,
"use_vllm": "true" if args.vllm else "false",
}
spot_kwargs = {}
if args.spot:
spot_kwargs = {"use_spot_instances": True, "max_wait": args.max_run + 3600}
est = Estimator(
image_uri=image,
role=ROLE,
instance_type=args.instance_type,
instance_count=1,
volume_size=100,
max_run=args.max_run,
sagemaker_session=sess,
output_path=f"s3://{BUCKET}/composer-rl/smoke/output",
base_job_name="composer-grpo-smoke",
entry_point="run_sagemaker.py",
source_dir=staging,
hyperparameters=hyperparameters,
environment={
"HF_HUB_ENABLE_HF_TRANSFER": "1",
# DLC sagemaker-training-toolkit installs requirements.txt from source_dir.
},
# EnableNetworkIsolation MUST be False (default) so the container can reach
# huggingface.co (model + GSM8K) and S3 (F3 §3.5).
keep_alive_period_in_seconds=0, # warm-pool quota=0 in this acct → leave off
**spot_kwargs,
)
try:
est.fit(wait=not args.no_wait, logs=("All" if not args.no_wait else None))
finally:
shutil.rmtree(staging, ignore_errors=True)
print(f"[launch] job name: {est.latest_training_job.name}")
if not args.no_wait:
print(f"[launch] model artifact: {est.model_data}")
return 0
if __name__ == "__main__":
sys.exit(main())