Decision Point Attention — 研究对话记录
项目起源
- 日期: 2026-03-23
- 目标会议: NeurIPS 2026 (DDL: May 6, 2026)
核心 Idea
Decision Point Attention (DPA): 在 agent trajectory 中,只有 ~10-15% 的 token 是关键决策点(tool call, plan revision, error recovery)。DPA 让这些 token 走 full softmax attention,其余走 linear attention,实现 Transformer 级推理能力 + 接近线性注意力的效率。
动机
- MiniMax M2 放弃 linear attention 因为多轮推理失败
- Jamba (1:7), Kimi Linear (1:3) 均匀混合,浪费算力
- Agent trajectory 有明确结构,不是每步都同等重要
相关工作
- Routing Transformer (ICLR 2020): content-based sparse routing
- NAtS-L (2026): token-level hybrid routing
- Kimi Linear KDA: 3:1 ratio for agentic workloads
- Based (ICML 2024): recall-throughput tradeoff
- Gated DeltaNet (ICLR 2025): delta rule improves recall
- Illusion of State (ICML 2024): SSM state-tracking limitations
- RNNs are not Transformers Yet (ICLR 2025): RAG/attention closes gap
实验计划
- Simulation: 模拟不同 decision ratio 下的 quality vs compute tradeoff (DONE)
- Trajectory Analysis: 分析真实 agent trajectories 中 decision point 比例 (DONE)
- Training: Fine-tune Qwen2.5-7B with DPA router on Merlin 8xH100 (TODO)
- Evaluation: HotpotQA, GSM8K, ToolBench benchmarks (TODO)
- Ablation: Router architecture, ratio sweeps, layer placement (TODO)
文件结构
src/models/router.py— Decision Point Router (learned + fixed)src/models/dpa_model.py— DPA architecture (LinearAttn + FullAttn + Router)src/models/baselines.py— Full Transformer, Pure Linear, Uniform Hybridsrc/data/agent_trajectory.py— Trajectory generator & labelingsrc/data/datasets.py— HotpotQA, GSM8K, ToolBench loaderssrc/eval/benchmark.py— Unified evaluation pipelinesrc/eval/metrics.py— FLOPs, latency, KV cache metricssrc/eval/visualize.py— Publication figuresconfigs/— Simulation & training configsscripts/— Run scripts for local & Merlin
Simulation 结果 (random init, 未训练)
| Model | FLOPs Ratio | PPL |
|---|---|---|
| Full Transformer | 100% | 38905 |
| Pure Linear | 12.5% | 38134 |
| Uniform Hybrid | 27.1% | 38856 |
| DPA (10%) | 22.5% | 36799 |
| DPA (15%) | 27.5% | 37978 |
| DPA (25%) | 37.5% | 38174 |
(PPL 差异来自随机初始化,训练后会有显著差别)
下一步
- 在 Merlin 上跑
scripts/run_dpa.sh训练 7B 模型 - 评估在 HotpotQA/GSM8K 上的多步推理准确率
- 写 NeurIPS 论文(LaTeX 模板已准备)