File size: 3,526 Bytes
30c14cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | #!/bin/bash
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
set -ex
DEBUG=${DEBUG:-0}
# Set environment variables
export PYTHONPATH=/mnt/ceph_rbd/comp_rag/clara:$PYTHONPATH
export WANDB_DIR=/mnt/ceph_rbd/comp_rag/clara/debug_data/wandb_logs
# Configuration
data_path=/mnt/ceph_rbd/comp_rag/clara/debug_data
SAVE_MODEL_NAME=clara_cluster2_2m_mix_stage1
SAVE_PATH=/mnt/ceph_rbd/comp_rag/clara/debug_data/train_checkpoint/$SAVE_MODEL_NAME
WANDB_TOKEN=xx
MODEL_PATH=/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2
mkdir -p $SAVE_PATH
# cp -r /mnt/conductor_data/code/clara_project/clara_training $SAVE_PATH/
# Extract distributed parameters dynamically
NCCL_DEBUG=INFO
NUM_NODES=1
MASTER=127.0.0.1
MASTER_PORT=29500
NODE_RANK=0
NUM_LOCAL_GPUS=4
WORLD_SIZE=$((NUM_LOCAL_GPUS * NUM_NODES))
echo "Number of nodes: ${NUM_NODES}"
echo "WORLD_SIZE: ${WORLD_SIZE}"
echo "Number of local GPUs: ${NUM_LOCAL_GPUS}"
echo "Master: ${MASTER}"
echo "Master port: ${MASTER_PORT}"
echo "Node rank: ${NODE_RANK}"
echo "Currently using $(which python)"
# Training command with torchrun
training_commands="openrlhf.cli.train_sft \
--max_len 2048 \
--dataset $data_path/pretrain_data.jsonl \
--pretrain $MODEL_PATH \
--train_batch_size 128 \
--micro_train_batch_size 2 \
--ckpt_path $SAVE_PATH \
--max_samples 500 \
--save_path $SAVE_PATH \
--save_steps -2 \
--logging_steps 1 \
--eval_steps 20 \
--zero_stage 2 \
--max_epochs 1 \
--bf16 \
--flash_attn \
--learning_rate 1e-4 \
--stage stage1 \
--generation_top_k 1 \
--qa_loss \
--doc_max_length 256 \
--compress_rate 32 \
--mse_loss \
--gradient_checkpointing"
# --use_wandb $WANDB_TOKEN"
# --wandb_run_name $SAVE_MODEL_NAME \
# --wandb_project CLaRa"
# Build distributed arguments
DISTRIBUTED_ARGS="--nproc_per_node ${NUM_LOCAL_GPUS} --nnodes ${NUM_NODES} --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint ${MASTER}:${MASTER_PORT} --master_addr ${MASTER} --master_port ${MASTER_PORT} --node_rank ${NODE_RANK}"
# Run training with torchrun for multinode
echo "Starting CLaRa training on node $NODE_RANK of $NUM_NODES nodes..."
if [ $DEBUG -eq 0 ]; then
if [ $NUM_NODES -gt 1 ]; then
# For multinode, check if EFA is available
if command -v fi_info >/dev/null 2>&1; then
fi_info -p efa -t FI_EP_RDM; torchrun $DISTRIBUTED_ARGS -m $training_commands
else
torchrun $DISTRIBUTED_ARGS -m $training_commands
fi
else
torchrun $DISTRIBUTED_ARGS -m $training_commands
fi
else
# Debug mode
WORLD_SIZE=1 LOCAL_RANK=0 \
python -m debugpy --listen 0.0.0.0:5678 --wait-for-client \
-m torch.distributed.launch --nproc_per_node=2 --master_port=20001 \
-m $training_commands
fi
# Copy model file
cp ../openrlhf/models/modeling_clara.py $SAVE_PATH
##############################
# Optional: Final inference
##############################
# echo "Running final inference..."
# cd /mnt/conductor_data/code/clara_project/clara
# unset PYTHONPATH
# export PYTHONPATH=$SAVE_PATH:$PYTHONPATH
#
# if [ "$NODE_RANK" -eq 0 ]; then
# accelerate launch \
# --num_processes=8 \
# --num_machines=1 \
# evaluate.py \
# --model_path $SAVE_MODEL_NAME \
# --stage stage1 \
# --dataset hotpotqa,multihoprag,musique
# else
# echo "Node rank $NODE_RANK: skipping inference"
# exit 0
# fi
echo "CLaRa training completed successfully!" |