ror HF Staff commited on
Commit
0eac02e
·
1 Parent(s): 46ba2c6

matplotlib

Browse files
Files changed (2) hide show
  1. app.py +101 -11
  2. data.py +4 -1
app.py CHANGED
@@ -1,9 +1,16 @@
1
  import pandas as pd
2
  import gradio as gr
3
- import random
 
 
 
4
 
5
  from data import ModelBenchmarkData
6
 
 
 
 
 
7
 
8
  DATA = ModelBenchmarkData("data.json")
9
 
@@ -22,9 +29,97 @@ def load_css():
22
  except FileNotFoundError:
23
  return "body { background: #000; color: #fff; }"
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def refresh_plot():
26
- """Generate new random data and update description."""
27
- return refresh_plot_data(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
28
 
29
  # Create Gradio interface
30
  with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
@@ -37,14 +132,9 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
37
 
38
  # Main plot area
39
  with gr.Column(elem_classes=["main-content"]):
40
- plot = gr.ScatterPlot(
41
- refresh_plot_data(),
42
- x="ttft", y="tpot",
43
- tooltip="all",
44
- height="100vh",
45
- container=False,
46
- show_fullscreen_button=True,
47
- elem_classes=["plot-container"]
48
  )
49
 
50
  # Button click handler
 
1
  import pandas as pd
2
  import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+ import io
6
+ import base64
7
 
8
  from data import ModelBenchmarkData
9
 
10
+ # Configure matplotlib for better performance
11
+ matplotlib.use('Agg')
12
+ plt.ioff()
13
+
14
 
15
  DATA = ModelBenchmarkData("data.json")
16
 
 
29
  except FileNotFoundError:
30
  return "body { background: #000; color: #fff; }"
31
 
32
+ def create_matplotlib_bar_charts():
33
+ """Create side-by-side matplotlib bar charts for TTFT and TPOT data."""
34
+ data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
35
+
36
+ # Create figure with dark theme - wider for side-by-side plots
37
+ plt.style.use('dark_background')
38
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12))
39
+ fig.patch.set_facecolor('#000000')
40
+
41
+ # Prepare data
42
+ labels = data['label']
43
+ ttft_values = data['ttft']
44
+ tpot_values = data['tpot']
45
+
46
+ # Define color mapping based on configuration keywords
47
+ def get_color_for_config(label):
48
+ is_eager = 'eager' in label.lower()
49
+ is_sdpa = 'sdpa' in label.lower()
50
+ is_compiled = '_compiled' in label.lower()
51
+ if is_eager:
52
+ if is_compiled:
53
+ return '#FF4444' # Red for eager compiled
54
+ else:
55
+ return '#FF6B6B' # Light red for eager uncompiled
56
+ elif is_sdpa:
57
+ if is_compiled:
58
+ return '#4A90E2' # Blue for SDPA compiled
59
+ else:
60
+ return '#7BB3F0' # Light blue for SDPA uncompiled
61
+ else:
62
+ return '#FFD700' # Yellow for others
63
+
64
+ # Get colors for each bar
65
+ colors = [get_color_for_config(label) for label in labels]
66
+
67
+ # TTFT Plot (left)
68
+ ax1.set_facecolor('#000000')
69
+ bars1 = ax1.bar(range(len(labels)), ttft_values,
70
+ color=colors, width=1.0, edgecolor='white', linewidth=1)
71
+
72
+ ax1.set_xlabel('Model Configuration', color='white', fontsize=14)
73
+ ax1.set_ylabel('TTFT (seconds)', color='white', fontsize=14)
74
+ ax1.set_title('Time To First Token by Configuration', color='white', fontsize=16, pad=20)
75
+
76
+ ax1.set_xticks(range(len(labels)))
77
+ ax1.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
78
+ rotation=45, ha='right', color='white', fontsize=10)
79
+
80
+ ax1.tick_params(colors='white')
81
+ ax1.grid(True, alpha=0.3, color='white')
82
+
83
+ # TPOT Plot (right)
84
+ ax2.set_facecolor('#000000')
85
+ bars2 = ax2.bar(range(len(labels)), tpot_values,
86
+ color=colors, width=1.0, edgecolor='white', linewidth=1)
87
+
88
+ ax2.set_xlabel('Model Configuration', color='white', fontsize=14)
89
+ ax2.set_ylabel('TPOT (seconds)', color='white', fontsize=14)
90
+ ax2.set_title('Time Per Output Token by Configuration', color='white', fontsize=16, pad=20)
91
+
92
+ ax2.set_xticks(range(len(labels)))
93
+ ax2.set_xticklabels([label[:12] + '...' if len(label) > 12 else label for label in labels],
94
+ rotation=45, ha='right', color='white', fontsize=10)
95
+
96
+ ax2.tick_params(colors='white')
97
+ ax2.grid(True, alpha=0.3, color='white')
98
+
99
+ # Tight layout to prevent label cutoff
100
+ plt.tight_layout()
101
+
102
+ # Save plot to bytes
103
+ buffer = io.BytesIO()
104
+ plt.savefig(buffer, format='png', facecolor='#000000',
105
+ bbox_inches='tight', dpi=100)
106
+ buffer.seek(0)
107
+
108
+ # Convert to base64 for HTML embedding
109
+ img_data = base64.b64encode(buffer.getvalue()).decode()
110
+ plt.close(fig)
111
+
112
+ # Return HTML with embedded image - almost full height
113
+ html = f"""
114
+ <div style="width: 100%; height: 95vh; background: #000; display: flex; justify-content: center; align-items: center;">
115
+ <img src="data:image/png;base64,{img_data}" style="max-width: 100%; max-height: 100%; object-fit: contain;" />
116
+ </div>
117
+ """
118
+ return html
119
+
120
  def refresh_plot():
