Torch Compile Variants!

This file benchmarks Flash Attention with different torch.compile modes.

Flash Attention with torch.compile(mode="default")

▼ code ▼ output ▶ uv-logs | Cell: benchmark_default | 12.08s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt


def torch_flash_base(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


# Compile with default mode
compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)

kbt.add(
    "torch_flash_compiled_default",
    compiled_flash_default,
    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = "float32" if device == "cpu" else "bfloat16"

    # Flux-like workloads
    base = 1024 if device == "cuda" else 512
    flux_sizes = (
        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
    )
    heads = 24 if device == "cuda" else 8
    head_dim = 128 if device == "cuda" else 64

    wl = []
    for L in flux_sizes:
        wl.append(
            {
                "name": f"flux_L{L}",
                "batch": 1,
                "seq_len": base + L,
                "heads": heads,
                "head_dim": head_dim,
                "dtype": dtype,
                "device": device,
                "seed": 0,
            }
        )

    kbt.run(
        wl,
        jsonl="attn_default.jsonl",
        reps=5,
        warmup=2,
        gen=kbt.attn.gen_qkv,
        ref=kbt.attn.ref_math,
        cmp=kbt.attn.cmp_allclose,
        profile_trace=True
    )
    kbt.summarize(["attn_default.jsonl"])
======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L128
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           torch_flash_compiled_default         0.00%       0.000us         0.00%       0.000us       0.000us     967.332us       298.12%     967.332us     967.332us             1  
                           torch_flash_compiled_default         5.37%     154.798us        99.77%       2.878ms       2.878ms       0.000us         0.00%     324.481us     324.481us             1  
                             Torch-Compiled Region: 0/1        20.96%     604.478us        92.49%       2.668ms     889.236us       0.000us         0.00%     324.481us     108.160us             3  
              aten::_scaled_dot_product_flash_attention         1.54%      44.432us         8.35%     240.853us      80.284us       0.000us         0.00%     276.257us      92.086us             3  
                         aten::_flash_attention_forward         1.64%      47.371us         5.29%     152.657us      50.886us     276.257us        85.14%     276.257us      92.086us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     276.257us        85.14%     276.257us      92.086us             3  
triton_poi_fused__scaled_dot_product_flash_attention...         3.50%     100.807us         6.04%     174.309us      19.368us      36.704us        11.31%      36.704us       4.078us             9  
triton_poi_fused__scaled_dot_product_flash_attention...         0.00%       0.000us         0.00%       0.000us       0.000us      36.704us        11.31%      36.704us       4.078us             9  
                               triton_poi_fused_clone_1         1.27%      36.672us         2.17%      62.583us      20.861us      11.520us         3.55%      11.520us       3.840us             3  
                               triton_poi_fused_clone_1         0.00%       0.000us         0.00%       0.000us       0.000us      11.520us         3.55%      11.520us       3.840us             3  
                               TorchDynamo Cache Lookup         1.91%      55.093us         1.91%      55.093us      18.364us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.36%      10.400us         0.36%      10.400us       3.467us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.70%      20.280us         0.70%      20.280us       6.760us       0.000us         0.00%       0.000us       0.000us             3  
                                Activity Buffer Request        53.91%       1.555ms        53.91%       1.555ms       1.555ms       0.000us         0.00%       0.000us       0.000us             1  
                                         cuLaunchKernel         3.45%      99.413us         3.45%      99.413us       8.284us       0.000us         0.00%       0.000us       0.000us            12  
                                        aten::transpose         1.19%      34.395us         1.52%      43.764us       3.647us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::as_strided         0.32%       9.369us         0.32%       9.369us       0.781us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::empty_like         0.44%      12.621us         1.20%      34.732us      11.577us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.77%      22.111us         0.77%      22.111us       7.370us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::empty         1.24%      35.841us         1.24%      35.841us       2.987us       0.000us         0.00%       0.000us       0.000us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.884ms
Self CUDA time total: 324.481us



======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L256
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           torch_flash_compiled_default         0.00%       0.000us         0.00%       0.000us       0.000us     834.378us       233.60%     834.378us     834.378us             1  
                           torch_flash_compiled_default         4.04%      97.294us        99.68%       2.400ms       2.400ms       0.000us         0.00%     357.190us     357.190us             1  
                             Torch-Compiled Region: 0/3        19.97%     480.803us        94.43%       2.274ms     757.987us       0.000us         0.00%     357.190us     119.063us             3  
              aten::_scaled_dot_product_flash_attention         1.08%      25.983us         7.33%     176.640us      58.880us       0.000us         0.00%     300.165us     100.055us             3  
                         aten::_flash_attention_forward         1.50%      36.164us         5.01%     120.717us      40.239us     300.165us        84.04%     300.165us     100.055us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     300.165us        84.04%     300.165us     100.055us             3  
triton_poi_fused__scaled_dot_product_flash_attention...         3.30%      79.496us         6.27%     150.937us      16.771us      40.161us        11.24%      40.161us       4.462us             9  
triton_poi_fused__scaled_dot_product_flash_attention...         0.00%       0.000us         0.00%       0.000us       0.000us      40.161us        11.24%      40.161us       4.462us             9  
                               triton_poi_fused_clone_1         2.33%      56.123us         3.38%      81.404us      27.135us      16.864us         4.72%      16.864us       5.621us             3  
                               triton_poi_fused_clone_1         0.00%       0.000us         0.00%       0.000us       0.000us      16.864us         4.72%      16.864us       5.621us             3  
                               TorchDynamo Cache Lookup         1.21%      29.133us         1.21%      29.133us       9.711us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.32%       7.730us         0.32%       7.730us       2.577us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.49%      11.750us         0.49%      11.750us       3.917us       0.000us         0.00%       0.000us       0.000us             3  
                                Activity Buffer Request        56.67%       1.365ms        56.67%       1.365ms       1.365ms       0.000us         0.00%       0.000us       0.000us             1  
                                         cuLaunchKernel         4.02%      96.722us         4.02%      96.722us       8.060us       0.000us         0.00%       0.000us       0.000us            12  
                                        aten::transpose         0.90%      21.580us         1.24%      29.940us       2.495us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::as_strided         0.35%       8.360us         0.35%       8.360us       0.697us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::empty_like         0.27%       6.480us         1.00%      23.971us       7.990us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.73%      17.491us         0.73%      17.491us       5.830us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::empty         1.24%      29.800us         1.24%      29.800us       2.483us       0.000us         0.00%       0.000us       0.000us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.408ms
Self CUDA time total: 357.190us



======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L320
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           torch_flash_compiled_default         0.00%       0.000us         0.00%       0.000us       0.000us     876.295us       230.02%     876.295us     876.295us             1  
                           torch_flash_compiled_default         3.99%      99.235us        99.67%       2.477ms       2.477ms       0.000us         0.00%     380.963us     380.963us             1  
                             Torch-Compiled Region: 0/5        19.71%     489.623us        94.50%       2.348ms     782.708us       0.000us         0.00%     380.963us     126.988us             3  
              aten::_scaled_dot_product_flash_attention         1.15%      28.583us         7.58%     188.458us      62.819us       0.000us         0.00%     323.107us     107.702us             3  
                         aten::_flash_attention_forward         1.61%      40.110us         5.06%     125.615us      41.872us     323.107us        84.81%     323.107us     107.702us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     323.107us        84.81%     323.107us     107.702us             3  
triton_poi_fused__scaled_dot_product_flash_attention...         3.47%      86.344us         6.19%     153.807us      17.090us      44.448us        11.67%      44.448us       4.939us             9  
triton_poi_fused__scaled_dot_product_flash_attention...         0.00%       0.000us         0.00%       0.000us       0.000us      44.448us        11.67%      44.448us       4.939us             9  
                               triton_poi_fused_clone_1         1.44%      35.902us         2.40%      59.634us      19.878us      13.408us         3.52%      13.408us       4.469us             3  
                               triton_poi_fused_clone_1         0.00%       0.000us         0.00%       0.000us       0.000us      13.408us         3.52%      13.408us       4.469us             3  
                               TorchDynamo Cache Lookup         1.18%      29.223us         1.18%      29.223us       9.741us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.30%       7.450us         0.30%       7.450us       2.483us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.46%      11.502us         0.46%      11.502us       3.834us       0.000us         0.00%       0.000us       0.000us             3  
                                Activity Buffer Request        57.86%       1.438ms        57.86%       1.438ms       1.438ms       0.000us         0.00%       0.000us       0.000us             1  
                                         cuLaunchKernel         3.67%      91.195us         3.67%      91.195us       7.600us       0.000us         0.00%       0.000us       0.000us            12  
                                        aten::transpose         0.95%      23.681us         1.38%      34.260us       2.855us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::as_strided         0.43%      10.579us         0.43%      10.579us       0.882us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::empty_like         0.27%       6.811us         0.93%      23.051us       7.684us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.65%      16.240us         0.65%      16.240us       5.413us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::empty         1.30%      32.232us         1.30%      32.232us       2.686us       0.000us         0.00%       0.000us       0.000us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.485ms
Self CUDA time total: 380.963us



======================================================================
PROFILE TRACE: torch_flash_compiled_default | flux_L384
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           torch_flash_compiled_default         0.00%       0.000us         0.00%       0.000us       0.000us     900.385us       224.95%     900.385us     900.385us             1  
                           torch_flash_compiled_default         3.56%     101.756us        99.74%       2.848ms       2.848ms       0.000us         0.00%     400.258us     400.258us             1  
                             Torch-Compiled Region: 0/7        18.27%     521.655us        95.19%       2.718ms     906.103us       0.000us         0.00%     400.258us     133.419us             3  
              aten::_scaled_dot_product_flash_attention         0.99%      28.253us         6.33%     180.729us      60.243us       0.000us         0.00%     336.352us     112.117us             3  
                         aten::_flash_attention_forward         1.29%      36.890us         4.19%     119.565us      39.855us     336.352us        84.03%     336.352us     112.117us             3  
void pytorch_flash::flash_fwd_kernel<Flash_fwd_kerne...         0.00%       0.000us         0.00%       0.000us       0.000us     336.352us        84.03%     336.352us     112.117us             3  
triton_poi_fused__scaled_dot_product_flash_attention...         3.07%      87.777us        16.12%     460.302us      51.145us      49.985us        12.49%      49.985us       5.554us             9  
triton_poi_fused__scaled_dot_product_flash_attention...         0.00%       0.000us         0.00%       0.000us       0.000us      49.985us        12.49%      49.985us       5.554us             9  
                               triton_poi_fused_clone_1         1.24%      35.330us         2.05%      58.492us      19.497us      13.921us         3.48%      13.921us       4.640us             3  
                               triton_poi_fused_clone_1         0.00%       0.000us         0.00%       0.000us       0.000us      13.921us         3.48%      13.921us       4.640us             3  
                               TorchDynamo Cache Lookup         0.99%      28.213us         0.99%      28.213us       9.404us       0.000us         0.00%       0.000us       0.000us             3  
                                      Pregraph bytecode         0.25%       7.170us         0.25%       7.170us       2.390us       0.000us         0.00%       0.000us       0.000us             3  
                 AOTDispatcher Runtime Wrapper Prologue         0.43%      12.361us         0.43%      12.361us       4.120us       0.000us         0.00%       0.000us       0.000us             3  
                                Activity Buffer Request        51.74%       1.478ms        51.74%       1.478ms       1.478ms       0.000us         0.00%       0.000us       0.000us             1  
                                         cuLaunchKernel        13.86%     395.687us        13.86%     395.687us      32.974us       0.000us         0.00%       0.000us       0.000us            12  
                                        aten::transpose         0.83%      23.691us         1.15%      32.911us       2.743us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::as_strided         0.32%       9.220us         0.32%       9.220us       0.768us       0.000us         0.00%       0.000us       0.000us            12  
                                       aten::empty_like         0.23%       6.600us         0.78%      22.311us       7.437us       0.000us         0.00%       0.000us       0.000us             3  
                                    aten::empty_strided         0.55%      15.711us         0.55%      15.711us       5.237us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::empty         1.03%      29.502us         1.03%      29.502us       2.459us       0.000us         0.00%       0.000us       0.000us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.856ms
Self CUDA time total: 400.258us


impl                     wl                  p50(ms)  ok
torch_flash_compiled_default flux_L128              0.20  True
torch_flash_compiled_default flux_L256              0.23  True
torch_flash_compiled_default flux_L320              0.24  True
torch_flash_compiled_default flux_L384              0.24  True
torch_flash_compiled_default flux_L448             FAIL  False
  Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
torch_flash_compiled_default flux_L512             FAIL  False
  Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
▶ UV Install Logs
W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] torch._dynamo hit config.recompile_limit (8) W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] function: 'torch_flash_base' (/__w/kernels-benchmarks/kernels-benchmarks/benches/flash_attn/impls/.uvnote/cells/benchmark_default.py:18) W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] last reason: 0/7: GLOBAL_STATE changed: num_threads W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". W1023 17:21:27.942000 6833 torch/_dynamo/convert_frame.py:1016] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html. W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] torch._dynamo hit config.recompile_limit (8) W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] function: 'torch_flash_base' (/__w/kernels-benchmarks/kernels-benchmarks/benches/flash_attn/impls/.uvnote/cells/benchmark_default.py:18) W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] last reason: 0/7: GLOBAL_STATE changed: num_threads W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] To log all recompilation reasons, use TORCH_LOGS="recompiles". W1023 17:21:27.948000 6833 torch/_dynamo/convert_frame.py:1016] [0/9] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.

