LoRA Direct Injection Implementation Plan
目标
实现一个基于 LoRA 的 draft model,通过直接注入 target model 的 hidden states 来进行投机解码。
核心特性
- Draft model 与 target model 结构完全相同,但附加了 LoRA adapters
- 推理时不维护 draft model 自己的 KV cache
- 直接将 target model 每层的 hidden state 注入到 draft model 对应层
- 省去 dflash 中的特征提取层(fc + hidden_norm)
需要实现的组件
1. 训练阶段 (已基本完成)
- ✅
DFlashLoRAInjectDraftModel: 支持 layer-by-layer injection 的 draft model - ✅
OnlineDFlashLoRAInjectModel: 训练时的 wrapper,需要创建 - ✅ Training script: 修改
train_dflash_lora.py支持 injection 模式
2. 推理阶段 (需要实现)
- ⚠️
spec_generate_with_injection: 投机解码推理函数 - ⚠️ Target model hidden states extraction: 从 target model 提取每层 hidden states
- ⚠️ Layer-by-layer injection logic: 在 draft model forward 时注入
实现步骤
Step 1: 完善训练逻辑
创建 OnlineDFlashLoRAInjectModel,在训练时:
- 使用 target model 生成每层的 hidden states
- 将这些 hidden states 注入到 draft model 的对应层
- 使用 dflash attention mask 进行训练
Step 2: 实现推理逻辑
在 DFlashLoRAInjectDraftModel 中添加 spec_generate 方法:
- Prefill: target model 处理 prompt,保存每层 hidden states
- Decode loop:
- Draft model 基于注入的 hidden states 并行生成 block_size 个 tokens
- Target model 验证这些 tokens
- 计算 acceptance length
- 更新 target model 的 hidden states
Step 3: 创建训练脚本
修改或创建新的训练脚本 train_dflash_lora_inject.py
Step 4: 创建推理脚本
创建 inference_dflash_lora_inject.py 用于测试
关键技术点
Hidden States Injection
# 在每一层:
target_ctx = target_hidden_states[layer_idx] # [bsz, ctx_len, hidden_dim]
layer_input = torch.cat([target_ctx, draft_hidden], dim=1)
# 使用扩展的 attention mask
layer_output = layer(layer_input, attention_mask=extended_mask)
# 只保留 draft 部分的输出
draft_hidden = layer_output[:, ctx_len:, :]
Attention Mask
Draft tokens 可以 attend 到:
- 所有 target context tokens
- Block 内的所有 tokens (bidirectional)
不维护 Draft KV Cache
- Draft model 每次都重新计算,不使用
use_cache=True - 只有 target model 维护 KV cache
文件清单
需要修改的文件
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py- 添加
spec_generate方法 - 完善
_forward_with_injection逻辑
- 添加
/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py- 确保
return_layer_hidden_states=True正常工作
- 确保
需要创建的文件
/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.pyOnlineDFlashLoRAInjectModel训练 wrapper
/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py- 训练脚本
/workspace/hanrui/syxin_old/Specforge/scripts/inference_dflash_lora_inject.py- 推理脚本