Hanrui / syxin /Specforge /docs /env_setup.md
Lekr0's picture
Add files using upload-large-folder tool
2d67aa6 verified

Specforge 训练环境配置指南

适用于 train_dflash_lora_inject.py(Qwen3-8B LoRA draft model 训练)

硬件要求

  • GPU: NVIDIA H800/A100(8 卡单机 或 32 卡多机)
  • CUDA Driver: >= 12.2
  • 内存: >= 256GB(推荐)

1. 创建 conda 环境

conda create -n spec python=3.11 -y
conda activate spec

2. 安装 PyTorch(需要 >= 2.5,支持 flex_attention)

pyproject.toml 指定版本为 torch==2.9.1,按需选择:

pip install -U pip setuptools wheel
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128

注意:PyTorch 必须从官方源装(带 CUDA),清华源没有 cu128 版本的 whl。

验证 flex_attention:

python -c "from torch.nn.attention.flex_attention import flex_attention; print('flex_attention OK')"

3. 安装核心依赖

# transformers(pyproject.toml 指定版本)
pip install transformers==4.57.1 -i https://pypi.tuna.tsinghua.edu.cn/simple

# 训练必需
pip install accelerate datasets tqdm peft safetensors pydantic numpy typing_extensions \
    -i https://pypi.tuna.tsinghua.edu.cn/simple

# sglang(specforge/args.py 顶层 import,必须装)
pip install sglang==0.5.6 -i https://pypi.tuna.tsinghua.edu.cn/simple

# sgl-kernel(sglang backend 用到的 CUDA kernel)
pip install sgl-kernel -i https://pypi.tuna.tsinghua.edu.cn/simple

# yunchang(序列并行通信,specforge/distributed.py 顶层 import,必须装)
pip install yunchang -i https://pypi.tuna.tsinghua.edu.cn/simple

4. 安装可选依赖

# flash attention(可选,有则用,没有 fallback 到 flex_attention)
pip install flash-attn --no-build-isolation -i https://pypi.tuna.tsinghua.edu.cn/simple

# 8-bit optimizer(仅 --optimizer-type adamw_8bit 时需要)
pip install bitsandbytes -i https://pypi.tuna.tsinghua.edu.cn/simple

# 实验追踪(按需选一个)
pip install wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple

# VL 模型数据处理(当前训练不需要)
pip install qwen-vl-utils==0.0.11 -i https://pypi.tuna.tsinghua.edu.cn/simple

5. 安装其他工具包

pip install setuptools packaging ninja psutil -i https://pypi.tuna.tsinghua.edu.cn/simple

6. 安装 Specforge 本身

cd /workspace/hanrui/syxin_old/Specforge
pip install -e . --no-deps

--no-deps 跳过 pyproject.toml 里的依赖解析(前面已经手动装完了),避免 pip 覆盖 PyTorch 版本。

7. 验证安装

conda activate spec
python -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}, {torch.version.cuda}')

from torch.nn.attention.flex_attention import flex_attention
print('flex_attention: OK')

import transformers; print(f'transformers: {transformers.__version__}')
import accelerate; print(f'accelerate: {accelerate.__version__}')
import datasets; print('datasets: OK')
import peft; print(f'peft: {peft.__version__}')
import yunchang; print('yunchang: OK')
import sglang; print('sglang: OK')
import safetensors; print('safetensors: OK')
import pydantic; print('pydantic: OK')

print()
print('All good!')
"

8. 启动训练

单机 8 卡(本地)

cd /workspace/hanrui/syxin_old/Specforge/scripts
./run_train_dflash_lora_inject.sh        # 默认 8 GPU
./run_train_dflash_lora_inject.sh 4      # 4 GPU

多机 32 卡(northjob)

cd /workspace/nex-agi/Megatron-BPLM
./run_qwen3_8b_sft_32gpu.sh

注意run_train_multinode.sh 里的 DEFAULT_SPECFORGE_PY 已改为指向 /workspace/miniconda3/envs/spec/bin/python3,northjob 容器也会用 conda 环境。

完整依赖汇总

包名 版本要求 必须/可选 说明
torch ==2.9.1 (或 >=2.5) 必须 需支持 flex_attention
torchaudio ==2.9.1 必须
torchvision ==0.24.1 必须
transformers ==4.57.1 必须
accelerate latest 必须 set_seed 等
datasets latest 必须 数据加载
tqdm latest 必须 进度条
peft latest 必须 LoRA 适配器
yunchang latest 必须 顶层 import,序列并行
sglang ==0.5.6 必须 顶层 import(args.py)
sgl-kernel latest 必须 sglang CUDA kernels
pydantic latest 必须 数据模板
safetensors latest 推荐 checkpoint 保存
numpy latest 必须
setuptools latest 必须
packaging latest 必须
ninja latest 必须
psutil latest 必须
typing_extensions latest 必须
flash-attn latest 可选 有更快,没有 fallback
bitsandbytes latest 可选 8bit optimizer
wandb / tensorboard latest 可选 实验追踪