#!/usr/bin/env bash # bench_pld.sh — sweep K × n-gram with corrected causal-with-past mask. # Measures TG + accept rate stability across N_RUNS per config. set -u cd "$(dirname "$0")/.." MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}" BIN="./build/qwen3-moe-aclnn" LAUNCH="./scripts/tp_launch.sh" TP=16 N_PREDICT=200 N_RUNS="${N_RUNS:-3}" PROMPT="${PROMPT:-Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints.}" VOCAB="tokenizer_data/vocab.bin" OUT=/tmp/bench_pld.csv echo "k,ngram,run_tgs,best,median,avg_accept" > $OUT run_one() { local k="$1" ng="$2" local tgs=() accs=() for r in $(seq 1 $N_RUNS); do local output output=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \ --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \ --vocab "$VOCAB" --seed 0 --no-stream \ --pld --pld-k $k --pld-ngram $ng 2>&1) local tg tg=$(echo "$output" | grep "decode :" | awk '{print $(NF-2)}') local acc acc=$(echo "$output" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2) tgs+=("${tg:-0}") accs+=("${acc:-0}") done local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n)) local n=${#sorted[@]} local best="${sorted[-1]}" local median="${sorted[$((n/2))]}" local accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}') echo "$k,$ng,$(IFS=/; echo "${tgs[*]}"),$best,$median,$accs_avg" >> $OUT printf " K=%-2d ng=%-1d runs=[%s] best=%s median=%s accept_avg=%s\n" \ "$k" "$ng" "${tgs[*]}" "$best" "$median" "$accs_avg" } echo "PLD sweep on '$PROMPT' ($N_RUNS runs × $N_PREDICT tokens)" echo "" for k in 2 4 6 8 12; do for ng in 1 2 3; do run_one $k $ng done done # Baseline for reference echo "" echo "Baseline (no PLD):" tgs=() for r in $(seq 1 $N_RUNS); do tg=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \ --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \ --vocab "$VOCAB" --seed 0 --no-stream 2>&1 | grep "decode :" | awk '{print $(NF-2)}') tgs+=("${tg:-0}") done sorted=($(printf '%s\n' "${tgs[@]}" | sort -n)) echo " baseline: ${tgs[*]} median=${sorted[$((${#sorted[@]}/2))]}" echo "" echo "====== Sorted by median TG ======" (head -1 $OUT; tail -n +2 $OUT | sort -t, -k5 -gr) | column -t -s,