# qwen3-moe-aclnn Pure C++ inference of **Qwen3-235B-A22B-Instruct** BF16 on **Ascend 910 × 16 NPU**, built directly on the aclnn EAGER API (no graph compilation, no PyTorch, no ggml). 中文版本:[README_zh.md](README_zh.md) --- ## Performance Measured on Ascend 910 initial-gen × 16 NPU (TP=16) with Qwen3-235B-A22B-Instruct-2507 BF16 weights. All numbers are **quality-preserving TG** (output was manually verified); greedy `temperature=0`. | Configuration | TG | Applicable prompts | |---|---|---| | Untuned baseline | 12 t/s | All | | **Default recommended** (no PLD) | **~27 t/s** | **All prompts, stable output** | | PLD with degeneration guard | 29-45 t/s | Structured text (essays, long-form answers) | | PLD on creative prompts | 25-40 t/s | Stories / varied generation | | PLD on factual / code prompts | unstable (21-95 t/s, high variance) | Not recommended | Reference: `cann-recipes-infer` GE graph baseline reports ~54 t/s on the same hardware. **This project does not exceed that baseline** — it trades some peak speed for (a) no graph compilation, (b) no PyTorch dependency, (c) full control over operator scheduling. ### Key optimizations that contributed (in order of magnitude) | Rank | Optimization | Gain | Where | |---|---|---|---| | 🥇 | HCCL env tuning (`AIV` + `FFTS` + `TASK_QUEUE=2`) | +89% (12→23 t/s) | `scripts/tp_launch.sh` | | 🥈 | Fused RoPE via `aclnnApplyRotaryPosEmbV2` | +17% (23→27 t/s) | `include/rope.h` | | 🥉 | Prompt Lookup Decoding (PLD) w/ degeneration guard | +10-60% on applicable prompts | `src/main_cli.cpp` | | ○ | Device-side topk-w normalize, MoE argsort, cos/sin cache | ~+15% cumulative | `include/engine.h` | | ○ | WorkspacePool (thread-local + retain-old) | reduces alloc overhead | `include/workspace_pool.h` | --- ## Architecture **Model**: Qwen3-235B-A22B, 94 layers, 128 experts (top-k=8), GQA (64 Q heads, 4 KV heads), BF16. **Parallelism**: TP=16 via HCCL ring AllReduce. KV heads sharded 1-per-rank (since 4 KV heads < 16 ranks, Q heads 0-3 on each rank share KV head 0). **Execution**: aclnn EAGER mode — every op goes through `aclnn*` single-op API with workspace pool; no graph capture, no GE IR. Async stream execution with `TASK_QUEUE_ENABLE=2` for kernel submission overlap. **Tokenizer**: Uses HuggingFace `transformers` via a Python subprocess for encoding; vocab decode is pure C++ from an exported `vocab.bin`. ### Per-layer forward flow ``` x_in [S, D=4096] ↓ ┌── Attention branch (TP: Q_DIM=512=4h×128, KV_DIM=128=1h×128) ──┐ │ RmsNorm(input_layernorm) │ linear_hf q_proj / k_proj / v_proj → q, k, v │ Per-head RmsNorm q_norm, k_norm │ Fused RoPE: aclnnApplyRotaryPosEmbV2 (layout=1, "half") │ Append K, V to per-layer KV cache │ Mask selection: │ prefill: 2048×2048 causal + sparse_mode=3 │ decode S=1: nullptr + sparse_mode=0 │ batch decode: [1,1,S,past+S] custom bool mask + sparse_mode=0 │ FIAS (aclnnFusedInferAttentionScore) │ o_proj linear_hf → partial per-rank │ HCCL AllReduce (ring + AIV + FFTS) → full └─────────┘ ↓ residual add ┌── MoE branch ──┐ │ RmsNorm(post_attention_layernorm) │ router linear_hf → logits [S, 128] │ moe_gating_topk_softmax → topk_w[S,8], topk_idx[S,8] │ Device-side normalize (reduce_sum + adds + cast + div) │ moe_init_routing_v3 → expanded_x, expanded_ri, tokens_per_expert │ grouped_matmul_v4 gate/up/down (SwiGLU activation) │ Device-side argsort × 2 → fwd permutation (avoids host sync) │ IndexSelect → packed │ Broadcast-mul by topk_w + ReduceSum axis=1 │ HCCL AllReduce → full └─────────┘ ↓ residual add x_out ``` --- ## Model weights This project targets **Qwen3-235B-A22B-Instruct-2507** (BF16). About **470 GB** of safetensors shards. **Download sources**: - HuggingFace: https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507 - ModelScope: https://www.modelscope.cn/models/Qwen/Qwen3-235B-A22B-Instruct-2507 Download via `huggingface-cli` or `modelscope` CLI: ```bash # HuggingFace huggingface-cli download Qwen/Qwen3-235B-A22B-Instruct-2507 --local-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 # ModelScope modelscope download --model Qwen/Qwen3-235B-A22B-Instruct-2507 --local_dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 ``` **Weights format**: the binary reads HuggingFace `.safetensors` shards (multi-shard mmap), `config.json`, and `tokenizer.json` directly from the model directory. No conversion step is needed — point `--model-dir` at the downloaded directory. **Expected directory contents**: ``` Qwen3-235B-A22B-Instruct-2507-BF16/ ├── config.json ├── tokenizer.json ├── tokenizer_config.json ├── model-00001-of-000XX.safetensors ├── ... └── model.safetensors.index.json ``` --- ## Build ```bash source /usr/local/Ascend/ascend-toolkit/set_env.sh cmake -B build cmake --build build -j8 --target qwen3-moe-aclnn ``` **Requires**: - CANN 8.5.1 or compatible - Python 3 + `transformers` + `torch_npu` (for tokenizer subprocess and reference-data generation only) - C++17 compiler - Ascend 910 × 16 NPU - nlohmann/json (bundled as `external/json.hpp`) **Python environment setup** — the tokenizer calls a Python subprocess. Override the activation command via `QWEN3_PYENV_INIT` if your conda / venv layout differs from the default: ```bash export QWEN3_PYENV_INIT="source /opt/my_conda/etc/profile.d/conda.sh && conda activate my_env && " ``` If unset, the default tries `${HOME}/miniconda3` with env `qwen3` and auto-sources the Ascend toolkit. --- ## Quick-start inference ```bash # 1. Export tokenizer vocab to binary (one-time setup) python3 scripts/export_vocab.py /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 # 2. Run inference (TP=16) ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn \ --model-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 \ --prompt "The capital of France is" \ --n-predict 100 \ --temperature 0 \ --vocab tokenizer_data/vocab.bin ``` Expected: ~27 t/s, coherent output. ### Recommended flags by use case **Universal default (stable, any prompt)** — no PLD: ```bash ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --temperature 0 --no-stream ``` **Structured / long-form (essays, explanations)** — PLD with guard gives +60-90%: ```bash ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --pld --temperature 0 --no-stream ``` **Interactive REPL (multi-turn chat)**: ```bash ./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... \ --interactive --chat --temperature 0.7 --top-p 0.8 ``` --- ## PLD degeneration guard Prompt Lookup Decoding speeds up generation by having the model verify a batch of "draft" tokens in a single forward pass. The drafts are copied from the generation history via n-gram match. **Known failure mode**: on prompts the model tends to repeat on (factual Q&A, code generation), the n-gram match feeds the model's own repetition back as drafts, creating a positive feedback loop that accelerates degenerate output. Early versions of this project reported misleading peak TG numbers driven by this loop. **This project's guard** blocks suspect drafts with two heuristics: 1. **low-distinct**: draft's distinct-token count < threshold → reject 2. **tail-echo**: all of last N hist tokens equal draft[0] → reject Rejected drafts fall back to single-token decode. A `[warn]` line is emitted once if the generated tail shows 8 consecutive identical tokens. Flags: ``` --pld enable PLD (opt-in) --pld-k N draft window size (default: 10) --pld-ngram N n-gram match size (default: 1, with multi-level fallback) --pld-min-hist N skip PLD until history >= N tokens (default: 20) --pld-no-guard disable the degeneration guard (dangerous: can produce dead loops) --pld-guard-distinct N minimum distinct tokens in draft (default: 3) --pld-guard-tail N tail-echo window (default: 6) --pld-loop-warn N emit warning on N consecutive identical tokens (default: 8) ``` **Honest benchmarking**: use `scripts/bench_pld_safe.sh`, which classifies each run's output as OK / LOOP_N / LOW_DIVERSITY and separates TG statistics for OK-only vs degraded runs. --- ## Correctness verification 15+ unit / integration tests checked against Python (HuggingFace Transformers) reference: ```bash ./build/test_attention_layer # rel=4.9e-4 vs Python prefill ./build/test_attention_decode # rel=0 (bit-exact) ./build/test_moe_layer # rel=3.6e-3 ./build/test_layer_forward # full single layer ./build/test_runner # multi-layer runner ./build/test_rope_fused # aclnnApplyRotaryPosEmbV2 vs manual HF rotate_half ./build/test_batch_decode # S=1..8 timing ./build/test_batch_correctness # argmax consistency ./build/test_op_support # 910-specific op availability # Integration smoke: ./tests/test_chat_flow.sh # 7/7 PASS ``` Tests expect reference data under `tests/_data/` generated by `scripts/gen_*_reference.py`. See each script's docstring. --- ## Environment tuning (auto-applied by `tp_launch.sh`) ```bash HCCL_WHITELIST_DISABLE=1 HCCL_ALGO=level0:ring # ring, not fullmesh (fullmesh causes garbled output) HCCL_BUFFSIZE=200 # sweet spot; 100 and 400 both slower HCCL_OP_EXPANSION_MODE=AIV # key: AI Vector cores participate in reduce scheduling HCCL_OP_BASE_FFTS_MODE_ENABLE=1 # key: Fast Frequently-used Transfer Scheduling TASK_QUEUE_ENABLE=2 # key: aggressive async task submission ``` Removing any of the three "key" env vars drops TG by 20-40%. --- ## Directory layout ``` include/ ├── acl_common.h RAII wrappers, DeviceBuffer, make_contig_tensor ├── aclnn_ops.h single-op wrappers + WorkspacePool integration ├── acl_runtime.h AclRuntime (device + stream management) ├── device_weights.h safetensors → device loading + TP sharding ├── engine.h attention_forward + moe_forward + RopeCache ├── hccl_comm.h HCCL init + allreduce + broadcast ├── model_config.h Qwen3 hyperparameters + compute_derived ├── rope.h apply_rope_fused (aclnnApplyRotaryPosEmbV2 wrapper) ├── runner.h Runner class (prefill/decode/decode_batch/rewind/profile) ├── safetensors_loader.h multi-shard safetensors mmap parser ├── tokenizer.h vocab decode + Python subprocess encode └── workspace_pool.h thread-local aclnn workspace pool (retain-old) src/ ├── device_weights.cpp load_attention (GQA fix), load_moe (permute sync fix) ├── main_cli.cpp CLI entry + PLD main loop + degeneration guard + multi-turn ├── model_config.cpp compute_derived (GQA KV sharding) ├── runner.cpp Runner (build_batch_decode_mask_ etc.) ├── safetensors_loader.cpp └── tokenizer.cpp scripts/ ├── tp_launch.sh production launcher (auto-applies HCCL env) ├── bench_tg.sh stable N-run TG measurement ├── bench_pld_safe.sh PLD benchmark with output-correctness classifier ├── bench_hccl[_adv].sh HCCL parameter sweep ├── bench_pld[_k].sh PLD K × ngram sweep (legacy, prefer bench_pld_safe.sh) ├── export_vocab.py vocab.bin exporter from HF tokenizer └── gen_*_reference.py per-op Python reference data generators tests/ ├── test_attention_* attention correctness (prefill / decode) ├── test_moe_layer MoE correctness ├── test_layer_forward full single layer ├── test_runner multi-layer Runner ├── test_rope_fused fused RoPE vs manual HF ├── test_batch_* batch decode timing + correctness ├── test_op_support 910-specific op availability probe └── test_chat_flow.sh end-to-end integration smoke ``` --- ## CLI reference ``` --model-dir (required) HF safetensors directory --prompt "" prompt text --prompt-file FILE read prompt from file (avoids shell-escape issues) --n-predict N maximum tokens to generate --tp-size N tensor parallelism (or set TP_SIZE env) --max-seq N KV cache + context cap (default: 512) --temperature F 0 = greedy; typical 0.7 --top-k N 0 = disabled --top-p F 1.0 = disabled --seed N 0 = time-based --chat apply Qwen3 chat template --system "" system role text (with --chat) --interactive, -i REPL mode (multi-turn memory with --chat) --reset force stateless REPL (reset KV between turns) --no-stream batch-print final text instead of per-token streaming --vocab vocab.bin path (default: tokenizer_data/vocab.bin) --pld* see "PLD degeneration guard" section ``` --- ## Known limitations - **Not yet reaching cann-recipes GE graph 54 t/s baseline** (currently ~27 t/s stable / up to ~45 t/s PLD). Closing the gap requires one of: (a) real graph compilation, (b) fused collectives (`MatmulAllReduce`, `GroupedMatmulAllReduce`) which are absent on 910 initial-gen, (c) migration to 910B/A2/A3. - **Only `tp_size` ∈ {1, 2, 4, 8, 16}** supported. Values that don't evenly divide 64 Q heads will error. - **PLD on factual/code prompts is unreliable** — either produces baseline TG (guard rejects most drafts) or enters partial degeneration the classifier may not catch at low-severity. Use `bench_pld_safe.sh` to evaluate honestly. - **Tokenizer requires Python subprocess** — adds ~1s startup for first encode. Override via `QWEN3_PYENV_INIT` env if default conda path doesn't match. - **NPU performance has high run-to-run variance** (up to 4× in some configurations) due to BF16 + MoE intrinsic non-determinism and shared hardware resources. Report medians over ≥5 runs. --- ## Future directions (prioritized) 1. **Draft Model Speculative Decoding** with Qwen3-0.6B — more stable accept rate than n-gram PLD, expected +60-100% TG across prompt types (1-2 week implementation). 2. **HCCL AllReduce / compute overlap** — ~+10-15% in theory, limited by EAGER path serial dependencies. 3. **KV cache INT8 quantization** — reduces memory-bandwidth pressure, ~+15-25% on long contexts (pending 910-initial-gen op support verification). 4. **W8 weight quantization** — ~+10-20% if aclnn quantization kernels exist on 910 initial-gen. Not recommended: - `aclmdlRI` stream-capture-style graph recording (POC proved 1.13× ceiling, not worth the engineering cost). - Custom AscendC fused ops (high maintenance cost unless dedicated kernel engineer). - torchair / torch.compile migration (breaks pure-C++ design). --- ## Documentation - [`docs/optimization-summary-zh.md`](docs/optimization-summary-zh.md) — 阶段性优化总结(中文):关键优化原因、PLD 正确性边界、项目级教训 - [`docs/next-steps-draft-model-speculative.md`](docs/next-steps-draft-model-speculative.md) — Draft Model Speculative Decoding(Qwen3-0.6B)执行规格:M1-M4 里程碑、正确性测试协议、风险兜底 --- ## License Apache License 2.0 — see `LICENSE`.