Artifacts:

attn_default.jsonl

Flash Attention with torch.compile(mode="max-autotune")

▼ code ▼ output ▶ uv-logs | Cell: benchmark_max_autotune | 18.98s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
import os
import kernels_benchmark_tools as kbt


def torch_flash_base(q, k, v):
    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
    return o.transpose(1, 2).contiguous()


# Compile with max-autotune mode
compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)

kbt.add(
    "torch_flash_compiled_max_autotune",
    compiled_flash_max_autotune,
    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
)

if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = "float32" if device == "cpu" else "bfloat16"

    # Flux-like workloads
    base = 1024 if device == "cuda" else 512
    flux_sizes = (
        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
    )
    heads = 24 if device == "cuda" else 8
    head_dim = 128 if device == "cuda" else 64

    wl = []
    for L in flux_sizes:
        wl.append(
            {
                "name": f"flux_L{L}",
                "batch": 1,
                "seq_len": base + L,
                "heads": heads,
                "head_dim": head_dim,
                "dtype": dtype,
                "device": device,
                "seed": 0,
            }
        )

    kbt.run(
        wl,
        jsonl="attn_max_autotune.jsonl",
        reps=5,
        warmup=2,
        gen=kbt.attn.gen_qkv,
        ref=kbt.attn.ref_math,
        cmp=kbt.attn.cmp_allclose,
    )
    kbt.summarize(["attn_max_autotune.jsonl"])
