File size: 3,222 Bytes
2d67aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# LoRA Direct Injection Implementation Plan

## 目标
实现一个基于 LoRA 的 draft model,通过直接注入 target model 的 hidden states 来进行投机解码。

## 核心特性
1. Draft model 与 target model 结构完全相同,但附加了 LoRA adapters
2. 推理时不维护 draft model 自己的 KV cache
3. 直接将 target model 每层的 hidden state 注入到 draft model 对应层
4. 省去 dflash 中的特征提取层(fc + hidden_norm)

## 需要实现的组件

### 1. 训练阶段 (已基本完成)
- ✅ `DFlashLoRAInjectDraftModel`: 支持 layer-by-layer injection 的 draft model
- ✅ `OnlineDFlashLoRAInjectModel`: 训练时的 wrapper,需要创建
- ✅ Training script: 修改 `train_dflash_lora.py` 支持 injection 模式

### 2. 推理阶段 (需要实现)
- ⚠️ `spec_generate_with_injection`: 投机解码推理函数
- ⚠️ Target model hidden states extraction: 从 target model 提取每层 hidden states
- ⚠️ Layer-by-layer injection logic: 在 draft model forward 时注入

## 实现步骤

### Step 1: 完善训练逻辑
创建 `OnlineDFlashLoRAInjectModel`,在训练时:
- 使用 target model 生成每层的 hidden states
- 将这些 hidden states 注入到 draft model 的对应层
- 使用 dflash attention mask 进行训练

### Step 2: 实现推理逻辑`DFlashLoRAInjectDraftModel` 中添加 `spec_generate` 方法:
- Prefill: target model 处理 prompt,保存每层 hidden states
- Decode loop:
  - Draft model 基于注入的 hidden states 并行生成 block_size 个 tokens
  - Target model 验证这些 tokens
  - 计算 acceptance length
  - 更新 target model 的 hidden states

### Step 3: 创建训练脚本
修改或创建新的训练脚本 `train_dflash_lora_inject.py`

### Step 4: 创建推理脚本
创建 `inference_dflash_lora_inject.py` 用于测试

## 关键技术点

### Hidden States Injection
```python
# 在每一层:
target_ctx = target_hidden_states[layer_idx]  # [bsz, ctx_len, hidden_dim]
layer_input = torch.cat([target_ctx, draft_hidden], dim=1)
# 使用扩展的 attention mask
layer_output = layer(layer_input, attention_mask=extended_mask)
# 只保留 draft 部分的输出
draft_hidden = layer_output[:, ctx_len:, :]
```

### Attention Mask
Draft tokens 可以 attend 到:
- 所有 target context tokens
- Block 内的所有 tokens (bidirectional)

### 不维护 Draft KV Cache
- Draft model 每次都重新计算,不使用 `use_cache=True`
- 只有 target model 维护 KV cache

## 文件清单

### 需要修改的文件
1. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/draft/dflash_lora_inject.py`
   - 添加 `spec_generate` 方法
   - 完善 `_forward_with_injection` 逻辑

2. `/workspace/hanrui/syxin_old/Specforge/specforge/modeling/target/dflash_target_model.py`
   - 确保 `return_layer_hidden_states=True` 正常工作

### 需要创建的文件
1. `/workspace/hanrui/syxin_old/Specforge/specforge/core/dflash_lora_inject.py`
   - `OnlineDFlashLoRAInjectModel` 训练 wrapper

2. `/workspace/hanrui/syxin_old/Specforge/scripts/train_dflash_lora_inject.py`
   - 训练脚本

3. `/workspace/hanrui/syxin_old/Specforge/scripts/inference_dflash_lora_inject.py`
   - 推理脚本