ror's picture
ror HF Staff
Refactor
55c8a69
raw
history blame
1.17 kB
import json
import numpy as np
from typing import Optional
class ModelBenchmarkData:
def __init__(self, json_path: str) -> None:
with open(json_path, "r") as f:
self.data = json.load(f)
def compute_e2e_latency(self, measures: dict) -> tuple[float, Optional[float]]:
return measures["e2e_latency"]
def compute_ttft(self, measures: dict) -> float:
return measures["t_tokens"][0] - measures["wall_time_start"]
def compute_itl(self, measures: dict) -> Optional[float]:
if len(measures["t_tokens"]) < 2:
return None
delta_t = measures["t_tokens"][-1] - measures["t_tokens"][0]
num_tokens = len(measures["t_tokens"]) - 1
return delta_t / num_tokens
def get_bar_plot_data(self) -> dict:
per_scenario_data = {}
for i, (cfg_name, data) in enumerate(self.data.items()):
per_scenario_data[cfg_name] = {
"ttft": [self.compute_ttft(d) for d in data["measures"]],
"itl": [self.compute_itl(d) for d in data["measures"]],
"config": data["metadata"]["config"],
}
return per_scenario_data