File size: 2,652 Bytes
09dd617 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | # Decision Point Attention — 研究对话记录
## 项目起源
- 日期: 2026-03-23
- 目标会议: NeurIPS 2026 (DDL: May 6, 2026)
## 核心 Idea
**Decision Point Attention (DPA)**: 在 agent trajectory 中,只有 ~10-15% 的 token 是关键决策点(tool call, plan revision, error recovery)。DPA 让这些 token 走 full softmax attention,其余走 linear attention,实现 Transformer 级推理能力 + 接近线性注意力的效率。
## 动机
- MiniMax M2 放弃 linear attention 因为多轮推理失败
- Jamba (1:7), Kimi Linear (1:3) 均匀混合,浪费算力
- Agent trajectory 有明确结构,不是每步都同等重要
## 相关工作
- Routing Transformer (ICLR 2020): content-based sparse routing
- NAtS-L (2026): token-level hybrid routing
- Kimi Linear KDA: 3:1 ratio for agentic workloads
- Based (ICML 2024): recall-throughput tradeoff
- Gated DeltaNet (ICLR 2025): delta rule improves recall
- Illusion of State (ICML 2024): SSM state-tracking limitations
- RNNs are not Transformers Yet (ICLR 2025): RAG/attention closes gap
## 实验计划
1. Simulation: 模拟不同 decision ratio 下的 quality vs compute tradeoff (DONE)
2. Trajectory Analysis: 分析真实 agent trajectories 中 decision point 比例 (DONE)
3. Training: Fine-tune Qwen2.5-7B with DPA router on Merlin 8xH100 (TODO)
4. Evaluation: HotpotQA, GSM8K, ToolBench benchmarks (TODO)
5. Ablation: Router architecture, ratio sweeps, layer placement (TODO)
## 文件结构
- `src/models/router.py` — Decision Point Router (learned + fixed)
- `src/models/dpa_model.py` — DPA architecture (LinearAttn + FullAttn + Router)
- `src/models/baselines.py` — Full Transformer, Pure Linear, Uniform Hybrid
- `src/data/agent_trajectory.py` — Trajectory generator & labeling
- `src/data/datasets.py` — HotpotQA, GSM8K, ToolBench loaders
- `src/eval/benchmark.py` — Unified evaluation pipeline
- `src/eval/metrics.py` — FLOPs, latency, KV cache metrics
- `src/eval/visualize.py` — Publication figures
- `configs/` — Simulation & training configs
- `scripts/` — Run scripts for local & Merlin
## Simulation 结果 (random init, 未训练)
| Model | FLOPs Ratio | PPL |
|-------|-------------|-----|
| Full Transformer | 100% | 38905 |
| Pure Linear | 12.5% | 38134 |
| Uniform Hybrid | 27.1% | 38856 |
| DPA (10%) | 22.5% | 36799 |
| DPA (15%) | 27.5% | 37978 |
| DPA (25%) | 37.5% | 38174 |
(PPL 差异来自随机初始化,训练后会有显著差别)
## 下一步
1. 在 Merlin 上跑 `scripts/run_dpa.sh` 训练 7B 模型
2. 评估在 HotpotQA/GSM8K 上的多步推理准确率
3. 写 NeurIPS 论文(LaTeX 模板已准备)
|