# qwen3-moe-aclnn 纯 C++ 实现的 **Qwen3-235B-A22B-Instruct** BF16 推理运行时,运行于 **Ascend 910 × 16 NPU**,直接基于 aclnn EAGER 单算子 API(无图编译、无 PyTorch、无 ggml 依赖)。 English version: [README.md](README.md) --- ## 性能表现 在 Ascend 910 初代 × 16 NPU(TP=16)上、Qwen3-235B-A22B-Instruct-2507 BF16 权重实测。 所有数字均为**质量保持前提下的 TG**(输出已人工核验),greedy `temperature=0`。 | 配置 | TG | 适用 prompt | |---|---|---| | 未调优基线 | 12 t/s | 全部 | | **推荐默认**(不开 PLD) | **~27 t/s** | **全部 prompt,输出稳定** | | PLD + degeneration guard | 29-45 t/s | 结构化文本(论述、长回答) | | PLD 创意类 prompt | 25-40 t/s | 故事 / 多样生成 | | PLD 事实 / 代码类 prompt | 不稳定(21-95 t/s,方差大) | 不推荐 | 参考:`cann-recipes-infer` GE graph 方案在同硬件约 54 t/s。**本项目未超越该基线**——以峰值速度换取 (a) 无需图编译、(b) 无 PyTorch 依赖、(c) 完整的算子调度控制权。 ### 核心优化(按贡献排序) | 排名 | 优化项 | 收益 | 位置 | |---|---|---|---| | 🥇 | HCCL env 调参(`AIV` + `FFTS` + `TASK_QUEUE=2`) | +89%(12→23 t/s) | `scripts/tp_launch.sh` | | 🥈 | Fused RoPE(`aclnnApplyRotaryPosEmbV2`) | +17%(23→27 t/s) | `include/rope.h` | | 🥉 | Prompt Lookup Decoding(PLD)+ degeneration guard | 适用 prompt 上 +10-60% | `src/main_cli.cpp` | | ○ | Device-side topk-w normalize、MoE argsort、cos/sin cache | 累计 ~+15% | `include/engine.h` | | ○ | WorkspacePool(thread-local + retain-old) | 降低 malloc 开销 | `include/workspace_pool.h` | --- ## 架构 **模型**:Qwen3-235B-A22B,94 层,128 experts(top-k=8),GQA(64 Q heads / 4 KV heads),BF16。 **并行**:TP=16,HCCL ring AllReduce。KV head 每 rank 1 份(4 个 KV heads < 16 ranks,因此每 rank 上的 Q head 0-3 共享 KV head 0)。 **执行**:aclnn EAGER 模式——所有算子走 `aclnn*` 单算子 API,配合 workspace 池;无 graph capture、无 GE IR。异步 stream 执行配合 `TASK_QUEUE_ENABLE=2` 实现 kernel 提交重叠。 **Tokenizer**:encode 通过 Python 子进程调用 HuggingFace `transformers`;decode 纯 C++,从导出的 `vocab.bin` 查表。 ### 单层 forward 数据流 ``` x_in [S, D=4096] ↓ ┌── Attention 分支(TP:Q_DIM=512=4h×128,KV_DIM=128=1h×128) ──┐ │ RmsNorm(input_layernorm) │ linear_hf q_proj / k_proj / v_proj → q, k, v │ Per-head RmsNorm q_norm, k_norm │ Fused RoPE:aclnnApplyRotaryPosEmbV2(layout=1, "half") │ K、V 追加到每层 KV cache │ Mask 选择: │ prefill: 2048×2048 causal + sparse_mode=3 │ decode S=1: mask=nullptr + sparse_mode=0 │ batch decode:[1,1,S,past+S] 自定义 bool mask + sparse_mode=0 │ FIAS(aclnnFusedInferAttentionScore) │ o_proj linear_hf → per-rank partial │ HCCL AllReduce(ring + AIV + FFTS) → full └─────────┘ ↓ residual add ┌── MoE 分支 ──┐ │ RmsNorm(post_attention_layernorm) │ router linear_hf → logits [S, 128] │ moe_gating_topk_softmax → topk_w[S,8], topk_idx[S,8] │ Device-side normalize(reduce_sum + adds + cast + div) │ moe_init_routing_v3 → expanded_x, expanded_ri, tokens_per_expert │ grouped_matmul_v4 gate/up/down(SwiGLU 激活) │ Device-side argsort × 2 → 前向置换(避免 host sync) │ IndexSelect → packed │ 广播 mul topk_w + ReduceSum axis=1 │ HCCL AllReduce → full └─────────┘ ↓ residual add x_out ``` --- ## 模型权重 目标模型:**Qwen3-235B-A22B-Instruct-2507**(BF16),safetensors 分片约 **470 GB**。 **下载地址**: - HuggingFace:https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507 - ModelScope:https://www.modelscope.cn/models/Qwen/Qwen3-235B-A22B-Instruct-2507 通过 `huggingface-cli` 或 `modelscope` CLI 下载: ```bash # HuggingFace huggingface-cli download Qwen/Qwen3-235B-A22B-Instruct-2507 --local-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 # ModelScope modelscope download --model Qwen/Qwen3-235B-A22B-Instruct-2507 --local_dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 ``` **权重格式**:二进制直接读 HuggingFace `.safetensors` 分片(多 shard mmap)、`config.json`、`tokenizer.json`。**无需转换**——`--model-dir` 指向下载目录即可。 **目录结构**: ``` Qwen3-235B-A22B-Instruct-2507-BF16/ ├── config.json ├── tokenizer.json ├── tokenizer_config.json ├── model-00001-of-000XX.safetensors ├── ... └── model.safetensors.index.json ``` --- ## 编译 ```bash source /usr/local/Ascend/ascend-toolkit/set_env.sh cmake -B build cmake --build build -j8 --target qwen3-moe-aclnn ``` **依赖**: - CANN 8.5.1 或兼容版本 - Python 3 + `transformers` + `torch_npu`(仅 tokenizer 子进程和参考数据生成用) - C++17 编译器 - Ascend 910 × 16 NPU - nlohmann/json(已打包在 `external/json.hpp`) **Python 环境设置** — Tokenizer 会调用 Python 子进程。如果 conda / venv 路径与默认不一致,用 `QWEN3_PYENV_INIT` 覆盖: ```bash export QWEN3_PYENV_INIT="source /opt/my_conda/etc/profile.d/conda.sh && conda activate my_env && " ``` 未设置时,默认尝试 `${HOME}/miniconda3` + env `qwen3`,并自动 source Ascend toolkit。 --- ## 快速开始 ```bash # 1. 导出 tokenizer 词表到二进制(一次性) python3 scripts/export_vocab.py /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 # 2. 运行推理(TP=16) ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn \ --model-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 \ --prompt "The capital of France is" \ --n-predict 100 \ --temperature 0 \ --vocab tokenizer_data/vocab.bin ``` 预期:~27 t/s,输出连贯正确。 ### 按场景推荐的参数 **通用默认(稳定、任意 prompt)** — 不开 PLD: ```bash ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --temperature 0 --no-stream ``` **结构化 / 长文本(论述、说明)** — PLD + guard 提升 +60-90%: ```bash ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --pld --temperature 0 --no-stream ``` **交互式 REPL(多轮对话)**: ```bash ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... \ --interactive --chat --temperature 0.7 --top-p 0.8 ``` --- ## PLD degeneration guard Prompt Lookup Decoding 通过让模型一次性 batch verify 若干"draft"token 来加速生成,draft 来自生成历史里的 n-gram 匹配。 **已知失效模式**:在模型本身就有重复倾向的 prompt(事实问答、代码生成)上,n-gram 会把模型的重复 token 当作 draft 喂回给模型,形成**正反馈循环**——以 batch 速度加速退化输出。本项目早期曾误报由此类死循环得出的"高 TG"。 **本项目的 guard** 用两条启发式拦截可疑 draft: 1. **low-distinct**:draft 中 distinct token 数 < 阈值 → 拒绝 2. **tail-echo**:最后 N 个 hist token 全部等于 draft[0] → 拒绝 被拒 draft 走单 token decode fallback。生成末尾如出现 8 个连续相同 token,stderr 打印一次 `[warn]` 提示。 参数: ``` --pld 启用 PLD(opt-in) --pld-k N draft 窗口(默认 10) --pld-ngram N n-gram 匹配长度(默认 1,带多级回退) --pld-min-hist N hist 达到 N 个 token 前跳过 PLD(默认 20) --pld-no-guard 关闭 degeneration guard(危险,可能产生死循环) --pld-guard-distinct N draft distinct token 最小值(默认 3) --pld-guard-tail N tail-echo 检测窗口(默认 6) --pld-loop-warn N 生成 N 个连续相同 token 时报警(默认 8) ``` **诚实 benchmark**:用 `scripts/bench_pld_safe.sh`,它会自动把每次 run 的输出分类为 OK / LOOP_N / LOW_DIVERSITY,并分别统计 OK-only 与 degraded 的 TG。 --- ## 正确性验证 15+ 单元 / 集成测试,与 Python(HuggingFace Transformers)参考对比: ```bash ./build/test_attention_layer # rel=4.9e-4 vs Python prefill ./build/test_attention_decode # rel=0(bit-exact) ./build/test_moe_layer # rel=3.6e-3 ./build/test_layer_forward # 完整单层 ./build/test_runner # 多层 runner ./build/test_rope_fused # aclnnApplyRotaryPosEmbV2 vs 手写 HF rotate_half ./build/test_batch_decode # S=1..8 耗时 ./build/test_batch_correctness # argmax 一致性 ./build/test_op_support # 910 特定 op 可用性探针 # 集成冒烟: ./tests/test_chat_flow.sh # 7/7 PASS ``` 测试期望 `tests/_data/` 下存放参考数据,由 `scripts/gen_*_reference.py` 生成。各脚本顶部有 docstring 说明。 --- ## 环境变量调优(`tp_launch.sh` 自动应用) ```bash HCCL_WHITELIST_DISABLE=1 HCCL_ALGO=level0:ring # ring,非 fullmesh(fullmesh 会让输出乱码) HCCL_BUFFSIZE=200 # sweet spot;100 和 400 都更慢 HCCL_OP_EXPANSION_MODE=AIV # 关键:AI Vector cores 参与 reduce 调度 HCCL_OP_BASE_FFTS_MODE_ENABLE=1 # 关键:Fast Frequently-used Transfer Scheduling TASK_QUEUE_ENABLE=2 # 关键:激进异步任务入队 ``` 三个"关键"env 任何一个去掉都会让 TG 降 20-40%。 --- ## 目录结构 ``` include/ ├── acl_common.h RAII 包装、DeviceBuffer、make_contig_tensor ├── aclnn_ops.h 单算子 wrapper + WorkspacePool 集成 ├── acl_runtime.h AclRuntime(device + stream 管理) ├── device_weights.h safetensors → device 加载 + TP 切分 ├── engine.h attention_forward + moe_forward + RopeCache ├── hccl_comm.h HCCL init + allreduce + broadcast ├── model_config.h Qwen3 超参 + compute_derived ├── rope.h apply_rope_fused(aclnnApplyRotaryPosEmbV2 wrapper) ├── runner.h Runner 类(prefill/decode/decode_batch/rewind/profile) ├── safetensors_loader.h 多 shard safetensors mmap parser ├── tokenizer.h vocab decode + Python 子进程 encode └── workspace_pool.h thread-local aclnn workspace 池(retain-old) src/ ├── device_weights.cpp load_attention(GQA 修复)、load_moe(permute sync 修复) ├── main_cli.cpp CLI 入口 + PLD 主循环 + degeneration guard + 多轮对话 ├── model_config.cpp compute_derived(GQA KV 切分) ├── runner.cpp Runner 实现(build_batch_decode_mask_ 等) ├── safetensors_loader.cpp └── tokenizer.cpp scripts/ ├── tp_launch.sh 产线启动器(自动设置 HCCL env) ├── bench_tg.sh 稳定 N-run TG 测量 ├── bench_pld_safe.sh 带输出正确性分类器的 PLD benchmark ├── bench_hccl[_adv].sh HCCL 参数 sweep ├── bench_pld[_k].sh PLD K × ngram sweep(旧版,优先用 bench_pld_safe.sh) ├── export_vocab.py 从 HF tokenizer 导出 vocab.bin └── gen_*_reference.py 逐 op 的 Python 参考数据生成器 tests/ ├── test_attention_* attention 正确性(prefill / decode) ├── test_moe_layer MoE 正确性 ├── test_layer_forward 完整单层 ├── test_runner 多层 Runner ├── test_rope_fused fused RoPE vs 手写 HF ├── test_batch_* batch decode 耗时 + 正确性 ├── test_op_support 910 特定 op 可用性探针 └── test_chat_flow.sh 端到端集成冒烟 ``` --- ## CLI 参数参考 ``` --model-dir (必填) HF safetensors 目录 --prompt "" prompt 文本 --prompt-file FILE 从文件读 prompt(避免 shell 转义) --n-predict N 最大生成 token 数 --tp-size N tensor parallelism(也可通过 TP_SIZE env 设置) --max-seq N KV cache + 上下文上限(默认 512) --temperature F 0 = greedy;典型 0.7 --top-k N 0 = 禁用 --top-p F 1.0 = 禁用 --seed N 0 = 基于时间 --chat 应用 Qwen3 chat 模板 --system "" system role(配合 --chat) --interactive, -i REPL 模式(配合 --chat 实现多轮记忆) --reset 强制无状态 REPL(每轮重置 KV) --no-stream 批量打印最终文本,不逐 token 流式输出 --vocab vocab.bin 路径(默认 tokenizer_data/vocab.bin) --pld* 见上方 "PLD degeneration guard" 章节 ``` --- ## 已知限制 - **未达到 cann-recipes GE graph 54 t/s 基线**(当前稳定 ~27 t/s,PLD 场景最高 ~45 t/s)。 要追上基线需要以下之一:(a) 真正的图编译;(b) 融合集合算子(`MatmulAllReduce`、`GroupedMatmulAllReduce`)——910 初代没有;(c) 迁移到 910B / A2 / A3 硬件。 - **仅支持 `tp_size` ∈ {1, 2, 4, 8, 16}**。不能整除 64 Q heads 的值会报错。 - **PLD 在事实 / 代码类 prompt 上不可靠**——要么产出基线 TG(guard 拒绝了绝大多数 draft),要么进入低强度退化(classifier 未必能抓到)。用 `bench_pld_safe.sh` 诚实评估。 - **Tokenizer 依赖 Python 子进程**——首次 encode 有 ~1s 启动开销。默认 conda 路径不匹配时用 `QWEN3_PYENV_INIT` 覆盖。 - **NPU 性能 run-to-run 方差巨大**(某些配置下高达 4×),源于 BF16 + MoE 固有非决定性与硬件资源共享。报告数字时用 ≥5 runs 的中位数。 --- ## 下一步方向(按优先级) 1. **Draft Model Speculative Decoding**(Qwen3-0.6B)——比 n-gram PLD 接受率稳定得多,预期跨 prompt 类型 +60-100% TG(1-2 周工程量)。 2. **HCCL AllReduce / compute 重叠**——理论上 +10-15%,受 EAGER 路径串行依赖限制。 3. **KV cache INT8 量化**——降低 memory-bandwidth 压力,长上下文场景 +15-25%(需先验证 910 初代算子支持)。 4. **W8 权重量化**——若 910 初代有 aclnn 量化 kernel 可用,+10-20%。 不推荐: - `aclmdlRI` stream-capture 图记录(POC 证明上限仅 1.13×,工程成本不值)。 - 自研 AscendC 融合算子(高维护成本,除非有专职 kernel 工程师)。 - torchair / torch.compile 迁移(破坏 pure-C++ 设计)。 --- ## 文档 - [`docs/optimization-summary-zh.md`](docs/optimization-summary-zh.md) — 阶段性优化总结(中文):关键优化原因、PLD 正确性边界、项目级教训 - [`docs/next-steps-draft-model-speculative.md`](docs/next-steps-draft-model-speculative.md) — Draft Model Speculative Decoding(Qwen3-0.6B)执行规格:M1-M4 里程碑、正确性测试协议、风险兜底 --- ## License Apache License 2.0,见 `LICENSE`。