Spaces:
Sleeping
Sleeping
Fix sdpa backend legend
Browse files- bar_plot.py +14 -5
bar_plot.py
CHANGED
|
@@ -23,11 +23,20 @@ def reorder_data(per_scenario_data: dict) -> dict:
|
|
| 23 |
|
| 24 |
def infer_bar_label(config: dict) -> str:
|
| 25 |
"""Format legend labels to be more readable."""
|
| 26 |
-
attn_implementation
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
"
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
compile = "compiled" if config["compilation"] else "no compile"
|
| 32 |
kernels = "kernelized" if config["kernelize"] else "no kernels"
|
| 33 |
return f"{attn_implementation}, {compile}, {kernels}"
|
|
|
|
| 23 |
|
| 24 |
def infer_bar_label(config: dict) -> str:
|
| 25 |
"""Format legend labels to be more readable."""
|
| 26 |
+
if config["attn_implementation"] == "eager":
|
| 27 |
+
attn_implementation = "Eager"
|
| 28 |
+
elif config["attn_implementation"] == "flash_attention_2":
|
| 29 |
+
attn_implementation = "Flash attention"
|
| 30 |
+
elif config["attn_implementation"] == "sdpa":
|
| 31 |
+
attn_implementation = {
|
| 32 |
+
"flash_attention": "SDPA (flash attention)",
|
| 33 |
+
"efficient_attention": "SDPA (efficient_attention)",
|
| 34 |
+
"cudnn_attention": "SDPA (cudnn)",
|
| 35 |
+
"math": "SDPA (math)",
|
| 36 |
+
}.get(config["sdpa_backend"], "SDPA (unknown backend)")
|
| 37 |
+
else:
|
| 38 |
+
attn_implementation = "Unknown"
|
| 39 |
+
|
| 40 |
compile = "compiled" if config["compilation"] else "no compile"
|
| 41 |
kernels = "kernelized" if config["kernelize"] else "no kernels"
|
| 42 |
return f"{attn_implementation}, {compile}, {kernels}"
|