Baladithya Balamurugan
Wave 20: SageMaker GRPO smoke artifacts (F3) — runnable g5.2xlarge GSM8K
7453f13
Raw
History Blame Contribute Delete
6.69 kB
"""GPU + vLLM-colocated variant of run.py, for a SageMaker Training Job.
F3 §3.3: the runnable-NOW GRPO smoke. Same GSM8K RLVR reward and same
`ComposerReplicationTrainer(alpha_sdpo=0, beta_replay=0)` (plain GRPO, channels
2/3 off) as the CPU example, lifted to one real GPU with vLLM colocated in the
training process. Proves: container builds, trainer runs on GPU, vLLM rollout
works, reward fires, checkpoint lands in S3.
This script runs INSIDE the SageMaker container. SageMaker conventions used:
* Hyperparameters arrive as CLI args ``--key value`` (the Estimator's
``hyperparameters=`` dict). We also read /opt/ml/input/config/
hyperparameters.json as a fallback.
* The final model must be written to ``/opt/ml/model`` (SM_MODEL_DIR);
SageMaker tars it to ``OutputDataConfig.S3OutputPath`` on exit.
* stdout/stderr stream to CloudWatch ``/aws/sagemaker/TrainingJobs/<job>``.
Run via examples/gsm8k_grpo/run_sagemaker_launch.py (the Estimator driver).
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import re
import sys
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
# ---------------------------------------------------------------------------
# Reward — identical RLVR `#### NUMBER` regex to the CPU example (run.py)
# ---------------------------------------------------------------------------
_ANSWER_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
def _extract_answer(text: str) -> str | None:
matches = _ANSWER_RE.findall(text or "")
return matches[-1].strip() if matches else None
def gsm8k_reward(completions, **kwargs):
gold = kwargs.get("gold_answer")
if gold is None:
return [0.0] * len(completions)
rewards: list[float] = []
for completion, gold_ans in zip(completions, gold, strict=False):
if isinstance(completion, list):
text = "\n".join(m.get("content", "") for m in completion)
else:
text = str(completion)
pred = _extract_answer(text)
rewards.append(1.0 if (pred is not None and pred == str(gold_ans).strip()) else 0.0)
return rewards
SYSTEM_PROMPT = (
"You are a math tutor. Solve the problem step by step. "
"End your answer with `#### N` where N is the final numeric answer."
)
def build_dataset(n_rows: int):
raw = load_dataset("openai/gsm8k", "main", split=f"train[:{n_rows}]")
def _format(row):
gold = _extract_answer(row["answer"]) or ""
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": row["question"]},
],
"gold_answer": gold,
}
return raw.map(_format, remove_columns=raw.column_names)
# ---------------------------------------------------------------------------
# Hyperparameters — SageMaker passes them as CLI args; JSON fallback.
# ---------------------------------------------------------------------------
def _parse_hyperparameters() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct")
p.add_argument("--n_train_rows", type=int, default=100)
p.add_argument("--max_steps", type=int, default=20)
p.add_argument("--num_generations", type=int, default=8)
p.add_argument("--per_device_train_batch_size", type=int, default=8)
p.add_argument("--max_completion_length", type=int, default=256)
p.add_argument("--learning_rate", type=float, default=1e-5)
p.add_argument("--beta", type=float, default=0.04)
p.add_argument("--vllm_gpu_memory_utilization", type=float, default=0.3)
p.add_argument("--use_vllm", type=lambda s: str(s).lower() != "false", default=True)
# SageMaker model output dir (env SM_MODEL_DIR, default /opt/ml/model).
p.add_argument("--model_dir", default=os.environ.get("SM_MODEL_DIR", "/opt/ml/model"))
args, _unknown = p.parse_known_args()
return args
def main() -> int:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
log = logging.getLogger("gsm8k_grpo_sagemaker")
args = _parse_hyperparameters()
log.info("=" * 64)
log.info("GRPO + GSM8K + %s (SageMaker GPU, vLLM=%s)", args.model, args.use_vllm)
log.info("=" * 64)
log.info("hyperparameters: %s", json.dumps(vars(args), indent=2))
cuda = torch.cuda.is_available()
log.info("CUDA available: %s | device: %s", cuda,
torch.cuda.get_device_name(0) if cuda else "cpu")
log.info("[1/4] Loading model + tokenizer ...")
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
args.model, torch_dtype=torch.bfloat16 if cuda else torch.float32
)
log.info("[2/4] Loading %d GSM8K rows ...", args.n_train_rows)
dataset = build_dataset(args.n_train_rows)
log.info("[3/4] Building GRPOConfig + ComposerReplicationTrainer ...")
from trl import GRPOConfig
from composer_replication import ComposerReplicationTrainer
config = GRPOConfig(
output_dir=args.model_dir,
per_device_train_batch_size=args.per_device_train_batch_size,
num_generations=args.num_generations,
max_completion_length=args.max_completion_length,
learning_rate=args.learning_rate,
max_steps=args.max_steps,
logging_steps=1,
save_strategy="no",
report_to=[],
bf16=cuda,
beta=args.beta,
# vLLM colocated in-process on the same GPU (F3 §3.3 / §5).
use_vllm=bool(args.use_vllm and cuda),
vllm_mode="colocate",
vllm_gpu_memory_utilization=args.vllm_gpu_memory_utilization,
vllm_tensor_parallel_size=1,
seed=42,
)
trainer = ComposerReplicationTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[gsm8k_reward],
train_dataset=dataset,
args=config,
alpha_sdpo=0.0, # channels 2/3 off — plain GRPO smoke
beta_replay=0.0,
)
log.info("[4/4] Training for %d steps ...", args.max_steps)
result = trainer.train()
log.info("Training complete: %s", result.metrics)
# Persist to SM_MODEL_DIR → SageMaker uploads to OutputDataConfig.
trainer.save_model(args.model_dir)
log.info("Model saved to %s (SageMaker will upload to S3).", args.model_dir)
return 0
if __name__ == "__main__":
sys.exit(main())