Hanrui / syxin /Specforge /IMPLEMENTATION_PLAN.md
Lekr0's picture
Add files using upload-large-folder tool
2d67aa6 verified

LoRA Direct Injection Implementation Plan

目标

实现一个基于 LoRA 的 draft model,通过直接注入 target model 的 hidden states 来进行投机解码。

核心特性

  1. Draft model 与 target model 结构完全相同,但附加了 LoRA adapters
  2. 推理时不维护 draft model 自己的 KV cache
  3. 直接将 target model 每层的 hidden state 注入到 draft model 对应层
  4. 省去 dflash 中的特征提取层(fc + hidden_norm)

需要实现的组件

1. 训练阶段 (已基本完成)

  • DFlashLoRAInjectDraftModel: 支持 layer-by-layer injection 的 draft model
  • OnlineDFlashLoRAInjectModel: 训练时的 wrapper,需要创建
  • ✅ Training script: 修改 train_dflash_lora.py 支持 injection 模式

2. 推理阶段 (需要实现)

  • ⚠️ spec_generate_with_injection: 投机解码推理函数
  • ⚠️ Target model hidden states extraction: 从 target model 提取每层 hidden states
  • ⚠️ Layer-by-layer injection logic: 在 draft model forward 时注入

实现步骤

Step 1: 完善训练逻辑

创建 OnlineDFlashLoRAInjectModel,在训练时:

  • 使用 target model 生成每层的 hidden states
  • 将这些 hidden states 注入到 draft model 的对应层
  • 使用 dflash attention mask 进行训练

Step 2: 实现推理逻辑

DFlashLoRAInjectDraftModel 中添加 spec_generate 方法:

  • Prefill: target model 处理 prompt,保存每层 hidden states
  • Decode loop:
    • Draft model 基于注入的 hidden states 并行生成 block_size 个 tokens
    • Target model 验证这些 tokens
    • 计算 acceptance length
    • 更新 target model 的 hidden states

Step 3: 创建训练脚本

修改或创建新的训练脚本 train_dflash_lora_inject.py

Step 4: 创建推理脚本

创建 inference_dflash_lora_inject.py 用于测试

关键技术点

Hidden States Injection

# 在每一层:
target_ctx = target_hidden_states[layer_idx]  # [bsz, ctx_len, hidden_dim]
layer_input = torch.cat([target_ctx, draft_hidden], dim=1)
# 使用扩展的 attention mask
layer_output = layer(layer_input, attention_mask=extended_mask)
# 只保留 draft 部分的输出
draft_hidden = layer_output[:, ctx_len:, :]

Attention Mask

Draft tokens 可以 attend 到:

  • 所有 target context tokens
  • Block 内的所有 tokens (bidirectional)

不维护 Draft KV Cache

  • Draft model 每次都重新计算,不使用 use_cache=True
  • 只有 target model 维护 KV cache

文件清单

需要修改的文件

  1. /workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py

    • 添加 spec_generate 方法
    • 完善 _forward_with_injection 逻辑
  2. /workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py

    • 确保 return_layer_hidden_states=True 正常工作

需要创建的文件

  1. /workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.py

    • OnlineDFlashLoRAInjectModel 训练 wrapper
  2. /workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py

    • 训练脚本
  3. /workspace/hanrui/syxin_old/Specforge/scripts/inference_dflash_lora_inject.py

    • 推理脚本