| | #!/usr/bin/env bash |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | if [ -z "${BASH_VERSION}" ]; then |
| | echo "Please use bash to run this script." >&2 |
| | exit 1 |
| | fi |
| |
|
| | set -x |
| |
|
| | SCRIPT_DIR="$(cd "$(dirname "$0")" &>/dev/null && pwd)" |
| | ROOT_DIR="$(dirname "${SCRIPT_DIR}")" |
| | export PYTHONPATH="${ROOT_DIR}${PYTHONPATH:+:${PYTHONPATH}}" |
| | export LOGLEVEL="${LOGLEVEL:-WARNING}" |
| |
|
| |
|
| |
|
| | |
| | export LOGLEVEL="INFO" |
| | export WANDB_API_KEY="0e77f7c02e33b86269ca2123964b9fefcf9c1a7a" |
| | |
| | |
| | |
| | |
| |
|
| | unset HOSTFILE |
| | ZERO_STAGE=3 |
| | OFFLOAD="none" |
| | LOG_RUN_NAME='setting3-default' |
| | while [[ "$#" -gt 0 ]]; do |
| | arg="$1" |
| | shift |
| | case "${arg}" in |
| | --train_datasets) |
| | DATASET="$1" |
| | shift |
| | ;; |
| | --train_datasets=*) |
| | DATASET="${arg#*=}" |
| | ;; |
| | --model_name_or_path) |
| | MODEL_NAME_OR_PATH="$1" |
| | shift |
| | ;; |
| | --model_name_or_path=*) |
| | MODEL_NAME_OR_PATH="${arg#*=}" |
| | ;; |
| | --output_dir) |
| | OUTPUT_DIR="$1" |
| | shift |
| | ;; |
| | --output_dir=*) |
| | OUTPUT_DIR="${arg#*=}" |
| | ;; |
| | --log_run_name) |
| | LOG_RUN_NAME="$1" |
| | shift |
| | ;; |
| | --log_run_name=*) |
| | LOG_RUN_NAME="${arg#*=}" |
| | ;; |
| | --hostfile) |
| | HOSTFILE="$1" |
| | shift |
| | ;; |
| | --hostfile=*) |
| | HOSTFILE="${arg#*=}" |
| | ;; |
| | --zero_stage) |
| | ZERO_STAGE="$1" |
| | shift |
| | ;; |
| | --zero_stage=*) |
| | ZERO_STAGE="${arg#*=}" |
| | ;; |
| | --offload) |
| | OFFLOAD="$1" |
| | shift |
| | ;; |
| | --offload=*) |
| | OFFLOAD="${arg#*=}" |
| | ;; |
| | *) |
| | echo "Unknown parameter passed: '${arg}'" >&2 |
| | exit 1 |
| | ;; |
| | esac |
| | done |
| |
|
| | mkdir -p "${OUTPUT_DIR}" |
| | OUTPUT_DIR="$(cd "${OUTPUT_DIR}" &>/dev/null && pwd)" |
| | if [[ ! -f "${OUTPUT_DIR}/.gitignore" ]]; then |
| | echo '*' >"${OUTPUT_DIR}/.gitignore" |
| | fi |
| |
|
| | cp -f "$0" "${OUTPUT_DIR}/script.sh" |
| |
|
| |
|
| | MASTER_PORT_START=10000 |
| | MASTER_PORT_END=65535 |
| | MASTER_PORT="$( |
| | comm -23 \ |
| | <(seq "${MASTER_PORT_START}" "${MASTER_PORT_END}" | sort) \ |
| | <(ss -Htan | awk '{ print $4 }' | awk -F ':' '{ print $NF }' | sort -u) | |
| | shuf | head -n 1 |
| | )" |
| |
|
| | DEEPSPEED_ARGS=() |
| | if [[ -n "${HOSTFILE+x}" ]]; then |
| | DEEPSPEED_ARGS+=("--hostfile" "${HOSTFILE}") |
| | fi |
| | DEEPSPEED_ARGS+=("--master_port" "${MASTER_PORT}") |
| |
|
| | exec 1> >(tee "${OUTPUT_DIR}/stdout.log" >&1) 2> >(tee "${OUTPUT_DIR}/stderr.log" >&2) |
| |
|
| | deepspeed "${DEEPSPEED_ARGS[@]}" \ |
| | --module safe_rlhf.finetune \ |
| | --train_datasets inverse-json::${DATASET} \ |
| | --model_name_or_path "${MODEL_NAME_OR_PATH}" \ |
| | --max_length 512 \ |
| | --trust_remote_code True \ |
| | --epochs 1 \ |
| | --per_device_train_batch_size 1 \ |
| | --per_device_eval_batch_size 4 \ |
| | --gradient_accumulation_steps 8 \ |
| | --gradient_checkpointing \ |
| | --learning_rate 1e-5 \ |
| | --lr_warmup_ratio 0 \ |
| | --weight_decay 0.0 \ |
| | --lr_scheduler_type constant \ |
| | --weight_decay 0.0 \ |
| | --seed 42 \ |
| | --output_dir "${OUTPUT_DIR}" \ |
| | --log_type wandb \ |
| | --log_run_name "${LOG_RUN_NAME}" \ |
| | --log_project Inverse_Alignment_IMDb \ |
| | --zero_stage "${ZERO_STAGE}" \ |
| | --offload "${OFFLOAD}" \ |
| | --bf16 True \ |
| | --tf32 True \ |
| | --save_16bit |
| |
|