| # Specforge 训练环境配置指南 |
|
|
| > 适用于 `train_dflash_lora_inject.py`(Qwen3-8B LoRA draft model 训练) |
| |
| ## 硬件要求 |
| |
| - GPU: NVIDIA H800/A100(8 卡单机 或 32 卡多机) |
| - CUDA Driver: >= 12.2 |
| - 内存: >= 256GB(推荐) |
| |
| ## 1. 创建 conda 环境 |
| |
| ```bash |
| conda create -n spec python=3.11 -y |
| conda activate spec |
| ``` |
| |
| ## 2. 安装 PyTorch(需要 >= 2.5,支持 flex_attention) |
|
|
| pyproject.toml 指定版本为 `torch==2.9.1`,按需选择: |
|
|
| ```bash |
| pip install -U pip setuptools wheel |
| pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128 |
| ``` |
|
|
| > **注意**:PyTorch 必须从官方源装(带 CUDA),清华源没有 cu128 版本的 whl。 |
|
|
| 验证 flex_attention: |
| ```bash |
| python -c "from torch.nn.attention.flex_attention import flex_attention; print('flex_attention OK')" |
| ``` |
| |
| ## 3. 安装核心依赖 |
| |
| ```bash |
| # transformers(pyproject.toml 指定版本) |
| pip install transformers==4.57.1 -i https://pypi.tuna.tsinghua.edu.cn/simple |
|
|
| # 训练必需 |
| pip install accelerate datasets tqdm peft safetensors pydantic numpy typing_extensions \ |
| -i https://pypi.tuna.tsinghua.edu.cn/simple |
| |
| # sglang(specforge/args.py 顶层 import,必须装) |
| pip install sglang==0.5.6 -i https://pypi.tuna.tsinghua.edu.cn/simple |
| |
| # sgl-kernel(sglang backend 用到的 CUDA kernel) |
| pip install sgl-kernel -i https://pypi.tuna.tsinghua.edu.cn/simple |
| |
| # yunchang(序列并行通信,specforge/distributed.py 顶层 import,必须装) |
| pip install yunchang -i https://pypi.tuna.tsinghua.edu.cn/simple |
| ``` |
| |
| ## 4. 安装可选依赖 |
| |
| ```bash |
| # flash attention(可选,有则用,没有 fallback 到 flex_attention) |
| pip install flash-attn --no-build-isolation -i https://pypi.tuna.tsinghua.edu.cn/simple |
|
|
| # 8-bit optimizer(仅 --optimizer-type adamw_8bit 时需要) |
| pip install bitsandbytes -i https://pypi.tuna.tsinghua.edu.cn/simple |
| |
| # 实验追踪(按需选一个) |
| pip install wandb -i https://pypi.tuna.tsinghua.edu.cn/simple |
| pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple |
| |
| # VL 模型数据处理(当前训练不需要) |
| pip install qwen-vl-utils==0.0.11 -i https://pypi.tuna.tsinghua.edu.cn/simple |
| ``` |
| |
| ## 5. 安装其他工具包 |
| |
| ```bash |
| pip install setuptools packaging ninja psutil -i https://pypi.tuna.tsinghua.edu.cn/simple |
| ``` |
| |
| ## 6. 安装 Specforge 本身 |
| |
| ```bash |
| cd /workspace/hanrui/syxin_old/Specforge |
| pip install -e . --no-deps |
| ``` |
| |
| > 用 `--no-deps` 跳过 pyproject.toml 里的依赖解析(前面已经手动装完了),避免 pip 覆盖 PyTorch 版本。 |
| |
| ## 7. 验证安装 |
| |
| ```bash |
| conda activate spec |
| python -c " |
| import torch |
| print(f'PyTorch: {torch.__version__}') |
| print(f'CUDA: {torch.cuda.is_available()}, {torch.version.cuda}') |
| |
| from torch.nn.attention.flex_attention import flex_attention |
| print('flex_attention: OK') |
|
|
| import transformers; print(f'transformers: {transformers.__version__}') |
| import accelerate; print(f'accelerate: {accelerate.__version__}') |
| import datasets; print('datasets: OK') |
| import peft; print(f'peft: {peft.__version__}') |
| import yunchang; print('yunchang: OK') |
| import sglang; print('sglang: OK') |
| import safetensors; print('safetensors: OK') |
| import pydantic; print('pydantic: OK') |
|
|
| print() |
| print('All good!') |
| " |
| ``` |
| |
| ## 8. 启动训练 |
| |
| ### 单机 8 卡(本地) |
| ```bash |
| cd /workspace/hanrui/syxin_old/Specforge/scripts |
| ./run_train_dflash_lora_inject.sh # 默认 8 GPU |
| ./run_train_dflash_lora_inject.sh 4 # 4 GPU |
| ``` |
| |
| ### 多机 32 卡(northjob) |
| ```bash |
| cd /workspace/nex-agi/Megatron-BPLM |
| ./run_qwen3_8b_sft_32gpu.sh |
| ``` |
| |
| > **注意**:`run_train_multinode.sh` 里的 `DEFAULT_SPECFORGE_PY` 已改为指向 |
| > `/workspace/miniconda3/envs/spec/bin/python3`,northjob 容器也会用 conda 环境。 |
| |
| ## 完整依赖汇总 |
| |
| | 包名 | 版本要求 | 必须/可选 | 说明 | |
| |------|---------|----------|------| |
| | torch | ==2.9.1 (或 >=2.5) | 必须 | 需支持 flex_attention | |
| | torchaudio | ==2.9.1 | 必须 | | |
| | torchvision | ==0.24.1 | 必须 | | |
| | transformers | ==4.57.1 | 必须 | | |
| | accelerate | latest | 必须 | set_seed 等 | |
| | datasets | latest | 必须 | 数据加载 | |
| | tqdm | latest | 必须 | 进度条 | |
| | peft | latest | 必须 | LoRA 适配器 | |
| | yunchang | latest | 必须 | 顶层 import,序列并行 | |
| | sglang | ==0.5.6 | 必须 | 顶层 import(args.py) | |
| | sgl-kernel | latest | 必须 | sglang CUDA kernels | |
| | pydantic | latest | 必须 | 数据模板 | |
| | safetensors | latest | 推荐 | checkpoint 保存 | |
| | numpy | latest | 必须 | | |
| | setuptools | latest | 必须 | | |
| | packaging | latest | 必须 | | |
| | ninja | latest | 必须 | | |
| | psutil | latest | 必须 | | |
| | typing_extensions | latest | 必须 | | |
| | flash-attn | latest | 可选 | 有更快,没有 fallback | |
| | bitsandbytes | latest | 可选 | 8bit optimizer | |
| | wandb / tensorboard | latest | 可选 | 实验追踪 | |
|
|