ror's picture
ror HF Staff
matplotlib
0eac02e
raw
history blame
1.25 kB
import json
import numpy as np
def estimate_from_measures(measures: list[float], estimator: str) -> float:
if estimator == "median":
return float(np.median(measures))
elif estimator == "mean":
return float(np.mean(measures))
raise ValueError(f"Invalid estimator: {estimator}")
class ModelBenchmarkData:
def __init__(self, json_path: str) -> None:
with open(json_path, "r") as f:
self.data = json.load(f)
def get_ttft_tpot_data(self, estimator: str = "median", use_cuda_time: bool = False) -> dict:
aggregated_data = {"ttft": [], "tpot": [], "label": [], "position": []}
time_key = "cuda_time" if use_cuda_time else "wall_time"
position = 0
for cfg_name, data in self.data.items():
x_measures = [d[time_key] for d in data["ttft"]]
y_measures = [d[time_key] for d in data["tpot"]]
aggregated_data["ttft"].append(estimate_from_measures(x_measures, estimator))
aggregated_data["tpot"].append(estimate_from_measures(y_measures, estimator))
aggregated_data["label"].append(cfg_name)
aggregated_data["position"].append(position)
position += 1
return aggregated_data