Hanrui / progress /dflash_lora_changelog.md
Lekr0's picture
Add files using upload-large-folder tool
40d87dd verified
# DFlash LoRA 全部改动记录
## 概述
为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。
---
## 新增文件清单
| 文件 | 行数 | 用途 |
|------|------|------|
| `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) |
| `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) |
| `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 |
| `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 |
| `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 |
---
## Step 1 完成过程
### 1.1 分析现有代码
首先分析了非 LoRA 版 `train_dflash.py` 的完整流程:
```
input_ids → target_model.generate_dflash_data() → hidden_states
→ OnlineDFlashModel.forward():
1. 截断到 block 边界
2. prepare_noise_input(): anchor 保留,其余 → MASK
3. embed_tokens(noise_input_ids) → noise_embedding
4. 构建 DFlash attention mask
5. draft_model(noise_embedding, target_hidden, mask)
6. lm_head(hidden) → logits → CE loss
```
非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
### 1.2 确定 LoRA 版设计差异
| 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) |
|------|------|------|
| Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
| Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
| Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
| KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
| 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
### 1.3 新建 LoRA 版三个核心文件
#### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel
- `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数
- `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要)
- `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用
- `gradient_checkpointing_enable()`: 代理到底层模型
- `save_pretrained()`: 仅保存 LoRA adapter 权重
#### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel
- `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
- `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]`
- `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
- `_full_lm_loss()`: 标准 CE loss 路径
- `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
- `forward()`: 完整训练 forward pass
LoRA 版 mask 规则:
- context token i → 因果注意力 (j ≤ i)
- block token i (属于 block b) → 所有 context + 同 block 内双向注意力
#### `scripts/train_dflash_lora.py` — 训练脚本
- 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
- `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
- `build_dataloader()`: 复用 `build_eagle3_dataset``prepare_dp_dataloaders`
- FSDP 包装 + BF16Optimizer
- 训练循环:forward → backward → accumulation → optimizer step
- checkpoint 保存/恢复
---
## OOM 修复改动(4 项)
### 改动 1: FSDP FULL_SHARD (ZeRO-3)
**问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
**修复**: `train_dflash_lora.py:362`
```python
# 之前
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
# 之后
sharding_strategy=ShardingStrategy.FULL_SHARD
```
**效果**: 参数跨卡分片,每卡省 ~8-12GB
### 改动 2: batch_size=1 + accumulation_steps=8
**问题**: `batch_size=2` 时峰值显存过高
**修复**: `run_train_dflash_lora.sh`
```bash
--batch-size 1 \
--accumulation-steps 8 \
```
**效果**: 等效 global batch size 不变,峰值显存减半
### 改动 3: flex_attention + BlockMask 替换 4D additive mask
**问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores
**修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景
涉及文件:
1. **`specforge/core/dflash_lora.py`**
- `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段
- 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask
- `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask
2. **`specforge/modeling/draft/dflash_lora.py`**
- `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace
3. **`scripts/train_dflash_lora.py`**
- `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`)
- `build_model()`: 根据 backend 选择 `attn_implementation`
BlockMask mask function(LoRA 版):
```python
def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
# Context query: 标准因果
is_q_ctx = q_idx < context_len
ctx_visible = is_q_ctx & (kv_idx <= q_idx)
# Block query: 全部 context + 同 block 双向
is_q_block = q_idx >= context_len
is_k_ctx = kv_idx < context_len
q_block_id = (q_idx - context_len) // block_size
k_block_id = (kv_idx - context_len) // block_size
block_attend_ctx = is_q_block & is_k_ctx
block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
return ctx_visible | (block_attend_ctx | block_attend_same)
```
**验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
**效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存
### 改动 4: chunked cross-entropy loss
**问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
**修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss
涉及文件:
1. **`specforge/core/dflash_lora.py`**
- `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用)
- 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing
- 提取 `_full_lm_loss()`: 原始非 chunked 路径
- `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径
2. **`specforge/modeling/draft/dflash_lora.py`**
- `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits
- `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用
3. **`scripts/train_dflash_lora.py`**
- `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256)
- `build_model()`: 传递到 OnlineDFlashLoRAModel
Chunked loss 核心逻辑:
```python
# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_loss, chunk_weight = grad_checkpoint(
_chunk_ce, # lm_head + CE
hidden[:, start:end, :], # 只取当前 chunk
input_ids[:, start:end],
combined_mask[:, start:end],
use_reentrant=False,
)
total_loss += chunk_loss
total_weight += chunk_weight
loss = total_loss / total_weight
```
**效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB
---
## 当前训练命令
```bash
bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
```
对应完整参数:
```bash
torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
--model-path /workspace/Qwen3-8B \
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
--output-dir outputs/qwen3-8b-dflash-lora \
--lora-config configs/qwen3-8b-dflash-lora.json \
--block-size 16 \
--max-length 2048 \
--batch-size 1 \
--num-epochs 3 \
--learning-rate 2e-4 \
--accumulation-steps 8 \
--loss-decay-gamma 7 \
--attention-backend flex_attention \
--lm-head-chunk-size 256 \
--gradient-checkpointing \
--chat-template qwen \
--log-interval 50 \
--save-interval 500
```
---
## 待验证
- [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM
- [ ] 确认无 SDPA math fallback warning
- [ ] 观察 GPU 显存峰值
- [ ] 确认 loss 下降和 accuracy 上升趋势正常