121
+ """Generate new matplotlib charts and update description."""
122
+ return create_matplotlib_bar_charts(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
123
 
124
  # Create Gradio interface
125
  with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
 
132
 
133
  # Main plot area
134
  with gr.Column(elem_classes=["main-content"]):
135
+ plot = gr.HTML(
136
+ create_matplotlib_bar_charts(),
137
+ elem_classes=["plot-container"],
 
 
 
 
 
138
  )
139
 
140
  # Button click handler
data.py CHANGED
@@ -17,12 +17,15 @@ class ModelBenchmarkData:
17
  self.data = json.load(f)
18
 
19
  def get_ttft_tpot_data(self, estimator: str = "median", use_cuda_time: bool = False) -> dict:
20
- aggregated_data = {"ttft": [], "tpot": [], "label": []}
21
  time_key = "cuda_time" if use_cuda_time else "wall_time"
 
22
  for cfg_name, data in self.data.items():
23
  x_measures = [d[time_key] for d in data["ttft"]]
24
  y_measures = [d[time_key] for d in data["tpot"]]
25
  aggregated_data["ttft"].append(estimate_from_measures(x_measures, estimator))
26
  aggregated_data["tpot"].append(estimate_from_measures(y_measures, estimator))
27
  aggregated_data["label"].append(cfg_name)
 
 
28
  return aggregated_data
 
17
  self.data = json.load(f)
18
 
19
  def get_ttft_tpot_data(self, estimator: str = "median", use_cuda_time: bool = False) -> dict:
20
+ aggregated_data = {"ttft": [], "tpot": [], "label": [], "position": []}
21
  time_key = "cuda_time" if use_cuda_time else "wall_time"
22
+ position = 0
23
  for cfg_name, data in self.data.items():
24
  x_measures = [d[time_key] for d in data["ttft"]]
25
  y_measures = [d[time_key] for d in data["tpot"]]
26
  aggregated_data["ttft"].append(estimate_from_measures(x_measures, estimator))
27
  aggregated_data["tpot"].append(estimate_from_measures(y_measures, estimator))
28
  aggregated_data["label"].append(cfg_name)
29
+ aggregated_data["position"].append(position)
30
+ position += 1
31
  return aggregated_data