Hanrui / progress /dflash_lora_changelog.md
Lekr0's picture
Add files using upload-large-folder tool
40d87dd verified

DFlash LoRA 全部改动记录

概述

为了让 Qwen3-8B DFlash LoRA 训练在 2×H100 上跑通(解决 OOM),共新增/修改了 5 个文件,1084 行代码。改动分为两大阶段:基础搭建 + OOM 修复。


新增文件清单

文件 行数 用途
specforge/core/dflash_lora.py 453 训练 wrapper(OnlineDFlashLoRAModel)
specforge/modeling/draft/dflash_lora.py 141 LoRA draft 模型(DFlashLoRADraftModel)
scripts/train_dflash_lora.py 449 训练入口脚本
scripts/run_train_dflash_lora.sh 31 启动 shell 脚本
configs/qwen3-8b-dflash-lora.json 10 LoRA 配置文件

Step 1 完成过程

1.1 分析现有代码

首先分析了非 LoRA 版 train_dflash.py 的完整流程:

input_ids → target_model.generate_dflash_data() → hidden_states
         → OnlineDFlashModel.forward():
             1. 截断到 block 边界
             2. prepare_noise_input(): anchor 保留,其余 → MASK
             3. embed_tokens(noise_input_ids) → noise_embedding
             4. 构建 DFlash attention mask
             5. draft_model(noise_embedding, target_hidden, mask)
             6. lm_head(hidden) → logits → CE loss

非 LoRA 版使用独立的小型 draft model + 冻结 target model 提取 hidden states。

1.2 确定 LoRA 版设计差异

方面 非 LoRA 版 (train_dflash.py) LoRA 版 (train_dflash_lora.py)
Draft model 自定义小模型 (1-10 层) Qwen3-8B + PEFT LoRA
Target model 冻结大模型提取 hidden states 无需 — 模型用自身表征
Attention 自定义 Qwen3DFlashAttention,KV = [ctx, noise] concat 标准 HF attention + DFlash mask
KV 结构 Q_LEN = noise_len, KV_LEN = 2×noise_len Q_LEN = KV_LEN = seq_len
可训练参数 全部 draft model 参数 仅 LoRA (q/k/v/o_proj)

1.3 新建 LoRA 版三个核心文件

specforge/modeling/draft/dflash_lora.py — DFlashLoRADraftModel

  • from_pretrained(): 加载 Qwen3-8B,注入 PEFT LoRA,支持 attn_implementation 参数
  • forward(): 标准 HF forward,支持 output_hidden_states 参数(chunked loss 需要)
  • get_lm_head(): 穿透 PEFT 层级获取 lm_head 引用
  • gradient_checkpointing_enable(): 代理到底层模型
  • save_pretrained(): 仅保存 LoRA adapter 权重

specforge/core/dflash_lora.py — OnlineDFlashLoRAModel

  • prepare_noise_input(): context 部分保持不变,block 部分只保留 anchor,其余替换为 MASK
  • build_dflash_full_attn_mask_fast(): 向量化构建 4D additive mask [bsz, 1, seq, seq]
  • _compute_loss_weights(): context + anchor 权重为 0,非 anchor 权重为 1(或 decay)
  • _full_lm_loss(): 标准 CE loss 路径
  • _compute_accuracy(): block-wise acceptance rate(累积正确预测长度 / block 非 anchor 长度)
  • forward(): 完整训练 forward pass

LoRA 版 mask 规则:

  • context token i → 因果注意力 (j ≤ i)
  • block token i (属于 block b) → 所有 context + 同 block 内双向注意力

scripts/train_dflash_lora.py — 训练脚本

  • 参数解析:model/lora/dataset/training/output/distributed/tracker 7 组参数
  • build_model(): 加载模型 + 注入 LoRA + 包装 OnlineDFlashLoRAModel
  • build_dataloader(): 复用 build_eagle3_datasetprepare_dp_dataloaders
  • FSDP 包装 + BF16Optimizer
  • 训练循环:forward → backward → accumulation → optimizer step
  • checkpoint 保存/恢复

OOM 修复改动(4 项)

改动 1: FSDP FULL_SHARD (ZeRO-3)

问题: SHARD_GRAD_OP (ZeRO-2) 每卡持有完整 Qwen3-8B 参数 (~16GB bf16)

修复: train_dflash_lora.py:362

# 之前
sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
# 之后
sharding_strategy=ShardingStrategy.FULL_SHARD

效果: 参数跨卡分片,每卡省 ~8-12GB

改动 2: batch_size=1 + accumulation_steps=8

问题: batch_size=2 时峰值显存过高

修复: run_train_dflash_lora.sh

--batch-size 1 \
--accumulation-steps 8 \

效果: 等效 global batch size 不变,峰值显存减半

改动 3: flex_attention + BlockMask 替换 4D additive mask

问题: SDPA 不支持 4D additive mask → fallback 到 math backend → 每层 materialize 完整 [bsz, 32heads, 2048, 2048] attention scores

