ror HF Staff commited on
Commit
46ba2c6
·
1 Parent(s): 46f4b10

data incorproration

Browse files
Files changed (2) hide show
  1. app.py +16 -11
  2. data.py +6 -8
app.py CHANGED
@@ -2,13 +2,17 @@ import pandas as pd
2
  import gradio as gr
3
  import random
4
 
5
- def generate_random_data():
6
- """Generate random scatter plot data."""
7
- n_points = random.randint(20, 50)
8
- return pd.DataFrame({
9
- "x": [random.uniform(-100, 100) for _ in range(n_points)],
10
- "y": [random.uniform(-100, 100) for _ in range(n_points)]
11
- })
 
 
 
 
12
 
13
  def load_css():
14
  """Load CSS styling."""
@@ -20,7 +24,7 @@ def load_css():
20
 
21
  def refresh_plot():
22
  """Generate new random data and update description."""
23
- return generate_random_data(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
24
 
25
  # Create Gradio interface
26
  with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
@@ -34,8 +38,9 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
34
  # Main plot area
35
  with gr.Column(elem_classes=["main-content"]):
36
  plot = gr.ScatterPlot(
37
- generate_random_data(),
38
- x="x", y="y",
 
39
  height="100vh",
40
  container=False,
41
  show_fullscreen_button=True,
@@ -46,4 +51,4 @@ with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True,
46
  summary_btn.click(fn=refresh_plot, outputs=[plot, description])
47
 
48
  if __name__ == "__main__":
49
- demo.launch()
 
2
  import gradio as gr
3
  import random
4
 
5
+ from data import ModelBenchmarkData
6
+
7
+
8
+ DATA = ModelBenchmarkData("data.json")
9
+
10
+
11
+ def refresh_plot_data():
12
+ data = DATA.get_ttft_tpot_data(estimator="median", use_cuda_time=False)
13
+ print(data)
14
+ return pd.DataFrame(data)
15
+
16
 
17
  def load_css():
18
  """Load CSS styling."""
 
24
 
25
  def refresh_plot():
26
  """Generate new random data and update description."""
27
+ return refresh_plot_data(), "**Transformer CI Dashboard**<br>-<br>**AMD runs on MI325**<br>**NVIDIA runs on A10**<br><br>*This dashboard only tracks important models*<br>*(Data refreshed)*"
28
 
29
  # Create Gradio interface
30
  with gr.Blocks(title="Random Data Dashboard", css=load_css(), fill_height=True, fill_width=True) as demo:
 
38
  # Main plot area
39
  with gr.Column(elem_classes=["main-content"]):
40
  plot = gr.ScatterPlot(
41
+ refresh_plot_data(),
42
+ x="ttft", y="tpot",
43
+ tooltip="all",
44
  height="100vh",
45
  container=False,
46
  show_fullscreen_button=True,
 
51
  summary_btn.click(fn=refresh_plot, outputs=[plot, description])
52
 
53
  if __name__ == "__main__":
54
+ demo.launch()
data.py CHANGED
@@ -16,15 +16,13 @@ class ModelBenchmarkData:
16
  with open(json_path, "r") as f:
17
  self.data = json.load(f)
18
 
19
- def get_ttft_tpot_data(self, model_name: str, estimator: str = "median", use_cuda_time: bool = False) -> dict:
20
- data_points = []
21
  time_key = "cuda_time" if use_cuda_time else "wall_time"
22
  for cfg_name, data in self.data.items():
23
  x_measures = [d[time_key] for d in data["ttft"]]
24
  y_measures = [d[time_key] for d in data["tpot"]]
25
- data_points.append({
26
- "x": estimate_from_measures(x_measures, estimator),
27
- "y": estimate_from_measures(y_measures, estimator),
28
- "label": cfg_name,
29
- })
30
- return data_points
 
16
  with open(json_path, "r") as f:
17
  self.data = json.load(f)
18
 
19
+ def get_ttft_tpot_data(self, estimator: str = "median", use_cuda_time: bool = False) -> dict:
20
+ aggregated_data = {"ttft": [], "tpot": [], "label": []}
21
  time_key = "cuda_time" if use_cuda_time else "wall_time"
22
  for cfg_name, data in self.data.items():
23
  x_measures = [d[time_key] for d in data["ttft"]]
24
  y_measures = [d[time_key] for d in data["tpot"]]
25
+ aggregated_data["ttft"].append(estimate_from_measures(x_measures, estimator))
26
+ aggregated_data["tpot"].append(estimate_from_measures(y_measures, estimator))
27
+ aggregated_data["label"].append(cfg_name)
28
+ return aggregated_data