| #!/usr/bin/env bash |
| set -eo pipefail |
|
|
| SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" |
| |
| source "${SCRIPT_DIR}/conda_env.sh" |
| set -u |
|
|
| REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" |
| cd "${REPO_ROOT}" |
| export PYTHONPATH="${REPO_ROOT}/src" |
|
|
| |
| |
| |
| |
| |
|
|
| BASE_CONFIG="${BASE_CONFIG:-configs/grpo_llama32_3b_bf16.yaml}" |
| ACCELERATE_CONFIG="${ACCELERATE_CONFIG:-configs/accelerate_ddp_4gpu.yaml}" |
| NUM_PROCESSES="${NUM_PROCESSES:-4}" |
| PYTHON_BIN="${PYTHON_BIN:-python}" |
| OUT_ROOT="${OUT_ROOT:-artifacts/sweeps/length_penalty_lambda}" |
| LAMBDAS="${LAMBDAS:-0.5 0.25}" |
|
|
| export WANDB_MODE="${WANDB_MODE:-offline}" |
| export HF_HUB_OFFLINE="${HF_HUB_OFFLINE:-1}" |
| export HF_DATASETS_OFFLINE="${HF_DATASETS_OFFLINE:-1}" |
| export TRANSFORMERS_OFFLINE="${TRANSFORMERS_OFFLINE:-1}" |
|
|
| mkdir -p "${OUT_ROOT}" |
| TMP_DIR="$(mktemp -d "${OUT_ROOT}/tmp_cfgs.XXXXXX")" |
| trap 'rm -rf "${TMP_DIR}"' EXIT |
|
|
| echo "Base config: ${BASE_CONFIG}" |
| echo "Lambdas: ${LAMBDAS}" |
| echo "Output root: ${OUT_ROOT}" |
|
|
| for LAMBDA in ${LAMBDAS}; do |
| echo |
| echo "=== Running lambda=${LAMBDA} ===" |
|
|
| CFG_PATH="${TMP_DIR}/grpo_lambda_${LAMBDA}.yaml" |
| BASE_CONFIG="${BASE_CONFIG}" OUT_ROOT="${OUT_ROOT}" LAMBDA="${LAMBDA}" CFG_PATH="${CFG_PATH}" "${PYTHON_BIN}" - <<'PY' |
| import copy |
| import os |
| from pathlib import Path |
|
|
| import yaml |
|
|
| base_config = Path(os.environ["BASE_CONFIG"]) |
| out_root = Path(os.environ["OUT_ROOT"]) |
| cfg_path = Path(os.environ["CFG_PATH"]) |
| lam = os.environ["LAMBDA"] |
|
|
| with base_config.open("r", encoding="utf-8") as handle: |
| cfg = yaml.safe_load(handle) |
|
|
| cfg = copy.deepcopy(cfg) |
| cfg.setdefault("objective", {}) |
| cfg["objective"].setdefault("kwargs", {}) |
| cfg["objective"]["kwargs"]["enable_length_penalty"] = True |
| cfg["objective"]["kwargs"]["reward_mode"] = "weighted_length_penalty" |
| cfg["objective"]["kwargs"]["length_penalty_lambda"] = float(lam) |
|
|
| cfg.setdefault("trainer", {}) |
| base_run_name = cfg["trainer"].get("run_name", "grpo") |
| safe_lam = lam.replace(".", "p") |
| cfg["trainer"]["run_name"] = f"{base_run_name}-lambda-{safe_lam}" |
| cfg["trainer"]["output_dir"] = str(out_root / f"run_lambda_{safe_lam}") |
|
|
| with cfg_path.open("w", encoding="utf-8") as handle: |
| yaml.safe_dump(cfg, handle, sort_keys=False) |
|
|
| print(f"Wrote config: {cfg_path}") |
| print(f"run_name: {cfg['trainer']['run_name']}") |
| print(f"output_dir: {cfg['trainer']['output_dir']}") |
| PY |
|
|
| accelerate launch \ |
| --config_file "${ACCELERATE_CONFIG}" \ |
| --num_processes "${NUM_PROCESSES}" \ |
| src/train_grpo.py \ |
| --config "${CFG_PATH}" |
| done |
|
|
| echo |
| echo "Sweep complete. Runs are under: ${OUT_ROOT}" |
|
|