decision-point-attention / CONVERSATION_LOG.md
jasonfan's picture
Upload folder using huggingface_hub
09dd617 verified

Decision Point Attention — 研究对话记录

项目起源

  • 日期: 2026-03-23
  • 目标会议: NeurIPS 2026 (DDL: May 6, 2026)

核心 Idea

Decision Point Attention (DPA): 在 agent trajectory 中,只有 ~10-15% 的 token 是关键决策点(tool call, plan revision, error recovery)。DPA 让这些 token 走 full softmax attention,其余走 linear attention,实现 Transformer 级推理能力 + 接近线性注意力的效率。

动机

  • MiniMax M2 放弃 linear attention 因为多轮推理失败
  • Jamba (1:7), Kimi Linear (1:3) 均匀混合,浪费算力
  • Agent trajectory 有明确结构,不是每步都同等重要

相关工作

  • Routing Transformer (ICLR 2020): content-based sparse routing
  • NAtS-L (2026): token-level hybrid routing
  • Kimi Linear KDA: 3:1 ratio for agentic workloads
  • Based (ICML 2024): recall-throughput tradeoff
  • Gated DeltaNet (ICLR 2025): delta rule improves recall
  • Illusion of State (ICML 2024): SSM state-tracking limitations
  • RNNs are not Transformers Yet (ICLR 2025): RAG/attention closes gap

实验计划

  1. Simulation: 模拟不同 decision ratio 下的 quality vs compute tradeoff (DONE)
  2. Trajectory Analysis: 分析真实 agent trajectories 中 decision point 比例 (DONE)
  3. Training: Fine-tune Qwen2.5-7B with DPA router on Merlin 8xH100 (TODO)
  4. Evaluation: HotpotQA, GSM8K, ToolBench benchmarks (TODO)
  5. Ablation: Router architecture, ratio sweeps, layer placement (TODO)

文件结构

  • src/models/router.py — Decision Point Router (learned + fixed)
  • src/models/dpa_model.py — DPA architecture (LinearAttn + FullAttn + Router)
  • src/models/baselines.py — Full Transformer, Pure Linear, Uniform Hybrid
  • src/data/agent_trajectory.py — Trajectory generator & labeling
  • src/data/datasets.py — HotpotQA, GSM8K, ToolBench loaders
  • src/eval/benchmark.py — Unified evaluation pipeline
  • src/eval/metrics.py — FLOPs, latency, KV cache metrics
  • src/eval/visualize.py — Publication figures
  • configs/ — Simulation & training configs
  • scripts/ — Run scripts for local & Merlin

Simulation 结果 (random init, 未训练)

Model FLOPs Ratio PPL
Full Transformer 100% 38905
Pure Linear 12.5% 38134
Uniform Hybrid 27.1% 38856
DPA (10%) 22.5% 36799
DPA (15%) 27.5% 37978
DPA (25%) 37.5% 38174

(PPL 差异来自随机初始化,训练后会有显著差别)

下一步

  1. 在 Merlin 上跑 scripts/run_dpa.sh 训练 7B 模型
  2. 评估在 HotpotQA/GSM8K 上的多步推理准确率
  3. 写 NeurIPS 论文(LaTeX 模板已准备)