| #!/usr/bin/env bash |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if [ -z "${BASH_VERSION}" ]; then |
| echo "Please use bash to run this script." >&2 |
| exit 1 |
| fi |
|
|
| set -x |
|
|
| SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" |
| ROOT_DIR="$(dirname "${SCRIPT_DIR}")" |
| export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" |
| export LOGLEVEL="${LOGLEVEL:-WARNING}" |
|
|
|
|
|
|
| |
| export LOGLEVEL="INFO" |
| export WANDB_API_KEY="0e77f7c02e33b86269ca2123964b9fefcf9c1a7a" |
| |
| |
| |
| |
|
|
| unset HOSTFILE |
| ZERO_STAGE=3 |
| OFFLOAD="none" |
| LOG_RUN_NAME='setting3-default' |
| while [[ "$#" -gt 0 ]]; do |
| arg="$1" |
| shift |
| case "${arg}" in |
| --train_datasets) |
| DATASET="$1" |
| shift |
| ;; |
| --train_datasets=*) |
| DATASET="${arg#*=}" |
| ;; |
| --model_name_or_path) |
| MODEL_NAME_OR_PATH="$1" |
| shift |
| ;; |
| --model_name_or_path=*) |
| MODEL_NAME_OR_PATH="${arg#*=}" |
| ;; |
| --output_dir) |
| OUTPUT_DIR="$1" |
| shift |
| ;; |
| --output_dir=*) |
| OUTPUT_DIR="${arg#*=}" |
| ;; |
| --log_run_name) |
| LOG_RUN_NAME="$1" |
| shift |
| ;; |
| --log_run_name=*) |
| LOG_RUN_NAME="${arg#*=}" |
| ;; |
| --hostfile) |
| HOSTFILE="$1" |
| shift |
| ;; |
| --hostfile=*) |
| HOSTFILE="${arg#*=}" |
| ;; |
| --zero_stage) |
| ZERO_STAGE="$1" |
| shift |
| ;; |
| --zero_stage=*) |
| ZERO_STAGE="${arg#*=}" |
| ;; |
| --offload) |
| OFFLOAD="$1" |
| shift |
| ;; |
| --offload=*) |
| OFFLOAD="${arg#*=}" |
| ;; |
| *) |
| echo "Unknown parameter passed: '${arg}'" >&2 |
| exit 1 |
| ;; |
| esac |
| done |
|
|
| mkdir -p "${OUTPUT_DIR}" |
| OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" |
| if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then |
| echo '*' >"${OUTPUT_DIR}/.gitignore" |
| fi |
|
|
| cp -f "$0" "${OUTPUT_DIR}/script.sh" |
|
|
|
|
| MASTER_PORT_START=10000 |
| MASTER_PORT_END=65535 |
| MASTER_PORT="$( |
| comm -23 \ |
| <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ |
| <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | |
| shuf | head -n 1 |
| )" |
|
|
| DEEPSPEED_ARGS=() |
| if [[ -n "${HOSTFILE+x}" ]]; then |
| DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") |
| fi |
| DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") |
|
|
| exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) |
|
|
| deepspeed "${DEEPSPEED_ARGS[@]}" \ |
| --module safe_rlhf.finetune \ |
| --train_datasets inverse-json::${DATASET} \ |
| --model_name_or_path "${MODEL_NAME_OR_PATH}" \ |
| --max_length 512 \ |
| --trust_remote_code True \ |
| --epochs 1 \ |
| --per_device_train_batch_size 1 \ |
| --per_device_eval_batch_size 4 \ |
| --gradient_accumulation_steps 8 \ |
| --gradient_checkpointing \ |
| --learning_rate 1e-5 \ |
| --lr_warmup_ratio 0 \ |
| --weight_decay 0.0 \ |
| --lr_scheduler_type constant \ |
| --weight_decay 0.0 \ |
| --seed 42 \ |
| --output_dir "${OUTPUT_DIR}" \ |
| --log_type wandb \ |
| --log_run_name "${LOG_RUN_NAME}" \ |
| --log_project Inverse_Alignment_IMDb \ |
| --zero_stage "${ZERO_STAGE}" \ |
| --offload "${OFFLOAD}" \ |
| --bf16 True \ |
| --tf32 True \ |
| --save_16bit |
|
|