# LoRA Direct Injection Implementation Plan ## 目标 实现一个基于 LoRA 的 draft model,通过直接注入 target model 的 hidden states 来进行投机解码。 ## 核心特性 1. Draft model 与 target model 结构完全相同,但附加了 LoRA adapters 2. 推理时不维护 draft model 自己的 KV cache 3. 直接将 target model 每层的 hidden state 注入到 draft model 对应层 4. 省去 dflash 中的特征提取层(fc + hidden_norm) ## 需要实现的组件 ### 1. 训练阶段 (已基本完成) - ✅ `DFlashLoRAInjectDraftModel`: 支持 layer-by-layer injection 的 draft model - ✅ `OnlineDFlashLoRAInjectModel`: 训练时的 wrapper,需要创建 - ✅ Training script: 修改 `train_dflash_lora.py` 支持 injection 模式 ### 2. 推理阶段 (需要实现) - ⚠️ `spec_generate_with_injection`: 投机解码推理函数 - ⚠️ Target model hidden states extraction: 从 target model 提取每层 hidden states - ⚠️ Layer-by-layer injection logic: 在 draft model forward 时注入 ## 实现步骤 ### Step 1: 完善训练逻辑 创建 `OnlineDFlashLoRAInjectModel`,在训练时: - 使用 target model 生成每层的 hidden states - 将这些 hidden states 注入到 draft model 的对应层 - 使用 dflash attention mask 进行训练 ### Step 2: 实现推理逻辑 在 `DFlashLoRAInjectDraftModel` 中添加 `spec_generate` 方法: - Prefill: target model 处理 prompt,保存每层 hidden states - Decode loop: - Draft model 基于注入的 hidden states 并行生成 block_size 个 tokens - Target model 验证这些 tokens - 计算 acceptance length - 更新 target model 的 hidden states ### Step 3: 创建训练脚本 修改或创建新的训练脚本 `train_dflash_lora_inject.py` ### Step 4: 创建推理脚本 创建 `inference_dflash_lora_inject.py` 用于测试 ## 关键技术点 ### Hidden States Injection ```python # 在每一层: target_ctx = target_hidden_states[layer_idx] # [bsz, ctx_len, hidden_dim] layer_input = torch.cat([target_ctx, draft_hidden], dim=1) # 使用扩展的 attention mask layer_output = layer(layer_input, attention_mask=extended_mask) # 只保留 draft 部分的输出 draft_hidden = layer_output[:, ctx_len:, :] ``` ### Attention Mask Draft tokens 可以 attend 到: - 所有 target context tokens - Block 内的所有 tokens (bidirectional) ### 不维护 Draft KV Cache - Draft model 每次都重新计算,不使用 `use_cache=True` - 只有 target model 维护 KV cache ## 文件清单 ### 需要修改的文件 1. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py` - 添加 `spec_generate` 方法 - 完善 `_forward_with_injection` 逻辑 2. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py` - 确保 `return_layer_hidden_states=True` 正常工作 ### 需要创建的文件 1. `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.py` - `OnlineDFlashLoRAInjectModel` 训练 wrapper 2. `/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py` - 训练脚本 3. `/workspace/hanrui/syxin_old/Specforge/scripts/inference_dflash_lora_inject.py` - 推理脚本