ror's picture
ror HF Staff
data backend
46f4b10
raw
history blame
1.08 kB
import json
import numpy as np
def estimate_from_measures(measures: list[float], estimator: str) -> float:
if estimator == "median":
return float(np.median(measures))
elif estimator == "mean":
return float(np.mean(measures))
raise ValueError(f"Invalid estimator: {estimator}")
class ModelBenchmarkData:
def __init__(self, json_path: str) -> None:
with open(json_path, "r") as f:
self.data = json.load(f)
def get_ttft_tpot_data(self, model_name: str, estimator: str = "median", use_cuda_time: bool = False) -> dict:
data_points = []
time_key = "cuda_time" if use_cuda_time else "wall_time"
for cfg_name, data in self.data.items():
x_measures = [d[time_key] for d in data["ttft"]]
y_measures = [d[time_key] for d in data["tpot"]]
data_points.append({
"x": estimate_from_measures(x_measures, estimator),
"y": estimate_from_measures(y_measures, estimator),
"label": cfg_name,
})
return data_points