File size: 3,708 Bytes
e34b94f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
## Training

### Weaver 训练机制(基于 `larm/memory_generator/memgen_model.py: LatentMemoryModel.forward`)

- **训练目标**: 在指定插入点将 weaver 生成的潜在向量注入到 reasoner 的隐空间中,增强解码路径;仅对“非潜在位置”做语言模型监督,梯度主要更新 weaver 与映射层。
- **两类前向**:
  - **单轮 SFT**: `_instructional_forward` → 核心 `_forward`,在“提示结束处”与“推理分隔符后”插入潜在向量。
  - **多轮 SFT**: `_conversational_forward` 将对话拆成多个监督段,每段独立走一次 `_forward`,最后拼回整条序列。
  - **GRPO**: 由 `WeaverGRPOTrainer` 负责采样、奖励与策略优化,但插入与映射机制与 SFT 一致。

#### 关键入口
- 函数签名:`LatentMemoryModel.forward(input_ids, attention_mask, labels, **kwargs)`
- 路由:判断是否为对话格式,选择 `_instructional_forward`(单轮)或 `_conversational_forward`(多轮),最终统一计算自回归 LM 损失(忽略 `-100` 标签)。

#### 单轮训练流程(`_instructional_forward` → `_forward`)
1) **选择插入点**
   - 1 个“提示插入点”:label 从 `-100 → 有效` 的边界(单轮必须恰好一个)。
   - 若干“推理插入点”:在监督区间内且位于分隔符 `[",", ".", "\n"]` 之后,数量由 `max_inference_aug_num` 限制。

2) **双向投影与插入**
   - 用 reasoner 的 embedding 得到 `inputs_embeds`,逐段累加到当前序列。
   - 将当前累积的 `inputs_embeds` 投影到 weaver 隐空间(`reasoner_to_weaver`),调用:
     - `weaver.augment_prompt(...)`(提示处),或
     - `weaver.augment_inference(...)`(推理处),
     取 weaver 最后一层在新拼接位置的隐向量作为“潜在向量”。
   - 将潜在向量投影回 reasoner 隐空间(`weaver_to_reasoner`),并与当前序列拼接,同时同步 `attention_mask/position_ids`。

3) **屏蔽潜在位置监督**
   - reasoner 在“插入后序列”上前向得到 `logits`   - 通过位移后的潜在掩码剔除“潜在位置”的 `logits`,仅保留“非潜在位置”的监督目标。

4) **语言模型损失(忽略 -100)**
```python
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
loss = CrossEntropyLoss(ignore_index=-100)(
    shift_logits.reshape(-1, shift_logits.size(-1)),
    shift_labels.reshape(-1),
)
```
- **梯度路径**: 非潜在位置的 CE 损失 → `weaver_to_reasoner` → weaver → `reasoner_to_weaver`(reasoner 通常冻结)。

#### 多轮训练流程(`_conversational_forward`)
-`labels` 找到多个监督段(assistant 回复区间),逐段截取“到段末”的子序列分别调用一次 `_forward`(各段的潜在向量互不“泄漏”),把每段的 `logits` 回填到整序列对应区间,拼接后统一计算 LM 损失(忽略 `-100`)。

#### 与训练方式(SFT vs GRPO)的关系
- **SFT**: 使用上述 CE 损失;weaver 查询向量在“提示结束 + 推理分隔符后”插入。
- **GRPO**: `WeaverGRPOTrainer` 负责采样/奖励/优势计算与策略优化;生成阶段仍通过 weaver 插入与双向投影影响 reasoner 的解码。

#### 关键要点小结
- **插入点选择**: 1 个提示插入 + 若干推理插入(分隔符后,限额)。
- **双向投影**: Reasoner 隐空间 ↔ Weaver 隐空间。
- **监督屏蔽**: 仅监督“非潜在位置”,潜在位置不参与 loss。
- **多轮拆段**: 段内独立插入与前向,段间不共享潜在。
- **参数更新**: 主要更新 weaver 与映射层;reasoner 默认冻结(除非显式解冻/PEFT)。