| # DFlash LoRA 全部改动记录 |
|
|
| ## 概述 |
|
|
| 为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。 |
|
|
| --- |
|
|
| ## 新增文件清单 |
|
|
| | 文件 | 行数 | 用途 | |
| |------|------|------| |
| | `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) | |
| | `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) | |
| | `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 | |
| | `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 | |
| | `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 | |
|
|
| --- |
|
|
| ## Step 1 完成过程 |
|
|
| ### 1.1 分析现有代码 |
|
|
| 首先分析了非 LoRA 版 `train_dflash.py` 的完整流程: |
|
|
| ``` |
| input_ids → target_model.generate_dflash_data() → hidden_states |
| → OnlineDFlashModel.forward(): |
| 1. 截断到 block 边界 |
| 2. prepare_noise_input(): anchor 保留,其余 → MASK |
| 3. embed_tokens(noise_input_ids) → noise_embedding |
| 4. 构建 DFlash attention mask |
| 5. draft_model(noise_embedding, target_hidden, mask) |
| 6. lm_head(hidden) → logits → CE loss |
| ``` |
|
|
| 非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。 |
|
|
| ### 1.2 确定 LoRA 版设计差异 |
|
|
| | 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) | |
| |------|------|------| |
| | Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA | |
| | Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 | |
| | Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask | |
| | KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len | |
| | 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) | |
|
|
| ### 1.3 新建 LoRA 版三个核心文件 |
|
|
| #### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel |
| |
| - `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数 |
| - `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要) |
| - `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用 |
| - `gradient_checkpointing_enable()`: 代理到底层模型 |
| - `save_pretrained()`: 仅保存 LoRA adapter 权重 |
|
|
| #### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel |
| |
| - `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK |
| - `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]` |
| - `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay) |
| - `_full_lm_loss()`: 标准 CE loss 路径 |
| - `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度) |
| - `forward()`: 完整训练 forward pass |
|
|
| LoRA 版 mask 规则: |
| - context token i → 因果注意力 (j ≤ i) |
| - block token i (属于 block b) → 所有 context + 同 block 内双向注意力 |
|
|
| #### `scripts/train_dflash_lora.py` — 训练脚本 |
|
|
| - 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数 |
| - `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel |
| - `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders` |
| - FSDP 包装 + BF16Optimizer |
| - 训练循环:forward → backward → accumulation → optimizer step |
| - checkpoint 保存/恢复 |
|
|
| --- |
|
|
| ## OOM 修复改动(4 项) |
|
|
| ### 改动 1: FSDP FULL_SHARD (ZeRO-3) |
| |
| **问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16) |
| |
| **修复**: `train_dflash_lora.py:362` |
| ```python |
| # 之前 |
| sharding_strategy=ShardingStrategy.SHARD_GRAD_OP |
| # 之后 |
| sharding_strategy=ShardingStrategy.FULL_SHARD |
| ``` |
| |
| **效果**: 参数跨卡分片,每卡省 ~8-12GB |
| |
| ### 改动 2: batch_size=1 + accumulation_steps=8 |
| |
| **问题**: `batch_size=2` 时峰值显存过高 |
| |
| **修复**: `run_train_dflash_lora.sh` |
| ```bash |
| --batch-size 1 \ |
| --accumulation-steps 8 \ |
| ``` |
| |
| **效果**: 等效 global batch size 不变,峰值显存减半 |
| |
| ### 改动 3: flex_attention + BlockMask 替换 4D additive mask |
| |
| **问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores |
| |
| **修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 |
| |
| 涉及文件: |
| |
| 1. **`specforge/core/dflash_lora.py`** |
| - `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段 |
| - 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask |
| - `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask |
| |
| 2. **`specforge/modeling/draft/dflash_lora.py`** |
| - `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace |
| |
| 3. **`scripts/train_dflash_lora.py`** |
| - `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`) |
| - `build_model()`: 根据 backend 选择 `attn_implementation` |
| |
| BlockMask mask function(LoRA 版): |
| ```python |
| def dflash_lora_mask_fn(b, h, q_idx, kv_idx): |
| # Context query: 标准因果 |
| is_q_ctx = q_idx < context_len |
| ctx_visible = is_q_ctx & (kv_idx <= q_idx) |
|
|
| # Block query: 全部 context + 同 block 双向 |
| is_q_block = q_idx >= context_len |
| is_k_ctx = kv_idx < context_len |
| q_block_id = (q_idx - context_len) // block_size |
| k_block_id = (kv_idx - context_len) // block_size |
| block_attend_ctx = is_q_block & is_k_ctx |
| block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id) |
| |
| return ctx_visible | (block_attend_ctx | block_attend_same) |
| ``` |
| |
| **验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。 |
| |
| **效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存 |
| |
| ### 改动 4: chunked cross-entropy loss |
| |
| **问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+ |
| |
| **修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss |
| |
| 涉及文件: |
| |
| 1. **`specforge/core/dflash_lora.py`** |
| - `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用) |
| - 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing |
| - 提取 `_full_lm_loss()`: 原始非 chunked 路径 |
| - `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径 |
|
|
| 2. **`specforge/modeling/draft/dflash_lora.py`** |
| - `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits |
| - `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用 |
| |
| 3. **`scripts/train_dflash_lora.py`** |
| - `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256) |
| - `build_model()`: 传递到 OnlineDFlashLoRAModel |
| |
| Chunked loss 核心逻辑: |
| ```python |
| # 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储) |
| for start in range(0, effective_len, chunk_size): |
| end = min(start + chunk_size, effective_len) |
| chunk_loss, chunk_weight = grad_checkpoint( |
| _chunk_ce, # lm_head + CE |
| hidden[:, start:end, :], # 只取当前 chunk |
| input_ids[:, start:end], |
| combined_mask[:, start:end], |
| use_reentrant=False, |
| ) |
| total_loss += chunk_loss |
| total_weight += chunk_weight |
| loss = total_loss / total_weight |
| ``` |
| |
| **效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB |
| |
| --- |
| |
| ## 当前训练命令 |
| |
| ```bash |
| bash run_train_dflash_lora.sh 2 # 2 = GPU 数量 |
| ``` |
| |
| 对应完整参数: |
| ```bash |
| torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \ |
| --model-path /workspace/Qwen3-8B \ |
| --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \ |
| --output-dir outputs/qwen3-8b-dflash-lora \ |
| --lora-config configs/qwen3-8b-dflash-lora.json \ |
| --block-size 16 \ |
| --max-length 2048 \ |
| --batch-size 1 \ |
| --num-epochs 3 \ |
| --learning-rate 2e-4 \ |
| --accumulation-steps 8 \ |
| --loss-decay-gamma 7 \ |
| --attention-backend flex_attention \ |
| --lm-head-chunk-size 256 \ |
| --gradient-checkpointing \ |
| --chat-template qwen \ |
| --log-interval 50 \ |
| --save-interval 500 |
| ``` |
| |
| --- |
| |
| ## 待验证 |
| |
| - [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM |
| - [ ] 确认无 SDPA math fallback warning |
| - [ ] 观察 GPU 显存峰值 |
| - [ ] 确认 loss 下降和 accuracy 上升趋势正常 |
| |