Hanrui / syxin_old /run_train_multinode_dflash.sh
Lekr0's picture
Add files using upload-large-folder tool
7c50656 verified
#!/bin/bash
set -euo pipefail
# 多节点版本的 train_dflash.py 训练脚本
# 对应 run_train_dflash_8gpu.sh 的逻辑,适配 northjob 多节点环境
# effective batch = 64 x 4 x 2 = 512,与 8 卡版本一致
ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
CACHE_DIR=/tmp/specforge_cache
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-official-64gpu
# Parse arguments
NUM_GPUS=8
if [[ $# -ge 1 ]]; then
NUM_GPUS=$1
shift
fi
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
OUTPUT_DIR=$1
shift
fi
EXTRA_ARGS=("$@")
# Environment variables
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
export HF_DATASETS_CACHE=$CACHE_DIR/hf_datasets
export HF_HOME=$CACHE_DIR/hf_home
export TORCHINDUCTOR_CACHE_DIR=$CACHE_DIR/compiled_kernels
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export PYTORCH_ALLOC_CONF=expandable_segments:True
export SPECFORGE_DATA_NUM_PROC=16
# Python binary
DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
if [[ -z "${PYTHON_BIN:-}" ]]; then
if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
else
PYTHON_BIN=python3
fi
fi
cd $ROOT_DIR
# northjob 已经通过 torchrun 设置了分布式环境变量
# 直接运行训练脚本,不要再启动 torch.distributed.run
$PYTHON_BIN scripts/train_dflash.py \
--target-model-path /workspace/models/Qwen3-8B \
--target-model-backend hf \
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json \
--attention-backend sdpa \
--trust-remote-code \
\
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
--chat-template qwen \
--max-length 3072 \
--dataloader-num-workers 4 \
\
--num-epochs 2 \
--batch-size 4 \
--accumulation-steps 2 \
--learning-rate 6e-4 \
--warmup-ratio 0.04 \
--max-grad-norm 1.0 \
--loss-decay-gamma 7 \
--random-anchor \
--num-anchors 512 \
--gradient-checkpointing \
\
--lm-head-chunk-size 256 \
\
--output-dir $OUTPUT_DIR \
--cache-dir $CACHE_DIR \
--log-interval 50 \
--save-interval 500 \
"${EXTRA_ARGS[@]}"