# DFlash LoRA 全部改动记录 ## 概述 为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 **5 个文件,1084 行代码**。改动分为两大阶段:基础搭建 + OOM 修复。 --- ## 新增文件清单 | 文件 | 行数 | 用途 | |------|------|------| | `specforge/core/dflash_lora.py` | 453 | 训练 wrapper(OnlineDFlashLoRAModel) | | `specforge/modeling/draft/dflash_lora.py` | 141 | LoRA draft 模型(DFlashLoRADraftModel) | | `scripts/train_dflash_lora.py` | 449 | 训练入口脚本 | | `scripts/run_train_dflash_lora.sh` | 31 | 启动 shell 脚本 | | `configs/qwen3-8b-dflash-lora.json` | 10 | LoRA 配置文件 | --- ## Step 1 完成过程 ### 1.1 分析现有代码 首先分析了非 LoRA 版 `train_dflash.py` 的完整流程: ``` input_ids → target_model.generate_dflash_data() → hidden_states → OnlineDFlashModel.forward(): 1. 截断到 block 边界 2. prepare_noise_input(): anchor 保留,其余 → MASK 3. embed_tokens(noise_input_ids) → noise_embedding 4. 构建 DFlash attention mask 5. draft_model(noise_embedding, target_hidden, mask) 6. lm_head(hidden) → logits → CE loss ``` 非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。 ### 1.2 确定 LoRA 版设计差异 | 方面 | 非 LoRA 版 (`train_dflash.py`) | LoRA 版 (`train_dflash_lora.py`) | |------|------|------| | Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA | | Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 | | Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask | | KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len | | 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) | ### 1.3 新建 LoRA 版三个核心文件 #### `specforge/modeling/draft/dflash_lora.py` — DFlashLoRADraftModel - `from_pretrained()`: 加载 Qwen3-8B,注入 PEFT LoRA,支持 `attn_implementation` 参数 - `forward()`: 标准 HF forward,支持 `output_hidden_states` 参数(chunked loss 需要) - `get_lm_head()`: 穿透 PEFT 层级获取 lm_head 引用 - `gradient_checkpointing_enable()`: 代理到底层模型 - `save_pretrained()`: 仅保存 LoRA adapter 权重 #### `specforge/core/dflash_lora.py` — OnlineDFlashLoRAModel - `prepare_noise_input()`: context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK - `build_dflash_full_attn_mask_fast()`: 向量化构建 4D additive mask `[bsz, 1, seq, seq]` - `_compute_loss_weights()`: context + anchor 权重为 0,非 anchor 权重为 1(或 decay) - `_full_lm_loss()`: 标准 CE loss 路径 - `_compute_accuracy()`: block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度) - `forward()`: 完整训练 forward pass LoRA 版 mask 规则: - context token i → 因果注意力 (j ≤ i) - block token i (属于 block b) → 所有 context + 同 block 内双向注意力 #### `scripts/train_dflash_lora.py` — 训练脚本 - 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数 - `build_model()`: 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel - `build_dataloader()`: 复用 `build_eagle3_dataset` 和 `prepare_dp_dataloaders` - FSDP 包装 + BF16Optimizer - 训练循环:forward → backward → accumulation → optimizer step - checkpoint 保存/恢复 --- ## OOM 修复改动(4 项) ### 改动 1: FSDP FULL_SHARD (ZeRO-3) **问题**: `SHARD_GRAD_OP` (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16) **修复**: `train_dflash_lora.py:362` ```python # 之前 sharding_strategy=ShardingStrategy.SHARD_GRAD_OP # 之后 sharding_strategy=ShardingStrategy.FULL_SHARD ``` **效果**: 参数跨卡分片,每卡省 ~8-12GB ### 改动 2: batch_size=1 + accumulation_steps=8 **问题**: `batch_size=2` 时峰值显存过高 **修复**: `run_train_dflash_lora.sh` ```bash --batch-size 1 \ --accumulation-steps 8 \ ``` **效果**: 等效 global batch size 不变,峰值显存减半 ### 改动 3: flex_attention + BlockMask 替换 4D additive mask **问题**: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 `[bsz, 32heads, 2048, 2048]` attention scores **修复**: 从非 LoRA 版 `dflash.py` 移植 `_get_or_create_block_mask()` 方法,适配 LoRA 场景 涉及文件: 1. **`specforge/core/dflash_lora.py`** - `__init__()`: 添加 `attention_backend` 参数(默认 `"flex_attention"`),BlockMask 缓存字段 - 新增 `_get_or_create_block_mask()`: 用 `create_block_mask()` 构建零显存的 BlockMask - `forward()`: 根据 `attention_backend` 选择 BlockMask 或 additive mask 2. **`specforge/modeling/draft/dflash_lora.py`** - `from_pretrained()`: 当 backend 为 flex_attention 时,传 `attn_implementation="flex_attention"` 给 HuggingFace 3. **`scripts/train_dflash_lora.py`** - `parse_args()`: `--attention-backend` 参数 (`flex_attention` | `additive`) - `build_model()`: 根据 backend 选择 `attn_implementation` BlockMask mask function(LoRA 版): ```python def dflash_lora_mask_fn(b, h, q_idx, kv_idx): # Context query: 标准因果 is_q_ctx = q_idx < context_len ctx_visible = is_q_ctx & (kv_idx <= q_idx) # Block query: 全部 context + 同 block 双向 is_q_block = q_idx >= context_len is_k_ctx = kv_idx < context_len q_block_id = (q_idx - context_len) // block_size k_block_id = (kv_idx - context_len) // block_size block_attend_ctx = is_q_block & is_k_ctx block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id) return ctx_visible | (block_attend_ctx | block_attend_same) ``` **验证**: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。 **效果**: 不再 fallback 到 SDPA math backend,省去 `[bsz, heads, seq, seq]` attention scores 显存 ### 改动 4: chunked cross-entropy loss **问题**: `[bsz, 2048, 151936]` bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+ **修复**: 从非 LoRA 版 `dflash.py:419-478` 移植 chunked loss 涉及文件: 1. **`specforge/core/dflash_lora.py`** - `__init__()`: 添加 `lm_head_chunk_size` 参数(默认 0 = 不启用) - 新增 `_chunked_lm_loss()`: 分 chunk 过 lm_head + CE loss + gradient checkpointing - 提取 `_full_lm_loss()`: 原始非 chunked 路径 - `forward()`: `lm_head_chunk_size > 0` 时走 chunked 路径 2. **`specforge/modeling/draft/dflash_lora.py`** - `forward()`: 新增 `output_hidden_states` 参数,True 时返回 last hidden state 而非 logits - `get_lm_head()`: 穿透 PEFT 层级返回 `base_model.lm_head` 引用 3. **`scripts/train_dflash_lora.py`** - `parse_args()`: `--lm-head-chunk-size` 参数(默认 0,推荐 256) - `build_model()`: 传递到 OnlineDFlashLoRAModel Chunked loss 核心逻辑: ```python # 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储) for start in range(0, effective_len, chunk_size): end = min(start + chunk_size, effective_len) chunk_loss, chunk_weight = grad_checkpoint( _chunk_ce, # lm_head + CE hidden[:, start:end, :], # 只取当前 chunk input_ids[:, start:end], combined_mask[:, start:end], use_reentrant=False, ) total_loss += chunk_loss total_weight += chunk_weight loss = total_loss / total_weight ``` **效果**: logits 峰值显存从 `O(seq_len × vocab_size)` 降至 `O(chunk_size × vocab_size)`,256 chunk → ~150MB vs 1.18GB --- ## 当前训练命令 ```bash bash run_train_dflash_lora.sh 2 # 2 = GPU 数量 ``` 对应完整参数: ```bash torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \ --model-path /workspace/Qwen3-8B \ --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \ --output-dir outputs/qwen3-8b-dflash-lora \ --lora-config configs/qwen3-8b-dflash-lora.json \ --block-size 16 \ --max-length 2048 \ --batch-size 1 \ --num-epochs 3 \ --learning-rate 2e-4 \ --accumulation-steps 8 \ --loss-decay-gamma 7 \ --attention-backend flex_attention \ --lm-head-chunk-size 256 \ --gradient-checkpointing \ --chat-template qwen \ --log-interval 50 \ --save-interval 500 ``` --- ## 待验证 - [ ] 跑 `bash run_train_dflash_lora.sh 2` 确认不再 OOM - [ ] 确认无 SDPA math fallback warning - [ ] 观察 GPU 显存峰值 - [ ] 确认 loss 下降和 accuracy 上升趋势正常