File size: 4,764 Bytes
30c14cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | #!/bin/bash
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#
set -ex
DEBUG=${DEBUG:-0}
########################################
# Environment Variables
########################################
export PYTHONPATH=/mnt/ceph_rbd/comp_rag/clara:$PYTHONPATH
export WANDB_DIR=/mnt/task_wrapper/user_output/artifacts/data/wandb_logs
export NCCL_DEBUG=INFO
########################################
# Configuration
########################################
data_path=/mnt/ceph_rbd/comp_rag/clara/debug_data
SAVE_MODEL_NAME=clara_cluster1_2_2m_split_data_single_32_mistral
SAVE_PATH=/mnt/ceph_rbd/comp_rag/clara/debug_data/train_checkpoint/$SAVE_MODEL_NAME
WANDB_TOKEN=xx
MODEL_PATH=/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2
PRETRAIN_CKPT=/mnt/ceph_rbd/comp_rag/clara/debug_data/train_checkpoint/clara_cluster2_2m_mix_stage1
mkdir -p $SAVE_PATH
# cp -r /mnt/ceph_rbd/comp_rag/clara $SAVE_PATH/
echo "Currently using $(which python)"
########################################
# Extract Distributed Parameters
########################################
NUM_NODES=1
MASTER=127.0.0.1
MASTER_PORT=29500
NODE_RANK=0
NUM_LOCAL_GPUS=4
WORLD_SIZE=$((NUM_LOCAL_GPUS * NUM_NODES))
echo "Number of nodes: ${NUM_NODES}"
echo "WORLD_SIZE: ${WORLD_SIZE}"
echo "Number of local GPUs: ${NUM_LOCAL_GPUS}"
echo "Master: ${MASTER}"
echo "Master port: ${MASTER_PORT}"
echo "Node rank: ${NODE_RANK}"
eval_dataset=xx
########################################
# Training Command
########################################
training_commands="openrlhf.cli.train_sft \
--max_len 2048 \
--dataset $data_path/instruction_tuning_data.jsonl \
--pretrain $MODEL_PATH \
--pretrain_checkpoint $PRETRAIN_CKPT \
--train_batch_size 128 \
--micro_train_batch_size 2 \
--ckpt_path $SAVE_PATH \
--max_samples 500 \
--save_path $SAVE_PATH \
--save_steps -2 \
--logging_steps 1 \
--eval_steps 30 \
--zero_stage 2 \
--max_epochs 1 \
--bf16 \
--flash_attn \
--learning_rate 1e-4 \
--gradient_checkpointing \
--generation_top_k 5 \
--stage stage1_2 \
--doc_max_length 256 \
--compress_rate 32 \
--mse_loss \
--do_eval_gen"
#--eval_dataset $eval_dataset \
# --use_wandb $WANDB_TOKEN"
# --wandb_run_name $SAVE_MODEL_NAME \
# --wandb_project CLaRa"
########################################
# Distributed Arguments for torchrun
########################################
DISTRIBUTED_ARGS="--nproc_per_node ${NUM_LOCAL_GPUS} \
--nnodes ${NUM_NODES} \
--rdzv_id 101 \
--rdzv_backend c10d \
--rdzv_endpoint ${MASTER}:${MASTER_PORT} \
--master_addr ${MASTER} \
--master_port ${MASTER_PORT} \
--node_rank ${NODE_RANK}"
########################################
# Multi-node Training
########################################
echo "Starting CLaRa stage1_2 training (multinode with torchrun)..."
if [ $DEBUG -eq 0 ]; then
if [ "$NUM_NODES" -gt 1 ]; then
# Check EFA for multi-node if available
if command -v fi_info >/dev/null 2>&1; then
fi_info -p efa -t FI_EP_RDM || true
fi
torchrun $DISTRIBUTED_ARGS -m $training_commands
else
torchrun $DISTRIBUTED_ARGS -m $training_commands
fi
else
# Debug mode
WORLD_SIZE=1 LOCAL_RANK=0 \
python -m debugpy --listen 0.0.0.0:5678 --wait-for-client \
-m torch.distributed.launch --nproc_per_node=2 --master_port=20001 \
-m $training_commands
fi
########################################
# Copy Model Files
########################################
cp ../openrlhf/models/modeling_clara.py $SAVE_PATH
########################################
# Final Inference (only on rank 0)
########################################
# echo "Running final inference..."
# cd /mnt/ceph_rbd/comp_rag/clara/evaluation
# # Clean and set PYTHONPATH to avoid conflicts
# unset PYTHONPATH
# export PYTHONPATH=/mnt/conductor_data/clara:$PYTHONPATH
# echo "Starting inference on node $NODE_RANK of $NUM_NODES nodes..."
# if [ "$NODE_RANK" -eq 0 ]; then
# # Run inference with gold retrieval
# accelerate launch \
# --num_processes=8 \
# --num_machines=1 \
# evaluate.py \
# --model_path $SAVE_MODEL_NAME \
# --stage stage1 \
# --dataset musique,hotpotqa,2wiki,nq \
# --gold_retrieval
# # Run inference without gold retrieval
# accelerate launch \
# --num_processes=8 \
# --num_machines=1 \
# evaluate.py \
# --model_path $SAVE_MODEL_NAME \
# --stage stage1 \
# --dataset musique,hotpotqa,2wiki,nq
# else
# echo "Node rank $NODE_RANK: skipping inference"
# exit 0
# fi
# echo "CLaRa stage1_2 training and inference completed successfully!" |