修复: 从非 LoRA 版 dflash.py 移植 _get_or_create_block_mask() 方法,适配 LoRA 场景

涉及文件:

  1. specforge/core/dflash_lora.py

    • __init__(): 添加 attention_backend 参数(默认 "flex_attention"),BlockMask 缓存字段
    • 新增 _get_or_create_block_mask(): 用 create_block_mask() 构建零显存的 BlockMask
    • forward(): 根据 attention_backend 选择 BlockMask 或 additive mask
  2. specforge/modeling/draft/dflash_lora.py

    • from_pretrained(): 当 backend 为 flex_attention 时,传 attn_implementation="flex_attention" 给 HuggingFace
  3. scripts/train_dflash_lora.py

    • parse_args(): --attention-backend 参数 (flex_attention | additive)
    • build_model(): 根据 backend 选择 attn_implementation

BlockMask mask function(LoRA 版):

def dflash_lora_mask_fn(b, h, q_idx, kv_idx):
    # Context query: 标准因果
    is_q_ctx = q_idx < context_len
    ctx_visible = is_q_ctx & (kv_idx <= q_idx)

    # Block query: 全部 context + 同 block 双向
    is_q_block = q_idx >= context_len
    is_k_ctx = kv_idx < context_len
    q_block_id = (q_idx - context_len) // block_size
    k_block_id = (kv_idx - context_len) // block_size
    block_attend_ctx = is_q_block & is_k_ctx
    block_attend_same = is_q_block & (~is_k_ctx) & (q_block_id == k_block_id)

    return ctx_visible | (block_attend_ctx | block_attend_same)

验证: 手动逐元素对比 BlockMask 和 additive mask 输出,三组测试 (context_len=4/0, seq=12/16/64) pattern 完全一致。

效果: 不再 fallback 到 SDPA math backend,省去 [bsz, heads, seq, seq] attention scores 显存

改动 4: chunked cross-entropy loss

问题: [bsz, 2048, 151936] bf16 logits ≈ 1.18GB,加梯度 ~2.4GB+

修复: 从非 LoRA 版 dflash.py:419-478 移植 chunked loss

涉及文件:

  1. specforge/core/dflash_lora.py

    • __init__(): 添加 lm_head_chunk_size 参数(默认 0 = 不启用)
    • 新增 _chunked_lm_loss(): 分 chunk 过 lm_head + CE loss + gradient checkpointing
    • 提取 _full_lm_loss(): 原始非 chunked 路径
    • forward(): lm_head_chunk_size > 0 时走 chunked 路径
  2. specforge/modeling/draft/dflash_lora.py

    • forward(): 新增 output_hidden_states 参数,True 时返回 last hidden state 而非 logits
    • get_lm_head(): 穿透 PEFT 层级返回 base_model.lm_head 引用
  3. scripts/train_dflash_lora.py

    • parse_args(): --lm-head-chunk-size 参数(默认 0,推荐 256)
    • build_model(): 传递到 OnlineDFlashLoRAModel

Chunked loss 核心逻辑:

# 分 chunk 计算,每 chunk 用 gradient checkpointing(backward 时重算 logits,不存储)
for start in range(0, effective_len, chunk_size):
    end = min(start + chunk_size, effective_len)
    chunk_loss, chunk_weight = grad_checkpoint(
        _chunk_ce,                          # lm_head + CE
        hidden[:, start:end, :],            # 只取当前 chunk
        input_ids[:, start:end],
        combined_mask[:, start:end],
        use_reentrant=False,
    )
    total_loss += chunk_loss
    total_weight += chunk_weight
loss = total_loss / total_weight

效果: logits 峰值显存从 O(seq_len × vocab_size) 降至 O(chunk_size × vocab_size),256 chunk → ~150MB vs 1.18GB


当前训练命令

bash run_train_dflash_lora.sh 2   # 2 = GPU 数量

对应完整参数:

torchrun --nproc_per_node 2 scripts/train_dflash_lora.py \
    --model-path /workspace/Qwen3-8B \
    --train-data-path /workspace/hanrui/datasets/Nemotron-CodeAlpaca-qwen3-8b-800K \
    --output-dir outputs/qwen3-8b-dflash-lora \
    --lora-config configs/qwen3-8b-dflash-lora.json \
    --block-size 16 \
    --max-length 2048 \
    --batch-size 1 \
    --num-epochs 3 \
    --learning-rate 2e-4 \
    --accumulation-steps 8 \
    --loss-decay-gamma 7 \
    --attention-backend flex_attention \
    --lm-head-chunk-size 256 \
    --gradient-checkpointing \
    --chat-template qwen \
    --log-interval 50 \
    --save-interval 500

待验证

  • bash run_train_dflash_lora.sh 2 确认不再 OOM
  • 确认无 SDPA math fallback warning
  • 观察 GPU 显存峰值
  • 确认 loss 下降和 accuracy 上升趋势正常