File size: 2,434 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | #!/usr/bin/env bash
# bench_pld.sh — sweep K × n-gram with corrected causal-with-past mask.
# Measures TG + accept rate stability across N_RUNS per config.
set -u
cd "$(dirname "$0")/.."
MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
BIN="./build/qwen3-moe-aclnn"
LAUNCH="./scripts/tp_launch.sh"
TP=16
N_PREDICT=200
N_RUNS="${N_RUNS:-3}"
PROMPT="${PROMPT:-Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints.}"
VOCAB="tokenizer_data/vocab.bin"
OUT=/tmp/bench_pld.csv
echo "k,ngram,run_tgs,best,median,avg_accept" > $OUT
run_one() {
local k="$1" ng="$2"
local tgs=() accs=()
for r in $(seq 1 $N_RUNS); do
local output
output=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
--prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
--vocab "$VOCAB" --seed 0 --no-stream \
--pld --pld-k $k --pld-ngram $ng 2>&1)
local tg
tg=$(echo "$output" | grep "decode :" | awk '{print $(NF-2)}')
local acc
acc=$(echo "$output" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2)
tgs+=("${tg:-0}")
accs+=("${acc:-0}")
done
local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
local n=${#sorted[@]}
local best="${sorted[-1]}"
local median="${sorted[$((n/2))]}"
local accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
echo "$k,$ng,$(IFS=/; echo "${tgs[*]}"),$best,$median,$accs_avg" >> $OUT
printf " K=%-2d ng=%-1d runs=[%s] best=%s median=%s accept_avg=%s\n" \
"$k" "$ng" "${tgs[*]}" "$best" "$median" "$accs_avg"
}
echo "PLD sweep on '$PROMPT' ($N_RUNS runs × $N_PREDICT tokens)"
echo ""
for k in 2 4 6 8 12; do
for ng in 1 2 3; do
run_one $k $ng
done
done
# Baseline for reference
echo ""
echo "Baseline (no PLD):"
tgs=()
for r in $(seq 1 $N_RUNS); do
tg=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
--prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
--vocab "$VOCAB" --seed 0 --no-stream 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
tgs+=("${tg:-0}")
done
sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
echo " baseline: ${tgs[*]} median=${sorted[$((${#sorted[@]}/2))]}"
echo ""
echo "====== Sorted by median TG ======"
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k5 -gr) | column -t -s,
|