#!/usr/bin/env bash # bench_pld_safe.sh — PLD benchmark with output correctness check. # Unlike bench_tg.sh (which only reports TG numbers), this wrapper also inspects the # generated text for degeneration signals (consecutive identical tokens / very low # distinct-token ratio in the tail) and flags runs whose high TG came from dead-loop # output rather than real acceleration. # # Usage: ./scripts/bench_pld_safe.sh [N_RUNS] [PROMPT_FILE] # Prompts with "|" separator: "tag|prompt text" # Default: tests multiple prompt classes and reports which ones PLD helps safely. set -u cd "$(dirname "$0")/.." MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}" BIN="./build/qwen3-moe-aclnn" N_RUNS="${1:-3}" N_PREDICT="${N_PREDICT:-120}" VOCAB="tokenizer_data/vocab.bin" # Default prompt suite: one per class. Override via PROMPTS env or arg 2 (file with "tag|prompt" per line). default_prompts=( "story|Once upon a time, in a small village," "factual|The capital of France is" "code|Write a Python function that computes Fibonacci." "essay|The history of artificial intelligence spans several decades and" ) if [ "${2:-}" != "" ] && [ -f "${2:-}" ]; then mapfile -t prompts < "$2" else prompts=("${default_prompts[@]}") fi # ----- Correctness classifier ----- # Reads generated text from stdin, returns: # OK — no loop signals # LOOP_N — N+ consecutive identical non-space words detected # LOW_DIVERSITY — tail 40 words have < 10 distinct words (heavy repetition) classify_output() { awk ' { # Tokenize on whitespace; strip punct at edges for comparison. n = split($0, w, /[[:space:]]+/); for (i = 1; i <= n; i++) { gsub(/^[[:punct:]]+|[[:punct:]]+$/, "", w[i]); if (w[i] == "") continue; words[++nw] = tolower(w[i]); } } END { if (nw < 5) { print "OK"; exit } # consecutive-same detection run = 1; max_run = 1; for (i = 2; i <= nw; i++) { if (words[i] == words[i-1]) { run++; if (run > max_run) max_run = run; } else run = 1; } if (max_run >= 6) { printf "LOOP_%d\n", max_run; exit } # tail diversity: last 40 words tail_start = nw - 39; if (tail_start < 1) tail_start = 1; delete seen; distinct = 0; for (i = tail_start; i <= nw; i++) { if (!(words[i] in seen)) { seen[words[i]] = 1; distinct++; } } tail_n = nw - tail_start + 1; if (tail_n >= 20 && distinct < 10) { printf "LOW_DIVERSITY_%d/%d\n", distinct, tail_n; exit; } print "OK"; }' } run_once() { local prompt="$1" local extra_flags="$2" # Launch. The binary prints to stdout: rank/cli headers, runner loading lines, # generated text (--no-stream), then perf lines. pld/warn go to stderr. local stdout_file=$(mktemp) local stderr_file=$(mktemp) # Ensure no lockfile leftover. ssh_cleanup_lockfile ./scripts/tp_launch.sh 16 $BIN --model-dir "$MODEL" \ --prompt "$prompt" --n-predict $N_PREDICT \ --vocab "$VOCAB" --seed 0 --no-stream --temperature 0 \ $extra_flags 1>"$stdout_file" 2>"$stderr_file" # TG lives on stdout (from printf in binary). local tg=$(grep "\[perf\] decode" "$stdout_file" | awk '{print $(NF-2)}') # Generated text: the line that begins with the prompt (--no-stream echoes prompt+text). local gen_text=$(grep -F -- "$prompt" "$stdout_file" | grep -v '^\[' | tail -1) local stripped="${gen_text#$prompt}" local verdict=$(echo "$stripped" | classify_output) local has_warn="" if grep -q "\[warn\]" "$stderr_file"; then has_warn="WARN"; fi local pld_line=$(grep "\[pld\]" "$stderr_file" | tail -1 | sed 's/^\[pld\] //') rm -f "$stdout_file" "$stderr_file" echo "${tg:-0}|${verdict}|${has_warn}|${pld_line}" } ssh_cleanup_lockfile() { rm -f /tmp/hccl_root_info.bin 2>/dev/null || true } bench_prompt() { local tag="$1"; local prompt="$2"; local flags="$3" echo "" echo "=== [$tag] $(echo "$prompt" | head -c 50)... (flags: ${flags:-none}) ===" local tgs=() verdicts=() warns=() plds=() for r in $(seq 1 $N_RUNS); do result=$(run_once "$prompt" "$flags") IFS='|' read -r tg verdict warn pld <<< "$result" printf " run %d: TG=%s verdict=%s %s\n" "$r" "$tg" "$verdict" "$warn" [ -n "$pld" ] && printf " %s\n" "$pld" tgs+=("${tg:-0}"); verdicts+=("$verdict"); warns+=("$warn") rm -f /tmp/hccl_root_info.bin done # Split good vs degraded local good_tgs=() bad_tgs=() for i in "${!tgs[@]}"; do if [ "${verdicts[$i]}" = "OK" ]; then good_tgs+=("${tgs[$i]}"); else bad_tgs+=("${tgs[$i]}"); fi done local n_good=${#good_tgs[@]} local n_bad=${#bad_tgs[@]} echo " → $n_good/$N_RUNS OK, $n_bad/$N_RUNS degraded" if [ $n_good -gt 0 ]; then local mean=$(printf '%s\n' "${good_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}') echo " → OK mean TG: $mean t/s (values: ${good_tgs[*]})" fi if [ $n_bad -gt 0 ]; then local bad_mean=$(printf '%s\n' "${bad_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}') echo " → degraded mean TG: $bad_mean t/s (DO NOT REPORT as speedup) (values: ${bad_tgs[*]})" fi } echo "bench_pld_safe: $N_RUNS runs × $N_PREDICT tokens per prompt; comparing [no-pld, pld+guard, pld+no-guard]" for entry in "${prompts[@]}"; do tag="${entry%%|*}" prompt="${entry#*|}" bench_prompt "$tag/base" "$prompt" "" bench_prompt "$tag/pld+guard" "$prompt" "--pld" bench_prompt "$tag/pld-raw" "$prompt" "--pld --pld-no-guard" done echo "" echo "==========================================================" echo "Interpretation:" echo " OK mean TG is the only honest number to report." echo " Any 'degraded' result with high TG is a dead-loop artifact." echo " Expected: pld+guard matches or beats base on creative/story prompts," echo " matches base on factual/code prompts (drafts rejected → fallback to single decode)." echo " pld-raw (no guard) on repetitive prompts produces 'degraded' with high TG."