File size: 2,210 Bytes
7c50656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#!/bin/bash
set -euo pipefail

# 多节点版本的 train_dflash.py 训练脚本
# 对应 run_train_dflash_8gpu.sh 的逻辑,适配 northjob 多节点环境
# effective batch = 64 x 4 x 2 = 512,与 8 卡版本一致

ROOT_DIR=/workspace/hanrui/syxin_old/Specforge
CACHE_DIR=/tmp/specforge_cache
OUTPUT_DIR=$ROOT_DIR/outputs/qwen3-8b-dflash-official-64gpu

# Parse arguments
NUM_GPUS=8
if [[ $# -ge 1 ]]; then
  NUM_GPUS=$1
  shift
fi
if [[ $# -ge 1 && "${1:0:1}" != "-" ]]; then
  OUTPUT_DIR=$1
  shift
fi
EXTRA_ARGS=("$@")

# Environment variables
export PYTHONPATH="$ROOT_DIR:${PYTHONPATH:-}"
export HF_DATASETS_CACHE=$CACHE_DIR/hf_datasets
export HF_HOME=$CACHE_DIR/hf_home
export TORCHINDUCTOR_CACHE_DIR=$CACHE_DIR/compiled_kernels
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export PYTORCH_ALLOC_CONF=expandable_segments:True
export SPECFORGE_DATA_NUM_PROC=16

# Python binary
DEFAULT_SPECFORGE_PY=/workspace/miniconda3/envs/spec/bin/python3
if [[ -z "${PYTHON_BIN:-}" ]]; then
  if [[ -x "$DEFAULT_SPECFORGE_PY" ]]; then
    PYTHON_BIN="$DEFAULT_SPECFORGE_PY"
  else
    PYTHON_BIN=python3
  fi
fi

cd $ROOT_DIR

# northjob 已经通过 torchrun 设置了分布式环境变量
# 直接运行训练脚本,不要再启动 torch.distributed.run
$PYTHON_BIN scripts/train_dflash.py \
  --target-model-path     /workspace/models/Qwen3-8B \
  --target-model-backend  hf \
  --draft-config-path     $ROOT_DIR/configs/qwen3-8b-dflash.json \
  --attention-backend     sdpa \
  --trust-remote-code \
  \
  --train-data-path       /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
  --chat-template         qwen \
  --max-length            3072 \
  --dataloader-num-workers 4 \
  \
  --num-epochs            2 \
  --batch-size            4 \
  --accumulation-steps    2 \
  --learning-rate         6e-4 \
  --warmup-ratio          0.04 \
  --max-grad-norm         1.0 \
  --loss-decay-gamma      7 \
  --random-anchor \
  --num-anchors           512 \
  --gradient-checkpointing \
  \
  --lm-head-chunk-size    256 \
  \
  --output-dir            $OUTPUT_DIR \
  --cache-dir             $CACHE_DIR \
  --log-interval          50 \
  --save-interval         500 \
  "${EXTRA_ARGS[@]}"