| # Decision Point Attention — 研究对话记录 |
|
|
| ## 项目起源 |
| - 日期: 2026-03-23 |
| - 目标会议: NeurIPS 2026 (DDL: May 6, 2026) |
|
|
| ## 核心 Idea |
| **Decision Point Attention (DPA)**: 在 agent trajectory 中,只有 ~10-15% 的 token 是关键决策点(tool call, plan revision, error recovery)。DPA 让这些 token 走 full softmax attention,其余走 linear attention,实现 Transformer 级推理能力 + 接近线性注意力的效率。 |
|
|
| ## 动机 |
| - MiniMax M2 放弃 linear attention 因为多轮推理失败 |
| - Jamba (1:7), Kimi Linear (1:3) 均匀混合,浪费算力 |
| - Agent trajectory 有明确结构,不是每步都同等重要 |
|
|
| ## 相关工作 |
| - Routing Transformer (ICLR 2020): content-based sparse routing |
| - NAtS-L (2026): token-level hybrid routing |
| - Kimi Linear KDA: 3:1 ratio for agentic workloads |
| - Based (ICML 2024): recall-throughput tradeoff |
| - Gated DeltaNet (ICLR 2025): delta rule improves recall |
| - Illusion of State (ICML 2024): SSM state-tracking limitations |
| - RNNs are not Transformers Yet (ICLR 2025): RAG/attention closes gap |
|
|
| ## 实验计划 |
| 1. Simulation: 模拟不同 decision ratio 下的 quality vs compute tradeoff (DONE) |
| 2. Trajectory Analysis: 分析真实 agent trajectories 中 decision point 比例 (DONE) |
| 3. Training: Fine-tune Qwen2.5-7B with DPA router on Merlin 8xH100 (TODO) |
| 4. Evaluation: HotpotQA, GSM8K, ToolBench benchmarks (TODO) |
| 5. Ablation: Router architecture, ratio sweeps, layer placement (TODO) |
|
|
| ## 文件结构 |
| - `src/models/router.py` — Decision Point Router (learned + fixed) |
| - `src/models/dpa_model.py` — DPA architecture (LinearAttn + FullAttn + Router) |
| - `src/models/baselines.py` — Full Transformer, Pure Linear, Uniform Hybrid |
| - `src/data/agent_trajectory.py` — Trajectory generator & labeling |
| - `src/data/datasets.py` — HotpotQA, GSM8K, ToolBench loaders |
| - `src/eval/benchmark.py` — Unified evaluation pipeline |
| - `src/eval/metrics.py` — FLOPs, latency, KV cache metrics |
| - `src/eval/visualize.py` — Publication figures |
| - `configs/` — Simulation & training configs |
| - `scripts/` — Run scripts for local & Merlin |
|
|
| ## Simulation 结果 (random init, 未训练) |
| | Model | FLOPs Ratio | PPL | |
| |-------|-------------|-----| |
| | Full Transformer | 100% | 38905 | |
| | Pure Linear | 12.5% | 38134 | |
| | Uniform Hybrid | 27.1% | 38856 | |
| | DPA (10%) | 22.5% | 36799 | |
| | DPA (15%) | 27.5% | 37978 | |
| | DPA (25%) | 37.5% | 38174 | |
|
|
| (PPL 差异来自随机初始化,训练后会有显著差别) |
|
|
| ## 下一步 |
| 1. 在 Merlin 上跑 `scripts/run_dpa.sh` 训练 7B 模型 |
| 2. 评估在 HotpotQA/GSM8K 上的多步推理准确率 |
| 3. 写 NeurIPS 论文(LaTeX 模板已准备) |
|
|