Hanrui / syxin /Specforge /IMPLEMENTATION_PLAN.md
Lekr0's picture
Add files using upload-large-folder tool
2d67aa6 verified
# LoRA Direct Injection Implementation Plan
## 目标
实现一个基于 LoRA 的 draft model,通过直接注入 target model 的 hidden states 来进行投机解码。
## 核心特性
1. Draft model 与 target model 结构完全相同,但附加了 LoRA adapters
2. 推理时不维护 draft model 自己的 KV cache
3. 直接将 target model 每层的 hidden state 注入到 draft model 对应层
4. 省去 dflash 中的特征提取层(fc + hidden_norm)
## 需要实现的组件
### 1. 训练阶段 (已基本完成)
- ✅ `DFlashLoRAInjectDraftModel`: 支持 layer-by-layer injection 的 draft model
- ✅ `OnlineDFlashLoRAInjectModel`: 训练时的 wrapper,需要创建
- ✅ Training script: 修改 `train_dflash_lora.py` 支持 injection 模式
### 2. 推理阶段 (需要实现)
- ⚠️ `spec_generate_with_injection`: 投机解码推理函数
- ⚠️ Target model hidden states extraction: 从 target model 提取每层 hidden states
- ⚠️ Layer-by-layer injection logic: 在 draft model forward 时注入
## 实现步骤
### Step 1: 完善训练逻辑
创建 `OnlineDFlashLoRAInjectModel`,在训练时:
- 使用 target model 生成每层的 hidden states
- 将这些 hidden states 注入到 draft model 的对应层
- 使用 dflash attention mask 进行训练
### Step 2: 实现推理逻辑
`DFlashLoRAInjectDraftModel` 中添加 `spec_generate` 方法:
- Prefill: target model 处理 prompt,保存每层 hidden states
- Decode loop:
- Draft model 基于注入的 hidden states 并行生成 block_size 个 tokens
- Target model 验证这些 tokens
- 计算 acceptance length
- 更新 target model 的 hidden states
### Step 3: 创建训练脚本
修改或创建新的训练脚本 `train_dflash_lora_inject.py`
### Step 4: 创建推理脚本
创建 `inference_dflash_lora_inject.py` 用于测试
## 关键技术点
### Hidden States Injection
```python
# 在每一层:
target_ctx = target_hidden_states[layer_idx] # [bsz, ctx_len, hidden_dim]
layer_input = torch.cat([target_ctx, draft_hidden], dim=1)
# 使用扩展的 attention mask
layer_output = layer(layer_input, attention_mask=extended_mask)
# 只保留 draft 部分的输出
draft_hidden = layer_output[:, ctx_len:, :]
```
### Attention Mask
Draft tokens 可以 attend 到:
- 所有 target context tokens
- Block 内的所有 tokens (bidirectional)
### 不维护 Draft KV Cache
- Draft model 每次都重新计算,不使用 `use_cache=True`
- 只有 target model 维护 KV cache
## 文件清单
### 需要修改的文件
1. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py`
- 添加 `spec_generate` 方法
- 完善 `_forward_with_injection` 逻辑
2. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py`
- 确保 `return_layer_hidden_states=True` 正常工作
### 需要创建的文件
1. `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.py`
- `OnlineDFlashLoRAInjectModel` 训练 wrapper
2. `/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py`
- 训练脚本
3. `/workspace/hanrui/syxin_old/Specforge/scripts/inference_dflash_lora_inject.py`
- 推理脚本