DFlash LoRA 全部改动记录
概述
为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 5 个文件,1084 行代码。改动分为两大阶段:基础搭建 + OOM 修复。
新增文件清单
| 文件 | 行数 | 用途 |
|---|---|---|
specforge/core/dflash_lora.py |
453 | 训练 wrapper(OnlineDFlashLoRAModel) |
specforge/modeling/draft/dflash_lora.py |
141 | LoRA draft 模型(DFlashLoRADraftModel) |
scripts/train_dflash_lora.py |
449 | 训练入口脚本 |
scripts/run_train_dflash_lora.sh |
31 | 启动 shell 脚本 |
configs/qwen3-8b-dflash-lora.json |
10 | LoRA 配置文件 |
Step 1 完成过程
1.1 分析现有代码
首先分析了非 LoRA 版 train_dflash.py 的完整流程:
input_ids → target_model.generate_dflash_data() → hidden_states
→ OnlineDFlashModel.forward():
1. 截断到 block 边界
2. prepare_noise_input(): anchor 保留,其余 → MASK
3. embed_tokens(noise_input_ids) → noise_embedding
4. 构建 DFlash attention mask
5. draft_model(noise_embedding, target_hidden, mask)
6. lm_head(hidden) → logits → CE loss
非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。
1.2 确定 LoRA 版设计差异
| 方面 | 非 LoRA 版 (train_dflash.py) |
LoRA 版 (train_dflash_lora.py) |
|---|---|---|
| Draft model | 自定义小模型 (1-10 层) | Qwen3-8B + PEFT LoRA |
| Target model | 冻结大模型提取 hidden states | 无需 — 模型用自身表征 |
| Attention | 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat | 标准 HF attention + DFlash mask |
| KV 结构 | Q_LEN = noise_len, KV_LEN = 2×noise_len | Q_LEN = KV_LEN = seq_len |
| 可训练参数 | 全部 draft model 参数 | 仅 LoRA (q/k/v/o_proj) |
1.3 新建 LoRA 版三个核心文件
specforge/modeling/draft/dflash_lora.py — DFlashLoRADraftModel
from_pretrained(): 加载 Qwen3-8B,注入 PEFT LoRA,支持attn_implementation参数forward(): 标准 HF forward,支持output_hidden_states参数(chunked loss 需要)get_lm_head(): 穿透 PEFT 层级获取 lm_head 引用gradient_checkpointing_enable(): 代理到底层模型save_pretrained(): 仅保存 LoRA adapter 权重
specforge/core/dflash_lora.py — OnlineDFlashLoRAModel
prepare_noise_input(): context 部分保持不变,block 部分只保留 anchor,其余替换为 MASKbuild_dflash_full_attn_mask_fast(): 向量化构建 4D additive mask[bsz, 1, seq, seq]_compute_loss_weights(): context + anchor 权重为 0,非 anchor 权重为 1(或 decay)_full_lm_loss(): 标准 CE loss 路径_compute_accuracy(): block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)forward(): 完整训练 forward pass
LoRA 版 mask 规则:
- context token i → 因果注意力 (j ≤ i)
- block token i (属于 block b) → 所有 context + 同 block 内双向注意力
scripts/train_dflash_lora.py — 训练脚本
- 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
build_model(): 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModelbuild_dataloader(): 复用build_eagle3_dataset和prepare_dp_dataloaders- FSDP 包装 + BF16Optimizer
- 训练循环:forward → backward → accumulation → optimizer step
- checkpoint 保存/恢复
OOM 修复改动(4 项)
改动 1: FSDP FULL_SHARD (ZeRO-3)
问题: SHARD_GRAD_OP (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)
修复: train_dflash_lora.py:362
# 之前
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
# 之后
sharding_strategy=ShardingStrategy.FULL_SHARD
效果: 参数跨卡分片,每卡省 ~8-12GB
改动 2: batch_size=1 + accumulation_steps=8
问题: batch_size=2 时峰值显存过高
修复: run_train_dflash_lora.sh
--batch-size 1 \
--accumulation-steps 8 \
效果: 等效 global batch size 不变,峰值显存减半
改动 3: flex_attention + BlockMask 替换 4D additive mask
问题: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 [bsz, 32heads, 2048, 2048] attention scores
修复: 从非 LoRA 版 dflash.py 移植 _get_or_create_block_mask() 方法,适配 LoRA 场景
涉及文件:
specforge/core/dflash_lora.py__init__(): 添加attention_backend参数(默认"flex_attention"),BlockMask 缓存字段- 新增
_get_or_create_block_mask(): 用create_block_mask()构建零显存的 BlockMask forward(): 根据attention_backend选择 BlockMask 或 additive mask
specforge/modeling/draft/dflash_lora.pyfrom_pretrained(): 当 backend 为 flex_attention 时,传attn_implementation="flex_attention"给 HuggingFace
scripts/train_dflash_lora.pyparse_args():--attention-backend参数 (flex_attention|additive)build_model(): 根据 backend 选择attn_implementation
BlockMask mask function(LoRA 版):
def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
# Context query: 标准因果
is_q_ctx = q_idx < context_len
ctx_visible = is_q_ctx & (kv_idx <= q_idx)
# Block query: 全部 context + 同 block 双向
is_q_block = q_idx >= context_len
is_k_ctx = kv_idx < context_len
q_block_id = (q_idx - context_len) // block_size
k_block_id = (kv_idx - context_len) // block_size
block_attend_ctx = is_q_block & is_k_ctx
block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)
return ctx_visible | (block_attend_ctx | block_attend_same)
验证: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。
效果: 不再 fallback 到 SDPA math backend,省去 [bsz, heads, seq, seq] attention scores 显存
改动 4: chunked cross-entropy loss
问题: [bsz, 2048, 151936] bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+
修复: 从非 LoRA 版 dflash.py:419-478 移植 chunked loss
涉及文件:
specforge/core/dflash_lora.py__init__(): 添加lm_head_chunk_size参数(默认 0 = 不启用)- 新增
_chunked_lm_loss(): 分 chunk 过 lm_head + CE loss + gradient checkpointing - 提取
_full_lm_loss(): 原始非 chunked 路径 forward():lm_head_chunk_size > 0时走 chunked 路径
specforge/modeling/draft/dflash_lora.pyforward(): 新增output_hidden_states参数,True 时返回 last hidden state 而非 logitsget_lm_head(): 穿透 PEFT 层级返回base_model.lm_head引用
scripts/train_dflash_lora.pyparse_args():--lm-head-chunk-size参数(默认 0,推荐 256)build_model(): 传递到 OnlineDFlashLoRAModel
Chunked loss 核心逻辑:
# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_loss, chunk_weight = grad_checkpoint(
_chunk_ce, # lm_head + CE
hidden[:, start:end, :], # 只取当前 chunk
input_ids[:, start:end],
combined_mask[:, start:end],
use_reentrant=False,
)
total_loss += chunk_loss
total_weight += chunk_weight
loss = total_loss / total_weight
效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size),256 chunk → ~150MB vs 1.18GB
当前训练命令
bash run_train_dflash_lora.sh 2 # 2 = GPU 数量
对应完整参数:
torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
--model-path /workspace/Qwen3-8B \
--train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
--output-dir outputs/qwen3-8b-dflash-lora \
--lora-config configs/qwen3-8b-dflash-lora.json \
--block-size 16 \
--max-length 2048 \
--batch-size 1 \
--num-epochs 3 \
--learning-rate 2e-4 \
--accumulation-steps 8 \
--loss-decay-gamma 7 \
--attention-backend flex_attention \
--lm-head-chunk-size 256 \
--gradient-checkpointing \
--chat-template qwen \
--log-interval 50 \
--save-interval 500
待验证
- 跑
bash run_train_dflash_lora.sh 2确认不再 OOM - 确认无 SDPA math fallback warning
- 观察 GPU 显存峰值
- 确认 loss 下降和 accuracy 上升趋势正常