File size: 4,923 Bytes
2d67aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Specforge 训练环境配置指南

> 适用于 `train_dflash_lora_inject.py`(Qwen3-8B LoRA draft model 训练)

## 硬件要求

- GPU: NVIDIA H800/A100(8 卡单机 或 32 卡多机)
- CUDA Driver: >= 12.2
- 内存: >= 256GB(推荐)

## 1. 创建 conda 环境

```bash
conda create -n spec python=3.11 -y
conda activate spec
```

## 2. 安装 PyTorch(需要 >= 2.5,支持 flex_attention)

pyproject.toml 指定版本为 `torch==2.9.1`,按需选择:

```bash
pip install -U pip setuptools wheel
pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/cu128
```

> **注意**:PyTorch 必须从官方源装(带 CUDA),清华源没有 cu128 版本的 whl。

验证 flex_attention:
```bash
python -c "from torch.nn.attention.flex_attention import flex_attention; print('flex_attention OK')"
```

## 3. 安装核心依赖

```bash
# transformers(pyproject.toml 指定版本)
pip install transformers==4.57.1 -i https://pypi.tuna.tsinghua.edu.cn/simple

# 训练必需
pip install accelerate datasets tqdm peft safetensors pydantic numpy typing_extensions \
    -i https://pypi.tuna.tsinghua.edu.cn/simple

# sglang(specforge/args.py 顶层 import,必须装)
pip install sglang==0.5.6 -i https://pypi.tuna.tsinghua.edu.cn/simple

# sgl-kernel(sglang backend 用到的 CUDA kernel)
pip install sgl-kernel -i https://pypi.tuna.tsinghua.edu.cn/simple

# yunchang(序列并行通信,specforge/distributed.py 顶层 import,必须装)
pip install yunchang -i https://pypi.tuna.tsinghua.edu.cn/simple
```

## 4. 安装可选依赖

```bash
# flash attention(可选,有则用,没有 fallback 到 flex_attention)
pip install flash-attn --no-build-isolation -i https://pypi.tuna.tsinghua.edu.cn/simple

# 8-bit optimizer(仅 --optimizer-type adamw_8bit 时需要)
pip install bitsandbytes -i https://pypi.tuna.tsinghua.edu.cn/simple

# 实验追踪(按需选一个)
pip install wandb -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install tensorboard -i https://pypi.tuna.tsinghua.edu.cn/simple

# VL 模型数据处理(当前训练不需要)
pip install qwen-vl-utils==0.0.11 -i https://pypi.tuna.tsinghua.edu.cn/simple
```

## 5. 安装其他工具包

```bash
pip install setuptools packaging ninja psutil -i https://pypi.tuna.tsinghua.edu.cn/simple
```

## 6. 安装 Specforge 本身

```bash
cd /workspace/hanrui/syxin_old/Specforge
pip install -e . --no-deps
```

> 用 `--no-deps` 跳过 pyproject.toml 里的依赖解析(前面已经手动装完了),避免 pip 覆盖 PyTorch 版本。

## 7. 验证安装

```bash
conda activate spec
python -c "
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}, {torch.version.cuda}')

from torch.nn.attention.flex_attention import flex_attention
print('flex_attention: OK')

import transformers; print(f'transformers: {transformers.__version__}')
import accelerate; print(f'accelerate: {accelerate.__version__}')
import datasets; print('datasets: OK')
import peft; print(f'peft: {peft.__version__}')
import yunchang; print('yunchang: OK')
import sglang; print('sglang: OK')
import safetensors; print('safetensors: OK')
import pydantic; print('pydantic: OK')

print()
print('All good!')
"
```

## 8. 启动训练

### 单机 8 卡(本地)
```bash
cd /workspace/hanrui/syxin_old/Specforge/scripts
./run_train_dflash_lora_inject.sh        # 默认 8 GPU
./run_train_dflash_lora_inject.sh 4      # 4 GPU
```

### 多机 32 卡(northjob)
```bash
cd /workspace/nex-agi/Megatron-BPLM
./run_qwen3_8b_sft_32gpu.sh
```

> **注意**:`run_train_multinode.sh` 里的 `DEFAULT_SPECFORGE_PY` 已改为指向
> `/workspace/miniconda3/envs/spec/bin/python3`,northjob 容器也会用 conda 环境。

## 完整依赖汇总

| 包名 | 版本要求 | 必须/可选 | 说明 |
|------|---------|----------|------|
| torch | ==2.9.1 (或 >=2.5) | 必须 | 需支持 flex_attention |
| torchaudio | ==2.9.1 | 必须 | |
| torchvision | ==0.24.1 | 必须 | |
| transformers | ==4.57.1 | 必须 | |
| accelerate | latest | 必须 | set_seed 等 |
| datasets | latest | 必须 | 数据加载 |
| tqdm | latest | 必须 | 进度条 |
| peft | latest | 必须 | LoRA 适配器 |
| yunchang | latest | 必须 | 顶层 import,序列并行 |
| sglang | ==0.5.6 | 必须 | 顶层 import(args.py) |
| sgl-kernel | latest | 必须 | sglang CUDA kernels |
| pydantic | latest | 必须 | 数据模板 |
| safetensors | latest | 推荐 | checkpoint 保存 |
| numpy | latest | 必须 | |
| setuptools | latest | 必须 | |
| packaging | latest | 必须 | |
| ninja | latest | 必须 | |
| psutil | latest | 必须 | |
| typing_extensions | latest | 必须 | |
| flash-attn | latest | 可选 | 有更快,没有 fallback |
| bitsandbytes | latest | 可选 | 8bit optimizer |
| wandb / tensorboard | latest | 可选 | 实验追踪 |