File size: 2,534 Bytes
a9bd396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import re
import subprocess
from pathlib import Path

from tabulate import tabulate


SCRIPT_LOCATION = (Path(__file__).parent.parent.parent / "examples/pytorch/continuous_batching.py").as_posix()
COMMON_ARGS = "--log-level WARNING --seed 0".split()


def run_and_parse_cb_example(args: list[str]) -> dict:
    print(f"Benchmarking with args: {args}")
    output = subprocess.check_output(
        ["python", SCRIPT_LOCATION] + args.split() + COMMON_ARGS,
        # stderr=subprocess.DEVNULL,
    ).decode("utf-8")
    if "Generation thread terminated unexpectedly." in output:
        return {
            "args": args,
            "time_seconds": "X",
            "num_tokens": "X",
            "throughput_tok_per_sec": "ERROR",
        }
    pattern = r"CB generation took: ([\d.]+) seconds for (\d+) tokens\. ([\d.]+)tok/s"
    match = re.search(pattern, output)
    if match is not None:
        return {
            "args": args,
            "time_seconds": float(match.group(1)),
            "num_tokens": int(match.group(2)),
            "throughput_tok_per_sec": float(match.group(3)),
        }
    return {}


if __name__ == "__main__":
    results = [
        {
            "args": "Arguments",
            "time_seconds": "Duration (s)",
            "num_tokens": "Generated tokens",
            "throughput_tok_per_sec": "Throughput (tok/s)",
        }
    ]

    # Benchmark with low number of samples
    results.append(run_and_parse_cb_example("--samples 10"))
    results.append(run_and_parse_cb_example("--samples 20 --num-blocks 16"))  # and low number of blocks
    results.append(run_and_parse_cb_example("--samples 50"))

    # Benchmark with compile: default, flash attention 2 and sdpa
    results.append(run_and_parse_cb_example("--samples 100"))
    results.append(run_and_parse_cb_example("--samples 100 --attn flash_attention_2"))
    results.append(run_and_parse_cb_example("--samples 100 --attn sdpa"))

    # Benchmark with high number of samples
    results.append(run_and_parse_cb_example("--samples 500"))

    # Benchmark with prefix sharing and compile (best performance, but not reproducible due to compilation)
    results.append(run_and_parse_cb_example("--samples 500 --add-prefix --compile"))

    # Benchmark with parallel decoding
    results.append(run_and_parse_cb_example("--samples 50 --num-return-sequences 8 --do-sample"))
    results.append(run_and_parse_cb_example("--samples 100 --num-return-sequences 4 --do-sample"))

    print()
    print(tabulate(results, tablefmt="github"))