ror HF Staff commited on
Commit
9f6e83d
·
1 Parent(s): a3c5e05

Fix sdpa backend legend

Browse files
Files changed (1) hide show
  1. bar_plot.py +14 -5
bar_plot.py CHANGED
@@ -23,11 +23,20 @@ def reorder_data(per_scenario_data: dict) -> dict:
23
 
24
  def infer_bar_label(config: dict) -> str:
25
  """Format legend labels to be more readable."""
26
- attn_implementation = {
27
- "flash_attention_2": "Flash attention",
28
- "sdpa": "SDPA",
29
- "eager": "Eager",
30
- }[config["attn_implementation"]]
 
 
 
 
 
 
 
 
 
31
  compile = "compiled" if config["compilation"] else "no compile"
32
  kernels = "kernelized" if config["kernelize"] else "no kernels"
33
  return f"{attn_implementation}, {compile}, {kernels}"
 
23
 
24
  def infer_bar_label(config: dict) -> str:
25
  """Format legend labels to be more readable."""
26
+ if config["attn_implementation"] == "eager":
27
+ attn_implementation = "Eager"
28
+ elif config["attn_implementation"] == "flash_attention_2":
29
+ attn_implementation = "Flash attention"
30
+ elif config["attn_implementation"] == "sdpa":
31
+ attn_implementation = {
32
+ "flash_attention": "SDPA (flash attention)",
33
+ "efficient_attention": "SDPA (efficient_attention)",
34
+ "cudnn_attention": "SDPA (cudnn)",
35
+ "math": "SDPA (math)",
36
+ }.get(config["sdpa_backend"], "SDPA (unknown backend)")
37
+ else:
38
+ attn_implementation = "Unknown"
39
+
40
  compile = "compiled" if config["compilation"] else "no compile"
41
  kernels = "kernelized" if config["kernelize"] else "no kernels"
42
  return f"{attn_implementation}, {compile}, {kernels}"