impl                     wl                  p50(ms)  ok
torch_flash_compiled_max_autotune flux_L128              0.19  True
torch_flash_compiled_max_autotune flux_L256              0.20  True
torch_flash_compiled_max_autotune flux_L320              0.21  True
torch_flash_compiled_max_autotune flux_L384              0.21  True
torch_flash_compiled_max_autotune flux_L448             FAIL  False
  Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
torch_flash_compiled_max_autotune flux_L512             FAIL  False
  Error: recompile_limit reached with one_graph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.
▶ UV Install Logs
W1023 17:21:15.860000 6116 torch/_dynamo/convert_frame.py:1016] [0/8] torch._dynamo hit config.recompile_limit (8) W1023 17:21:15.860000 6116 torch/_dynamo/convert_frame.py:1016] [0/8] function: 'torch_flash_base' (/__w/kernels-benchmarks/kernels-benchmarks/benches/flash_attn/impls/.uvnote/cells/benchmark_max_autotune.py:18) W1023 17:21:15.860000 6116 torch/_dynamo/convert_frame.py:1016] [0/8] last reason: 0/7: GLOBAL_STATE changed: num_threads W1023 17:21:15.860000 6116 torch/_dynamo/convert_frame.py:1016] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". W1023 17:21:15.860000 6116 torch/_dynamo/convert_frame.py:1016] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html. W1023 17:21:15.866000 6116 torch/_dynamo/convert_frame.py:1016] [0/9] torch._dynamo hit config.recompile_limit (8) W1023 17:21:15.866000 6116 torch/_dynamo/convert_frame.py:1016] [0/9] function: 'torch_flash_base' (/__w/kernels-benchmarks/kernels-benchmarks/benches/flash_attn/impls/.uvnote/cells/benchmark_max_autotune.py:18) W1023 17:21:15.866000 6116 torch/_dynamo/convert_frame.py:1016] [0/9] last reason: 0/7: GLOBAL_STATE changed: num_threads W1023 17:21:15.866000 6116 torch/_dynamo/convert_frame.py:1016] [0/9] To log all recompilation reasons, use TORCH_LOGS="recompiles". W1023 17:21:15.866000 6116 torch/_dynamo/convert_frame.py:1016] [0/9] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.