File size: 8,782 Bytes
40d87dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# DFlash LoRA 全部改动记录

## 概述

为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。

---

## 新增文件清单

| 文件 | 行数 | 用途 |
|------|------|------|
| `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
| `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
| `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
| `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
| `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |

---

## Step 1 完成过程

### 1.1 分析现有代码

首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:

```
input_ids → target_model.generate_dflash_data() → hidden_states
         → OnlineDFlashModel.forward():
             1. 截断到 block 边界
             2. prepare_noise_input(): anchor 保留,其余 → MASK
             3. embed_tokens(noise_input_ids) → noise_embedding
             4. 构建 DFlash attention mask
             5. draft_model(noise_embedding, target_hidden, mask)
             6. lm_head(hidden) → logits → CE loss
```

非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。

### 1.2 确定 LoRA 版设计差异

| 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
|------|------|------|
| Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
| Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
| Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
| KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
| 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |

### 1.3 新建 LoRA 版三个核心文件

#### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel

- `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
- `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
- `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
- `gradient_checkpointing_enable()`: 代理到底层模型
- `save_pretrained()`: 仅保存 LoRA adapter 权重

#### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel

- `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
- `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
- `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
- `_full_lm_loss()`: 标准 CE loss 路径
- `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
- `forward()`: 完整训练 forward pass

LoRA 版 mask 规则:
- context token i → 因果注意力 (j ≤ i)
- block token i (属于 block b) → 所有 context + 同 block 内双向注意力

#### `scripts/train_dflash_lora.py` — 训练脚本

- 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
- `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
- `build_dataloader()`: 复用 `build_eagle3_dataset``prepare_dp_dataloaders`
- FSDP 包装 + BF16Optimizer
- 训练循环:forward → backward → accumulation → optimizer step
- checkpoint 保存/恢复

---

## OOM 修复改动(4 项)

### 改动 1: FSDP FULL_SHARD (ZeRO-3)

**问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)

**修复**: `train_dflash_lora.py:362`
```python
# 之前
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
# 之后
sharding_strategy=ShardingStrategy.FULL_SHARD
```

**效果**: 参数跨卡分片,每卡省 ~8-12GB

### 改动 2: batch_size=1 + accumulation_steps=8

**问题**: `batch_size=2` 时峰值显存过高

**修复**: `run_train_dflash_lora.sh`
```bash
--batch-size 1 \
--accumulation-steps 8 \
```

**效果**: 等效 global batch size 不变,峰值显存减半

### 改动 3: flex_attention + BlockMask 替换 4D additive mask

**问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores

**修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景

涉及文件:

1. **`specforge/core/dflash_lora.py`**
   - `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
   - 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
   - `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask

2. **`specforge/modeling/draft/dflash_lora.py`**
   - `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace

3. **`scripts/train_dflash_lora.py`**
   - `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
   - `build_model()`: 根据 backend 选择 `attn_implementation`

BlockMask mask function(LoRA 版):
```python
def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
    # Context query: 标准因果
    is_q_ctx = q_idx < context_len
    ctx_visible = is_q_ctx & (kv_idx <= q_idx)

    # Block query: 全部 context + 同 block 双向
    is_q_block = q_idx >= context_len
    is_k_ctx = kv_idx < context_len
    q_block_id = (q_idx - context_len) // block_size
    k_block_id = (kv_idx - context_len) // block_size
    block_attend_ctx = is_q_block & is_k_ctx
    block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)

    return ctx_visible | (block_attend_ctx | block_attend_same)
```

**验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。

**效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存

### 改动 4: chunked cross-entropy loss

**问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+

**修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss

涉及文件:

1. **`specforge/core/dflash_lora.py`**
   - `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
   - 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
   - 提取 `_full_lm_loss()`: 原始非 chunked 路径
   - `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径

2. **`specforge/modeling/draft/dflash_lora.py`**
   - `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
   - `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用

3. **`scripts/train_dflash_lora.py`**
   - `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
   - `build_model()`: 传递到 OnlineDFlashLoRAModel

Chunked loss 核心逻辑:
```python
# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
for start in range(0, effective_len, chunk_size):
    end = min(start + chunk_size, effective_len)
    chunk_loss, chunk_weight = grad_checkpoint(
        _chunk_ce,                          # lm_head + CE
        hidden[:, start:end, :],            # 只取当前 chunk
        input_ids[:, start:end],
        combined_mask[:, start:end],
        use_reentrant=False,
    )
    total_loss += chunk_loss
    total_weight += chunk_weight
loss = total_loss / total_weight
```

**效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB

---

## 当前训练命令

```bash
bash run_train_dflash_lora.sh 2   # 2 = GPU 数量
```

对应完整参数:
```bash
torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
    --model-path /workspace/Qwen3-8B \
    --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
    --output-dir outputs/qwen3-8b-dflash-lora \
    --lora-config configs/qwen3-8b-dflash-lora.json \
    --block-size 16 \
    --max-length 2048 \
    --batch-size 1 \
    --num-epochs 3 \
    --learning-rate 2e-4 \
    --accumulation-steps 8 \
    --loss-decay-gamma 7 \
    --attention-backend flex_attention \
    --lm-head-chunk-size 256 \
    --gradient-checkpointing \
    --chat-template qwen \
    --log-interval 50 \
    --save-interval 500
```

---

## 待验证

- [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
- [ ] 确认无 SDPA math fallback warning
- [ ] 观察 GPU 显存峰值
- [ ] 确认 loss 下降和 accuracy 上升趋势正常