File size: 15,535 Bytes
4b9fefd ac7d8da 4b9fefd 08ad55b 4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 | # qwen3-moe-aclnn
Pure C++ inference of **Qwen3-235B-A22B-Instruct** BF16 on **Ascend 910 ร 16 NPU**, built directly on the aclnn EAGER API (no graph compilation, no PyTorch, no ggml).
ไธญๆ็ๆฌ๏ผ[README_zh.md](README_zh.md)
---
## Performance
Measured on Ascend 910 initial-gen ร 16 NPU (TP=16) with Qwen3-235B-A22B-Instruct-2507 BF16 weights.
All numbers are **quality-preserving TG** (output was manually verified); greedy `temperature=0`.
| Configuration | TG | Applicable prompts |
|---|---|---|
| Untuned baseline | 12 t/s | All |
| **Default recommended** (no PLD) | **~27 t/s** | **All prompts, stable output** |
| PLD with degeneration guard | 29-45 t/s | Structured text (essays, long-form answers) |
| PLD on creative prompts | 25-40 t/s | Stories / varied generation |
| PLD on factual / code prompts | unstable (21-95 t/s, high variance) | Not recommended |
Reference: `cann-recipes-infer` GE graph baseline reports ~54 t/s on the same hardware. **This project does not exceed that baseline** โ it trades some peak speed for (a) no graph compilation, (b) no PyTorch dependency, (c) full control over operator scheduling.
### Key optimizations that contributed (in order of magnitude)
| Rank | Optimization | Gain | Where |
|---|---|---|---|
| ๐ฅ | HCCL env tuning (`AIV` + `FFTS` + `TASK_QUEUE=2`) | +89% (12โ23 t/s) | `scripts/tp_launch.sh` |
| ๐ฅ | Fused RoPE via `aclnnApplyRotaryPosEmbV2` | +17% (23โ27 t/s) | `include/rope.h` |
| ๐ฅ | Prompt Lookup Decoding (PLD) w/ degeneration guard | +10-60% on applicable prompts | `src/main_cli.cpp` |
| โ | Device-side topk-w normalize, MoE argsort, cos/sin cache | ~+15% cumulative | `include/engine.h` |
| โ | WorkspacePool (thread-local + retain-old) | reduces alloc overhead | `include/workspace_pool.h` |
---
## Architecture
**Model**: Qwen3-235B-A22B, 94 layers, 128 experts (top-k=8), GQA (64 Q heads, 4 KV heads), BF16.
**Parallelism**: TP=16 via HCCL ring AllReduce. KV heads sharded 1-per-rank (since 4 KV heads < 16 ranks, Q heads 0-3 on each rank share KV head 0).
**Execution**: aclnn EAGER mode โ every op goes through `aclnn*` single-op API with workspace pool; no graph capture, no GE IR. Async stream execution with `TASK_QUEUE_ENABLE=2` for kernel submission overlap.
**Tokenizer**: Uses HuggingFace `transformers` via a Python subprocess for encoding; vocab decode is pure C++ from an exported `vocab.bin`.
### Per-layer forward flow
```
x_in [S, D=4096]
โ
โโโ Attention branch (TP: Q_DIM=512=4hร128, KV_DIM=128=1hร128) โโโ
โ RmsNorm(input_layernorm)
โ linear_hf q_proj / k_proj / v_proj โ q, k, v
โ Per-head RmsNorm q_norm, k_norm
โ Fused RoPE: aclnnApplyRotaryPosEmbV2 (layout=1, "half")
โ Append K, V to per-layer KV cache
โ Mask selection:
โ prefill: 2048ร2048 causal + sparse_mode=3
โ decode S=1: nullptr + sparse_mode=0
โ batch decode: [1,1,S,past+S] custom bool mask + sparse_mode=0
โ FIAS (aclnnFusedInferAttentionScore)
โ o_proj linear_hf โ partial per-rank
โ HCCL AllReduce (ring + AIV + FFTS) โ full
โโโโโโโโโโโ
โ residual add
โโโ MoE branch โโโ
โ RmsNorm(post_attention_layernorm)
โ router linear_hf โ logits [S, 128]
โ moe_gating_topk_softmax โ topk_w[S,8], topk_idx[S,8]
โ Device-side normalize (reduce_sum + adds + cast + div)
โ moe_init_routing_v3 โ expanded_x, expanded_ri, tokens_per_expert
โ grouped_matmul_v4 gate/up/down (SwiGLU activation)
โ Device-side argsort ร 2 โ fwd permutation (avoids host sync)
โ IndexSelect โ packed
โ Broadcast-mul by topk_w + ReduceSum axis=1
โ HCCL AllReduce โ full
โโโโโโโโโโโ
โ residual add
x_out
```
---
## Model weights
This project targets **Qwen3-235B-A22B-Instruct-2507** (BF16). About **470 GB** of safetensors shards.
**Download sources**:
- HuggingFace: https://huggingface.co/Qwen/Qwen3-235B-A22B-Instruct-2507
- ModelScope: https://www.modelscope.cn/models/Qwen/Qwen3-235B-A22B-Instruct-2507
Download via `huggingface-cli` or `modelscope` CLI:
```bash
# HuggingFace
huggingface-cli download Qwen/Qwen3-235B-A22B-Instruct-2507 --local-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
# ModelScope
modelscope download --model Qwen/Qwen3-235B-A22B-Instruct-2507 --local_dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
```
**Weights format**: the binary reads HuggingFace `.safetensors` shards (multi-shard mmap), `config.json`, and `tokenizer.json` directly from the model directory. No conversion step is needed โ point `--model-dir` at the downloaded directory.
**Expected directory contents**:
```
Qwen3-235B-A22B-Instruct-2507-BF16/
โโโ config.json
โโโ tokenizer.json
โโโ tokenizer_config.json
โโโ model-00001-of-000XX.safetensors
โโโ ...
โโโ model.safetensors.index.json
```
---
## Build
```bash
source /usr/local/Ascend/ascend-toolkit/set_env.sh
cmake -B build
cmake --build build -j8 --target qwen3-moe-aclnn
```
**Requires**:
- CANN 8.5.1 or compatible
- Python 3 + `transformers` + `torch_npu` (for tokenizer subprocess and reference-data generation only)
- C++17 compiler
- Ascend 910 ร 16 NPU
- nlohmann/json (bundled as `external/json.hpp`)
**Python environment setup** โ the tokenizer calls a Python subprocess. Override the activation command via `QWEN3_PYENV_INIT` if your conda / venv layout differs from the default:
```bash
export QWEN3_PYENV_INIT="source /opt/my_conda/etc/profile.d/conda.sh && conda activate my_env && "
```
If unset, the default tries `${HOME}/miniconda3` with env `qwen3` and auto-sources the Ascend toolkit.
---
## Quick-start inference
```bash
# 1. Export tokenizer vocab to binary (one-time setup)
python3 scripts/export_vocab.py /path/to/Qwen3-235B-A22B-Instruct-2507-BF16
# 2. Run inference (TP=16)
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn \
--model-dir /path/to/Qwen3-235B-A22B-Instruct-2507-BF16 \
--prompt "The capital of France is" \
--n-predict 100 \
--temperature 0 \
--vocab tokenizer_data/vocab.bin
```
Expected: ~27 t/s, coherent output.
### Recommended flags by use case
**Universal default (stable, any prompt)** โ no PLD:
```bash
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --temperature 0 --no-stream
```
**Structured / long-form (essays, explanations)** โ PLD with guard gives +60-90%:
```bash
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... --pld --temperature 0 --no-stream
```
**Interactive REPL (multi-turn chat)**:
```bash
./scripts/tp_launch.sh 16 ./build/qwen3-moe-aclnn --model-dir ... \
--interactive --chat --temperature 0.7 --top-p 0.8
```
---
## PLD degeneration guard
Prompt Lookup Decoding speeds up generation by having the model verify a batch of "draft" tokens in a single forward pass. The drafts are copied from the generation history via n-gram match.
**Known failure mode**: on prompts the model tends to repeat on (factual Q&A, code generation), the n-gram match feeds the model's own repetition back as drafts, creating a positive feedback loop that accelerates degenerate output. Early versions of this project reported misleading peak TG numbers driven by this loop.
**This project's guard** blocks suspect drafts with two heuristics:
1. **low-distinct**: draft's distinct-token count < threshold โ reject
2. **tail-echo**: all of last N hist tokens equal draft[0] โ reject
Rejected drafts fall back to single-token decode. A `[warn]` line is emitted once if the generated tail shows 8 consecutive identical tokens.
Flags:
```
--pld enable PLD (opt-in)
--pld-k N draft window size (default: 10)
--pld-ngram N n-gram match size (default: 1, with multi-level fallback)
--pld-min-hist N skip PLD until history >= N tokens (default: 20)
--pld-no-guard disable the degeneration guard (dangerous: can produce dead loops)
--pld-guard-distinct N minimum distinct tokens in draft (default: 3)
--pld-guard-tail N tail-echo window (default: 6)
--pld-loop-warn N emit warning on N consecutive identical tokens (default: 8)
```
**Honest benchmarking**: use `scripts/bench_pld_safe.sh`, which classifies each run's output as OK / LOOP_N / LOW_DIVERSITY and separates TG statistics for OK-only vs degraded runs.
---
## Correctness verification
15+ unit / integration tests checked against Python (HuggingFace Transformers) reference:
```bash
./build/test_attention_layer # rel=4.9e-4 vs Python prefill
./build/test_attention_decode # rel=0 (bit-exact)
./build/test_moe_layer # rel=3.6e-3
./build/test_layer_forward # full single layer
./build/test_runner # multi-layer runner
./build/test_rope_fused # aclnnApplyRotaryPosEmbV2 vs manual HF rotate_half
./build/test_batch_decode # S=1..8 timing
./build/test_batch_correctness # argmax consistency
./build/test_op_support # 910-specific op availability
# Integration smoke:
./tests/test_chat_flow.sh # 7/7 PASS
```
Tests expect reference data under `tests/<name>_data/` generated by `scripts/gen_*_reference.py`. See each script's docstring.
---
## Environment tuning (auto-applied by `tp_launch.sh`)
```bash
HCCL_WHITELIST_DISABLE=1
HCCL_ALGO=level0:ring # ring, not fullmesh (fullmesh causes garbled output)
HCCL_BUFFSIZE=200 # sweet spot; 100 and 400 both slower
HCCL_OP_EXPANSION_MODE=AIV # key: AI Vector cores participate in reduce scheduling
HCCL_OP_BASE_FFTS_MODE_ENABLE=1 # key: Fast Frequently-used Transfer Scheduling
TASK_QUEUE_ENABLE=2 # key: aggressive async task submission
```
Removing any of the three "key" env vars drops TG by 20-40%.
---
## Directory layout
```
include/
โโโ acl_common.h RAII wrappers, DeviceBuffer, make_contig_tensor
โโโ aclnn_ops.h single-op wrappers + WorkspacePool integration
โโโ acl_runtime.h AclRuntime (device + stream management)
โโโ device_weights.h safetensors โ device loading + TP sharding
โโโ engine.h attention_forward + moe_forward + RopeCache
โโโ hccl_comm.h HCCL init + allreduce + broadcast
โโโ model_config.h Qwen3 hyperparameters + compute_derived
โโโ rope.h apply_rope_fused (aclnnApplyRotaryPosEmbV2 wrapper)
โโโ runner.h Runner class (prefill/decode/decode_batch/rewind/profile)
โโโ safetensors_loader.h multi-shard safetensors mmap parser
โโโ tokenizer.h vocab decode + Python subprocess encode
โโโ workspace_pool.h thread-local aclnn workspace pool (retain-old)
src/
โโโ device_weights.cpp load_attention (GQA fix), load_moe (permute sync fix)
โโโ main_cli.cpp CLI entry + PLD main loop + degeneration guard + multi-turn
โโโ model_config.cpp compute_derived (GQA KV sharding)
โโโ runner.cpp Runner (build_batch_decode_mask_ etc.)
โโโ safetensors_loader.cpp
โโโ tokenizer.cpp
scripts/
โโโ tp_launch.sh production launcher (auto-applies HCCL env)
โโโ bench_tg.sh stable N-run TG measurement
โโโ bench_pld_safe.sh PLD benchmark with output-correctness classifier
โโโ bench_hccl[_adv].sh HCCL parameter sweep
โโโ bench_pld[_k].sh PLD K ร ngram sweep (legacy, prefer bench_pld_safe.sh)
โโโ export_vocab.py vocab.bin exporter from HF tokenizer
โโโ gen_*_reference.py per-op Python reference data generators
tests/
โโโ test_attention_* attention correctness (prefill / decode)
โโโ test_moe_layer MoE correctness
โโโ test_layer_forward full single layer
โโโ test_runner multi-layer Runner
โโโ test_rope_fused fused RoPE vs manual HF
โโโ test_batch_* batch decode timing + correctness
โโโ test_op_support 910-specific op availability probe
โโโ test_chat_flow.sh end-to-end integration smoke
```
---
## CLI reference
```
--model-dir <path> (required) HF safetensors directory
--prompt "<text>" prompt text
--prompt-file FILE read prompt from file (avoids shell-escape issues)
--n-predict N maximum tokens to generate
--tp-size N tensor parallelism (or set TP_SIZE env)
--max-seq N KV cache + context cap (default: 512)
--temperature F 0 = greedy; typical 0.7
--top-k N 0 = disabled
--top-p F 1.0 = disabled
--seed N 0 = time-based
--chat apply Qwen3 chat template
--system "<text>" system role text (with --chat)
--interactive, -i REPL mode (multi-turn memory with --chat)
--reset force stateless REPL (reset KV between turns)
--no-stream batch-print final text instead of per-token streaming
--vocab <path> vocab.bin path (default: tokenizer_data/vocab.bin)
--pld* see "PLD degeneration guard" section
```
---
## Known limitations
- **Not yet reaching cann-recipes GE graph 54 t/s baseline** (currently ~27 t/s stable / up to ~45 t/s PLD).
Closing the gap requires one of: (a) real graph compilation, (b) fused collectives (`MatmulAllReduce`, `GroupedMatmulAllReduce`) which are absent on 910 initial-gen, (c) migration to 910B/A2/A3.
- **Only `tp_size` โ {1, 2, 4, 8, 16}** supported. Values that don't evenly divide 64 Q heads will error.
- **PLD on factual/code prompts is unreliable** โ either produces baseline TG (guard rejects most drafts) or enters partial degeneration the classifier may not catch at low-severity. Use `bench_pld_safe.sh` to evaluate honestly.
- **Tokenizer requires Python subprocess** โ adds ~1s startup for first encode. Override via `QWEN3_PYENV_INIT` env if default conda path doesn't match.
- **NPU performance has high run-to-run variance** (up to 4ร in some configurations) due to BF16 + MoE intrinsic non-determinism and shared hardware resources. Report medians over โฅ5 runs.
---
## Future directions (prioritized)
1. **Draft Model Speculative Decoding** with Qwen3-0.6B โ more stable accept rate than n-gram PLD, expected +60-100% TG across prompt types (1-2 week implementation).
2. **HCCL AllReduce / compute overlap** โ ~+10-15% in theory, limited by EAGER path serial dependencies.
3. **KV cache INT8 quantization** โ reduces memory-bandwidth pressure, ~+15-25% on long contexts (pending 910-initial-gen op support verification).
4. **W8 weight quantization** โ ~+10-20% if aclnn quantization kernels exist on 910 initial-gen.
Not recommended:
- `aclmdlRI` stream-capture-style graph recording (POC proved 1.13ร ceiling, not worth the engineering cost).
- Custom AscendC fused ops (high maintenance cost unless dedicated kernel engineer).
- torchair / torch.compile migration (breaks pure-C++ design).
---
## Documentation
- [`docs/optimization-summary-zh.md`](docs/optimization-summary-zh.md) โ ้ถๆฎตๆงไผๅๆป็ป๏ผไธญๆ๏ผ๏ผๅ
ณ้ฎไผๅๅๅ ใPLD ๆญฃ็กฎๆง่พน็ใ้กน็ฎ็บงๆ่ฎญ
- [`docs/next-steps-draft-model-speculative.md`](docs/next-steps-draft-model-speculative.md) โ Draft Model Speculative Decoding๏ผQwen3-0.6B๏ผๆง่ก่งๆ ผ๏ผM1-M4 ้็จ็ขใๆญฃ็กฎๆงๆต่ฏๅ่ฎฎใ้ฃ้ฉๅ
ๅบ
---
## License
Apache License 2.0 โ see `LICENSE`.
|