File size: 4,764 Bytes
30c14cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/bin/bash
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2025 Apple Inc. All Rights Reserved.
#


set -ex

DEBUG=${DEBUG:-0}

########################################
# Environment Variables
########################################
export PYTHONPATH=/mnt/ceph_rbd/comp_rag/clara:$PYTHONPATH
export WANDB_DIR=/mnt/task_wrapper/user_output/artifacts/data/wandb_logs
export NCCL_DEBUG=INFO

########################################
# Configuration
########################################
data_path=/mnt/ceph_rbd/comp_rag/clara/debug_data
SAVE_MODEL_NAME=clara_cluster1_2_2m_split_data_single_32_mistral
SAVE_PATH=/mnt/ceph_rbd/comp_rag/clara/debug_data/train_checkpoint/$SAVE_MODEL_NAME
WANDB_TOKEN=xx
MODEL_PATH=/mnt/ceph_rbd/model/Mistral-7B-Instruct-v0.2
PRETRAIN_CKPT=/mnt/ceph_rbd/comp_rag/clara/debug_data/train_checkpoint/clara_cluster2_2m_mix_stage1

mkdir -p $SAVE_PATH
# cp -r /mnt/ceph_rbd/comp_rag/clara $SAVE_PATH/

echo "Currently using $(which python)"

########################################
# Extract Distributed Parameters
########################################
NUM_NODES=1
MASTER=127.0.0.1
MASTER_PORT=29500
NODE_RANK=0
NUM_LOCAL_GPUS=4
WORLD_SIZE=$((NUM_LOCAL_GPUS * NUM_NODES))

echo "Number of nodes: ${NUM_NODES}"
echo "WORLD_SIZE: ${WORLD_SIZE}"
echo "Number of local GPUs: ${NUM_LOCAL_GPUS}"
echo "Master: ${MASTER}"
echo "Master port: ${MASTER_PORT}"
echo "Node rank: ${NODE_RANK}"

eval_dataset=xx
########################################
# Training Command
########################################
training_commands="openrlhf.cli.train_sft \
   --max_len 2048 \
   --dataset $data_path/instruction_tuning_data.jsonl \
   --pretrain $MODEL_PATH \
   --pretrain_checkpoint $PRETRAIN_CKPT \
   --train_batch_size 128 \
   --micro_train_batch_size 2 \
   --ckpt_path $SAVE_PATH \
   --max_samples 500 \
   --save_path $SAVE_PATH \
   --save_steps -2 \
   --logging_steps 1 \
   --eval_steps 30 \
   --zero_stage 2 \
   --max_epochs 1 \
   --bf16 \
   --flash_attn \
   --learning_rate 1e-4 \
   --gradient_checkpointing \
   --generation_top_k 5 \
   --stage stage1_2 \
   --doc_max_length 256 \
   --compress_rate 32 \
   --mse_loss \
   --do_eval_gen"

#--eval_dataset $eval_dataset \
#    --use_wandb $WANDB_TOKEN"
#    --wandb_run_name $SAVE_MODEL_NAME \
#    --wandb_project CLaRa"

########################################
# Distributed Arguments for torchrun
########################################
DISTRIBUTED_ARGS="--nproc_per_node ${NUM_LOCAL_GPUS} \
   --nnodes ${NUM_NODES} \
   --rdzv_id 101 \
   --rdzv_backend c10d \
   --rdzv_endpoint ${MASTER}:${MASTER_PORT} \
   --master_addr ${MASTER} \
   --master_port ${MASTER_PORT} \
   --node_rank ${NODE_RANK}"

########################################
# Multi-node Training
########################################
echo "Starting CLaRa stage1_2 training (multinode with torchrun)..."
if [ $DEBUG -eq 0 ]; then
    if [ "$NUM_NODES" -gt 1 ]; then
        # Check EFA for multi-node if available
        if command -v fi_info >/dev/null 2>&1; then
            fi_info -p efa -t FI_EP_RDM || true
        fi
        torchrun $DISTRIBUTED_ARGS -m $training_commands
    else
        torchrun $DISTRIBUTED_ARGS -m $training_commands
    fi
else
    # Debug mode
    WORLD_SIZE=1 LOCAL_RANK=0 \
    python -m debugpy --listen 0.0.0.0:5678 --wait-for-client \
    -m torch.distributed.launch --nproc_per_node=2 --master_port=20001 \
    -m $training_commands
fi

########################################
# Copy Model Files
########################################
cp ../openrlhf/models/modeling_clara.py $SAVE_PATH

########################################
# Final Inference (only on rank 0)
########################################
# echo "Running final inference..."
# cd /mnt/ceph_rbd/comp_rag/clara/evaluation

# # Clean and set PYTHONPATH to avoid conflicts
# unset PYTHONPATH
# export PYTHONPATH=/mnt/conductor_data/clara:$PYTHONPATH

# echo "Starting inference on node $NODE_RANK of $NUM_NODES nodes..."
# if [ "$NODE_RANK" -eq 0 ]; then
#     # Run inference with gold retrieval
#     accelerate launch \
#         --num_processes=8 \
#         --num_machines=1 \
#         evaluate.py \
#         --model_path $SAVE_MODEL_NAME \
#         --stage stage1 \
#         --dataset musique,hotpotqa,2wiki,nq \
#         --gold_retrieval
    
#     # Run inference without gold retrieval
#     accelerate launch \
#         --num_processes=8 \
#         --num_machines=1 \
#         evaluate.py \
#         --model_path $SAVE_MODEL_NAME \
#         --stage stage1 \
#         --dataset musique,hotpotqa,2wiki,nq
# else
#     echo "Node rank $NODE_RANK: skipping inference"
#     exit 0
# fi

# echo "CLaRa stage1_2 training and inference completed successfully!"