Hanrui / syxin_old /Specforge /scripts /run_train_dflash_8gpu.sh
Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
#!/bin/bash
# run_train_dflash_8gpu.sh
# 复现 DFlash 官方效果,对齐 examples/run_qwen3_8b_dflash_online.sh
# effective batch = 8 x 4 = 32(官方不用 accumulation)
# 先跑 2 epoch 验证 loss 下降正常,确认后改回 6 epoch 跑完整训练
#
# 用法:
# bash run_train_dflash_8gpu.sh # 默认 8 卡
# bash run_train_dflash_8gpu.sh 4 # 指定 4 卡
# bash run_train_dflash_8gpu.sh 8 --num-epochs 6 # 追加额外参数
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
CACHE_DIR=/workspace/hanrui/cache
export PYTHONPATH=$ROOT_DIR:${PYTHONPATH:-}
export HF_DATASETS_CACHE=$CACHE_DIR/hf_datasets
export HF_HOME=$CACHE_DIR/hf_home
export TORCHINDUCTOR_CACHE_DIR=$CACHE_DIR/compiled_kernels
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export SPECFORGE_DATA_NUM_PROC=32
NUM_GPUS=${1:-8}
if [[ $# -ge 1 ]]; then shift; fi
EXTRA_ARGS=("$@")
PYTHON_BIN=/workspace/miniconda3/envs/dflash/bin/python3
$PYTHON_BIN -m torch.distributed.run \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_dflash.py \
\
--target-model-path /workspace/models/Qwen3-8B \
--target-model-backend hf \
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \
--attention-backend flex_attention \
--trust-remote-code \
\
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
--chat-template qwen \
--max-length 3072 \
\
--num-epochs 2 \
--batch-size 4 \
--learning-rate 6e-4 \
--warmup-ratio 0.04 \
--max-grad-norm 1.0 \
--random-anchor \
--num-anchors 512 \
--loss-decay-gamma 7.0 \
\
--output-dir $ROOT_DIR/outputs/qwen3-8b-dflash-official \
--cache-dir $CACHE_DIR \
--log-interval 50 \
--save-interval 1000 \
"${EXTRA_ARGS[@]}"