File size: 8,782 Bytes
40d87dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | # DFlash LoRA 全部改动记录
## 概述
为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。
---
## 新增文件清单
| 文件 | 行数 | 用途 |
|------|------|------|
| `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
| `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
| `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
| `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
| `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |
---
## Step 1 完成过程
### 1.1 分析现有代码
首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:
```
input_ids → target_model.generate_dflash_data() → hidden_states
→ OnlineDFlashModel.forward():
1. 截断到 block 边界
2. prepare_noise_input(): anchor 保留,其余 → MASK
3. embed_tokens(noise_input_ids) → noise_embedding
4. 构建 DFlash attention mask
5. draft_model(noise_embedding, target_hidden, mask)
6. lm_head(hidden) → logits → CE loss
```
非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
### 1.2 确定 LoRA 版设计差异
| 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
|------|------|------|
| Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
| Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
| Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
| KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
| 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
### 1.3 新建 LoRA 版三个核心文件
#### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel
- `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
- `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
- `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
- `gradient_checkpointing_enable()`: 代理到底层模型
- `save_pretrained()`: 仅保存 LoRA adapter 权重
#### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel
- `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
- `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
- `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
- `_full_lm_loss()`: 标准 CE loss 路径
- `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
- `forward()`: 完整训练 forward pass
LoRA 版 mask 规则:
- context token i → 因果注意力 (j ≤ i)
- block token i (属于 block b) → 所有 context + 同 block 内双向注意力
#### `scripts/train_dflash_lora.py` — 训练脚本
- 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
- `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
- `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders`
- FSDP 包装 + BF16Optimizer
- 训练循环:forward → backward → accumulation → optimizer step
- checkpoint 保存/恢复
---
## OOM 修复改动(4 项)
### 改动 1: FSDP FULL_SHARD (ZeRO-3)
**问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
**修复**: `train_dflash_lora.py:362`
```python
# 之前
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
# 之后
sharding_strategy=ShardingStrategy.FULL_SHARD
```
**效果**: 参数跨卡分片,每卡省 ~8-12GB
### 改动 2: batch_size=1 + accumulation_steps=8
**问题**: `batch_size=2` 时峰值显存过高
**修复**: `run_train_dflash_lora.sh`
```bash
--batch-size 1 \
--accumulation-steps 8 \
```
**效果**: 等效 global batch size 不变,峰值显存减半
### 改动 3: flex_attention + BlockMask 替换 4D additive mask
**问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores
**修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景
涉及文件:
1. **`specforge/core/dflash_lora.py`**
- `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
- 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
- `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask
2. **`specforge/modeling/draft/dflash_lora.py`**
- `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace
3. **`scripts/train_dflash_lora.py`**
- `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
- `build_model()`: 根据 backend 选择 `attn_implementation`
BlockMask mask function(LoRA 版):
```python
def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
# Context query: 标准因果
is_q_ctx = q_idx < context_len
ctx_visible = is_q_ctx & (kv_idx <= q_idx)
# Block query: 全部 context + 同 block 双向
is_q_block = q_idx >= context_len
is_k_ctx = kv_idx < context_len
q_block_id = (q_idx - context_len) // block_size
k_block_id = (kv_idx - context_len) // block_size
block_attend_ctx = is_q_block & is_k_ctx
block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
return ctx_visible | (block_attend_ctx | block_attend_same)
```
**验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
**效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存
### 改动 4: chunked cross-entropy loss
**问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
**修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss
涉及文件:
1. **`specforge/core/dflash_lora.py`**
- `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
- 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
- 提取 `_full_lm_loss()`: 原始非 chunked 路径
- `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径
2. **`specforge/modeling/draft/dflash_lora.py`**
- `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
- `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用
3. **`scripts/train_dflash_lora.py`**
- `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
- `build_model()`: 传递到 OnlineDFlashLoRAModel
Chunked loss 核心逻辑:
```python
# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_loss, chunk_weight = grad_checkpoint(
_chunk_ce, # lm_head + CE
hidden[:, start:end, :], # 只取当前 chunk
input_ids[:, start:end],
combined_mask[:, start:end],
use_reentrant=False,
)
total_loss += chunk_loss
total_weight += chunk_weight
loss = total_loss / total_weight
```
**效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB
---
## 当前训练命令
```bash
bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
```
对应完整参数:
```bash
torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
--model-path /workspace/Qwen3-8B \
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
--output-dir outputs/qwen3-8b-dflash-lora \
--lora-config configs/qwen3-8b-dflash-lora.json \
--block-size 16 \
--max-length 2048 \
--batch-size 1 \
--num-epochs 3 \
--learning-rate 2e-4 \
--accumulation-steps 8 \
--loss-decay-gamma 7 \
--attention-backend flex_attention \
--lm-head-chunk-size 256 \
--gradient-checkpointing \
--chat-template qwen \
--log-interval 50 \
--save-interval 500
```
---
## 待验证
- [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
- [ ] 确认无 SDPA math fallback warning
- [ ] 观察 GPU 显存峰值
- [ ] 确认 loss 下降和 accuracy 上升趋势正常
|