File size: 4,923 Bytes
2d67aa6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | # Specforge 训练环境配置指南
> 适用于 `train_dflash_lora_inject.py`(Qwen3-8B LoRA draft model 训练)
## 硬件要求
- GPU: NVIDIA H800/A100(8 卡单机 或 32 卡多机)
- CUDA Driver: >= 12.2
- 内存: >= 256GB(推荐)
## 1. 创建 conda 环境
```bash
conda create -n spec python=3.11 -y
conda activate spec
```
## 2. 安装 PyTorch(需要 >= 2.5,支持 flex_attention)
pyproject.toml 指定版本为 `torch==2.9.1`,按需选择:
```bash
pip install -U pip setuptools wheel
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
```
> **注意**:PyTorch 必须从官方源装(带 CUDA),清华源没有 cu128 版本的 whl。
验证 flex_attention:
```bash
python -c "from torch.nn.attention.flex_attention import flex_attention; print('flex_attention OK')"
```
## 3. 安装核心依赖
```bash
# transformers(pyproject.toml 指定版本)
pip install transformers==4.57.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 训练必需
pip install accelerate datasets tqdm peft safetensors pydantic numpy typing_extensions \
-i https://pypi.tuna.tsinghua.edu.cn/simple
# sglang(specforge/args.py 顶层 import,必须装)
pip install sglang==0.5.6 -i https://pypi.tuna.tsinghua.edu.cn/simple
# sgl-kernel(sglang backend 用到的 CUDA kernel)
pip install sgl-kernel -i https://pypi.tuna.tsinghua.edu.cn/simple
# yunchang(序列并行通信,specforge/distributed.py 顶层 import,必须装)
pip install yunchang -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 4. 安装可选依赖
```bash
# flash attention(可选,有则用,没有 fallback 到 flex_attention)
pip install flash-attn --no-build-isolation -i https://pypi.tuna.tsinghua.edu.cn/simple
# 8-bit optimizer(仅 --optimizer-type adamw_8bit 时需要)
pip install bitsandbytes -i https://pypi.tuna.tsinghua.edu.cn/simple
# 实验追踪(按需选一个)
pip install wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple
# VL 模型数据处理(当前训练不需要)
pip install qwen-vl-utils==0.0.11 -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 5. 安装其他工具包
```bash
pip install setuptools packaging ninja psutil -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 6. 安装 Specforge 本身
```bash
cd /workspace/hanrui/syxin_old/Specforge
pip install -e . --no-deps
```
> 用 `--no-deps` 跳过 pyproject.toml 里的依赖解析(前面已经手动装完了),避免 pip 覆盖 PyTorch 版本。
## 7. 验证安装
```bash
conda activate spec
python -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}, {torch.version.cuda}')
from torch.nn.attention.flex_attention import flex_attention
print('flex_attention: OK')
import transformers; print(f'transformers: {transformers.__version__}')
import accelerate; print(f'accelerate: {accelerate.__version__}')
import datasets; print('datasets: OK')
import peft; print(f'peft: {peft.__version__}')
import yunchang; print('yunchang: OK')
import sglang; print('sglang: OK')
import safetensors; print('safetensors: OK')
import pydantic; print('pydantic: OK')
print()
print('All good!')
"
```
## 8. 启动训练
### 单机 8 卡(本地)
```bash
cd /workspace/hanrui/syxin_old/Specforge/scripts
./run_train_dflash_lora_inject.sh # 默认 8 GPU
./run_train_dflash_lora_inject.sh 4 # 4 GPU
```
### 多机 32 卡(northjob)
```bash
cd /workspace/nex-agi/Megatron-BPLM
./run_qwen3_8b_sft_32gpu.sh
```
> **注意**:`run_train_multinode.sh` 里的 `DEFAULT_SPECFORGE_PY` 已改为指向
> `/workspace/miniconda3/envs/spec/bin/python3`,northjob 容器也会用 conda 环境。
## 完整依赖汇总
| 包名 | 版本要求 | 必须/可选 | 说明 |
|------|---------|----------|------|
| torch | ==2.9.1 (或 >=2.5) | 必须 | 需支持 flex_attention |
| torchaudio | ==2.9.1 | 必须 | |
| torchvision | ==0.24.1 | 必须 | |
| transformers | ==4.57.1 | 必须 | |
| accelerate | latest | 必须 | set_seed 等 |
| datasets | latest | 必须 | 数据加载 |
| tqdm | latest | 必须 | 进度条 |
| peft | latest | 必须 | LoRA 适配器 |
| yunchang | latest | 必须 | 顶层 import,序列并行 |
| sglang | ==0.5.6 | 必须 | 顶层 import(args.py) |
| sgl-kernel | latest | 必须 | sglang CUDA kernels |
| pydantic | latest | 必须 | 数据模板 |
| safetensors | latest | 推荐 | checkpoint 保存 |
| numpy | latest | 必须 | |
| setuptools | latest | 必须 | |
| packaging | latest | 必须 | |
| ninja | latest | 必须 | |
| psutil | latest | 必须 | |
| typing_extensions | latest | 必须 | |
| flash-attn | latest | 可选 | 有更快,没有 fallback |
| bitsandbytes | latest | 可选 | 8bit optimizer |
| wandb / tensorboard | latest | 可选 | 实验追踪 |
|