File size: 6,260 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env bash
# bench_pld_safe.sh — PLD benchmark with output correctness check.
# Unlike bench_tg.sh (which only reports TG numbers), this wrapper also inspects the
# generated text for degeneration signals (consecutive identical tokens / very low
# distinct-token ratio in the tail) and flags runs whose high TG came from dead-loop
# output rather than real acceleration.
#
# Usage:  ./scripts/bench_pld_safe.sh [N_RUNS] [PROMPT_FILE]
#         Prompts with "|" separator: "tag|prompt text"
#         Default: tests multiple prompt classes and reports which ones PLD helps safely.
set -u
cd "$(dirname "$0")/.."

MODEL="${MODEL_DIR:-/path/to/Qwen3-235B-A22B-Instruct-2507-BF16}"
BIN="./build/qwen3-moe-aclnn"
N_RUNS="${1:-3}"
N_PREDICT="${N_PREDICT:-120}"
VOCAB="tokenizer_data/vocab.bin"

# Default prompt suite: one per class. Override via PROMPTS env or arg 2 (file with "tag|prompt" per line).
default_prompts=(
    "story|Once upon a time, in a small village,"
    "factual|The capital of France is"
    "code|Write a Python function that computes Fibonacci."
    "essay|The history of artificial intelligence spans several decades and"
)

if [ "${2:-}" != "" ] && [ -f "${2:-}" ]; then
    mapfile -t prompts < "$2"
else
    prompts=("${default_prompts[@]}")
fi

# ----- Correctness classifier -----
# Reads generated text from stdin, returns:
#   OK           — no loop signals
#   LOOP_N       — N+ consecutive identical non-space words detected
#   LOW_DIVERSITY — tail 40 words have < 10 distinct words (heavy repetition)
classify_output() {
    awk '
    {
        # Tokenize on whitespace; strip punct at edges for comparison.
        n = split($0, w, /[[:space:]]+/);
        for (i = 1; i <= n; i++) {
            gsub(/^[[:punct:]]+|[[:punct:]]+$/, "", w[i]);
            if (w[i] == "") continue;
            words[++nw] = tolower(w[i]);
        }
    }
    END {
        if (nw < 5) { print "OK"; exit }
        # consecutive-same detection
        run = 1; max_run = 1;
        for (i = 2; i <= nw; i++) {
            if (words[i] == words[i-1]) { run++; if (run > max_run) max_run = run; }
            else run = 1;
        }
        if (max_run >= 6) { printf "LOOP_%d\n", max_run; exit }

        # tail diversity: last 40 words
        tail_start = nw - 39; if (tail_start < 1) tail_start = 1;
        delete seen;
        distinct = 0;
        for (i = tail_start; i <= nw; i++) {
            if (!(words[i] in seen)) { seen[words[i]] = 1; distinct++; }
        }
        tail_n = nw - tail_start + 1;
        if (tail_n >= 20 && distinct < 10) {
            printf "LOW_DIVERSITY_%d/%d\n", distinct, tail_n;
            exit;
        }
        print "OK";
    }'
}

run_once() {
    local prompt="$1"
    local extra_flags="$2"
    # Launch. The binary prints to stdout: rank/cli headers, runner loading lines,
    # generated text (--no-stream), then perf lines. pld/warn go to stderr.
    local stdout_file=$(mktemp)
    local stderr_file=$(mktemp)
    # Ensure no lockfile leftover.
    ssh_cleanup_lockfile
    ./scripts/tp_launch.sh 16 $BIN --model-dir "$MODEL" \
        --prompt "$prompt" --n-predict $N_PREDICT \
        --vocab "$VOCAB" --seed 0 --no-stream --temperature 0 \
        $extra_flags 1>"$stdout_file" 2>"$stderr_file"
    # TG lives on stdout (from printf in binary).
    local tg=$(grep "\[perf\] decode" "$stdout_file" | awk '{print $(NF-2)}')
    # Generated text: the line that begins with the prompt (--no-stream echoes prompt+text).
    local gen_text=$(grep -F -- "$prompt" "$stdout_file" | grep -v '^\[' | tail -1)
    local stripped="${gen_text#$prompt}"
    local verdict=$(echo "$stripped" | classify_output)
    local has_warn=""
    if grep -q "\[warn\]" "$stderr_file"; then has_warn="WARN"; fi
    local pld_line=$(grep "\[pld\]" "$stderr_file" | tail -1 | sed 's/^\[pld\] //')
    rm -f "$stdout_file" "$stderr_file"
    echo "${tg:-0}|${verdict}|${has_warn}|${pld_line}"
}

ssh_cleanup_lockfile() {
    rm -f /tmp/hccl_root_info.bin 2>/dev/null || true
}

bench_prompt() {
    local tag="$1"; local prompt="$2"; local flags="$3"
    echo ""
    echo "=== [$tag] $(echo "$prompt" | head -c 50)... (flags: ${flags:-none}) ==="
    local tgs=() verdicts=() warns=() plds=()
    for r in $(seq 1 $N_RUNS); do
        result=$(run_once "$prompt" "$flags")
        IFS='|' read -r tg verdict warn pld <<< "$result"
        printf "  run %d: TG=%s verdict=%s %s\n" "$r" "$tg" "$verdict" "$warn"
        [ -n "$pld" ] && printf "         %s\n" "$pld"
        tgs+=("${tg:-0}"); verdicts+=("$verdict"); warns+=("$warn")
        rm -f /tmp/hccl_root_info.bin
    done
    # Split good vs degraded
    local good_tgs=() bad_tgs=()
    for i in "${!tgs[@]}"; do
        if [ "${verdicts[$i]}" = "OK" ]; then good_tgs+=("${tgs[$i]}"); else bad_tgs+=("${tgs[$i]}"); fi
    done
    local n_good=${#good_tgs[@]}
    local n_bad=${#bad_tgs[@]}
    echo "  → $n_good/$N_RUNS OK, $n_bad/$N_RUNS degraded"
    if [ $n_good -gt 0 ]; then
        local mean=$(printf '%s\n' "${good_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
        echo "  → OK mean TG: $mean t/s  (values: ${good_tgs[*]})"
    fi
    if [ $n_bad -gt 0 ]; then
        local bad_mean=$(printf '%s\n' "${bad_tgs[@]}" | awk '{s+=$1} END {printf "%.2f", s/NR}')
        echo "  → degraded mean TG: $bad_mean t/s (DO NOT REPORT as speedup)  (values: ${bad_tgs[*]})"
    fi
}

echo "bench_pld_safe: $N_RUNS runs × $N_PREDICT tokens per prompt; comparing [no-pld, pld+guard, pld+no-guard]"

for entry in "${prompts[@]}"; do
    tag="${entry%%|*}"
    prompt="${entry#*|}"
    bench_prompt "$tag/base"      "$prompt" ""
    bench_prompt "$tag/pld+guard" "$prompt" "--pld"
    bench_prompt "$tag/pld-raw"   "$prompt" "--pld --pld-no-guard"
done

echo ""
echo "=========================================================="
echo "Interpretation:"
echo "  OK mean TG is the only honest number to report."
echo "  Any 'degraded' result with high TG is a dead-loop artifact."
echo "  Expected: pld+guard matches or beats base on creative/story prompts,"
echo "  matches base on factual/code prompts (drafts rejected → fallback to single decode)."
echo "  pld-raw (no guard) on repetitive prompts produces 'degraded' with high TG."