File size: 2,434 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env bash
# bench_pld.sh — sweep K × n-gram with corrected causal-with-past mask.
# Measures TG + accept rate stability across N_RUNS per config.
set -u
cd "$(dirname "$0")/.."

MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
BIN="./build/qwen3-moe-aclnn"
LAUNCH="./scripts/tp_launch.sh"
TP=16
N_PREDICT=200
N_RUNS="${N_RUNS:-3}"
PROMPT="${PROMPT:-Write a long Python function that computes the Fibonacci sequence with memoization, extensive comments, and type hints.}"
VOCAB="tokenizer_data/vocab.bin"

OUT=/tmp/bench_pld.csv
echo "k,ngram,run_tgs,best,median,avg_accept" > $OUT

run_one() {
    local k="$1" ng="$2"
    local tgs=() accs=()
    for r in $(seq 1 $N_RUNS); do
        local output
        output=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
                --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
                --vocab "$VOCAB" --seed 0 --no-stream \
                --pld --pld-k $k --pld-ngram $ng 2>&1)
        local tg
        tg=$(echo "$output" | grep "decode :" | awk '{print $(NF-2)}')
        local acc
        acc=$(echo "$output" | grep "\[pld\]" | grep -oE "avg=[0-9.]+" | cut -d= -f2)
        tgs+=("${tg:-0}")
        accs+=("${acc:-0}")
    done
    local sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
    local n=${#sorted[@]}
    local best="${sorted[-1]}"
    local median="${sorted[$((n/2))]}"
    local accs_avg=$(printf '%s\n' "${accs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
    echo "$k,$ng,$(IFS=/; echo "${tgs[*]}"),$best,$median,$accs_avg" >> $OUT
    printf "  K=%-2d ng=%-1d  runs=[%s]  best=%s  median=%s  accept_avg=%s\n" \
        "$k" "$ng" "${tgs[*]}" "$best" "$median" "$accs_avg"
}

echo "PLD sweep on '$PROMPT' ($N_RUNS runs × $N_PREDICT tokens)"
echo ""

for k in 2 4 6 8 12; do
    for ng in 1 2 3; do
        run_one $k $ng
    done
done

# Baseline for reference
echo ""
echo "Baseline (no PLD):"
tgs=()
for r in $(seq 1 $N_RUNS); do
    tg=$(${LAUNCH} ${TP} ${BIN} --model-dir "$MODEL" \
         --prompt "$PROMPT" --n-predict $N_PREDICT --max-seq 512 \
         --vocab "$VOCAB" --seed 0 --no-stream 2>&1 | grep "decode :" | awk '{print $(NF-2)}')
    tgs+=("${tg:-0}")
done
sorted=($(printf '%s\n' "${tgs[@]}" | sort -n))
echo "  baseline:  ${tgs[*]}  median=${sorted[$((${#sorted[@]}/2))]}"

echo ""
echo "====== Sorted by median TG ======"
(head -1 $OUT; tail -n +2 $OUT | sort -t, -k5 -gr) | column -t -s,