File size: 2,025 Bytes
7a60a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#!/bin/bash
# run_train_dflash_8gpu.sh
# 复现 DFlash 官方效果,对齐 examples/run_qwen3_8b_dflash_online.sh
# effective batch = 8 x 4 = 32(官方不用 accumulation)
# 先跑 2 epoch 验证 loss 下降正常,确认后改回 6 epoch 跑完整训练
#
# 用法:
#   bash run_train_dflash_8gpu.sh          # 默认 8 卡
#   bash run_train_dflash_8gpu.sh 4        # 指定 4 卡
#   bash run_train_dflash_8gpu.sh 8 --num-epochs 6   # 追加额外参数

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
CACHE_DIR=/workspace/hanrui/cache

export PYTHONPATH=$ROOT_DIR:${PYTHONPATH:-}
export HF_DATASETS_CACHE=$CACHE_DIR/hf_datasets
export HF_HOME=$CACHE_DIR/hf_home
export TORCHINDUCTOR_CACHE_DIR=$CACHE_DIR/compiled_kernels
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export SPECFORGE_DATA_NUM_PROC=32

NUM_GPUS=${1:-8}
if [[ $# -ge 1 ]]; then shift; fi
EXTRA_ARGS=("$@")

PYTHON_BIN=/workspace/miniconda3/envs/dflash/bin/python3

$PYTHON_BIN -m torch.distributed.run \
    --standalone \
    --nproc_per_node $NUM_GPUS \
    $ROOT_DIR/scripts/train_dflash.py \
    \
    --target-model-path     /workspace/models/Qwen3-8B \
    --target-model-backend  hf \
    --draft-config-path     $ROOT_DIR/configs/qwen3-8b-dflash.json \
    --attention-backend     flex_attention \
    --trust-remote-code \
    \
    --train-data-path       /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
    --chat-template         qwen \
    --max-length            3072 \
    \
    --num-epochs            2 \
    --batch-size            4 \
    --learning-rate         6e-4 \
    --warmup-ratio          0.04 \
    --max-grad-norm         1.0 \
    --random-anchor \
    --num-anchors           512 \
    --loss-decay-gamma      7.0 \
    \
    --output-dir            $ROOT_DIR/outputs/qwen3-8b-dflash-official \
    --cache-dir             $CACHE_DIR \
    --log-interval          50 \
    --save-interval         1000 \
    "${EXTRA_ARGS[@]}"