|
|
|
|
|
import torch |
|
|
import triton |
|
|
from typing import Callable, List, Union, Tuple, Dict |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
|
|
|
|
|
|
class do_bench_config(): |
|
|
def __init__( |
|
|
self, |
|
|
warm_up: int = 25, |
|
|
repetition: int = 100, |
|
|
quantiles: List[float] = None, |
|
|
return_mode: str = "median" |
|
|
): |
|
|
self.warm_up = warm_up |
|
|
self.repetition = repetition |
|
|
self.quantiles = quantiles if quantiles is not None else [0.5, 0.8, 0.2] |
|
|
self.return_mode = return_mode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PYTEST_BENCHMARK_RESULTS = {} |
|
|
|
|
|
def add_benchmark_result(op_name: str, result_dict: Dict): |
|
|
if op_name not in PYTEST_BENCHMARK_RESULTS: |
|
|
PYTEST_BENCHMARK_RESULTS[op_name] = [] |
|
|
PYTEST_BENCHMARK_RESULTS[op_name].append(result_dict) |
|
|
|
|
|
def save_all_benchmark_results(output_dir: str): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
for op_name, results_list in PYTEST_BENCHMARK_RESULTS.items(): |
|
|
if results_list: |
|
|
file_name = op_name + ".json" |
|
|
file_path = os.path.join(output_dir, file_name) |
|
|
try: |
|
|
with open(file_path, 'w', encoding='utf8') as f: |
|
|
json.dump(results_list, f, indent=4, ensure_ascii=False) |
|
|
print(f"Benchmark results for {op_name} saved to: {file_path}") |
|
|
except IOError as e: |
|
|
print(f"Error saving results for {op_name} to {file_path}: {e}") |
|
|
PYTEST_BENCHMARK_RESULTS.clear() |
|
|
|
|
|
|
|
|
class PytestBenchmarker: |
|
|
def __init__(self, op_callable: Callable, op_name: str, config: do_bench_config = None): |
|
|
self.op_callable = op_callable |
|
|
self.op_name = op_name |
|
|
self.config = config if config else do_bench_config() |
|
|
|
|
|
def run_benchmark(self, current_params_dict: Dict, |
|
|
gbps_calculator: Callable = None, |
|
|
tflops_calculator: Callable = None |
|
|
): |
|
|
""" |
|
|
Runs the benchmark for the op_callable (which should be a lambda with inputs already bound). |
|
|
current_params_dict: Dictionary describing the current pytest parameters. |
|
|
gbps_calculator: A function that takes (original_inputs_tuple, ms) and returns GB/s. |
|
|
tflops_calculator: A function that takes (original_inputs_tuple, ms) and returns TFLOPS. |
|
|
""" |
|
|
try: |
|
|
ms, min_ms, max_ms = triton.testing.do_bench( |
|
|
self.op_callable, |
|
|
warmup=self.config.warm_up, |
|
|
rep=self.config.repetition, |
|
|
quantiles=self.config.quantiles, |
|
|
return_mode=self.config.return_mode |
|
|
) |
|
|
|
|
|
gbps = "N/A" |
|
|
if gbps_calculator: |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
gbps = gbps_calculator(current_params_dict, ms) |
|
|
except Exception as e_gbps: |
|
|
print(f"Warning: GB/s calculation failed for {self.op_name} with params {current_params_dict}: {e_gbps}") |
|
|
|
|
|
|
|
|
tflops = "N/A" |
|
|
if tflops_calculator: |
|
|
try: |
|
|
tflops = tflops_calculator(current_params_dict, ms) |
|
|
except Exception as e_tflops: |
|
|
print(f"Warning: TFLOPS calculation failed for {self.op_name} with params {current_params_dict}: {e_tflops}") |
|
|
|
|
|
|
|
|
result = { |
|
|
"params": current_params_dict, |
|
|
"ms": round(ms, 4), |
|
|
"min_ms": round(min_ms, 4), |
|
|
"max_ms": round(max_ms, 4), |
|
|
"GB/s": round(gbps, 2) if isinstance(gbps, (float, int)) else gbps, |
|
|
"TFLOPS": round(tflops, 2) if isinstance(tflops, (float, int)) else tflops, |
|
|
} |
|
|
|
|
|
add_benchmark_result(self.op_name, result) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during benchmark for {self.op_name} with params {current_params_dict}: {e}") |
|
|
error_result = { |
|
|
"params": current_params_dict, |
|
|
"error": str(e) |
|
|
} |
|
|
add_benchmark_result(self.op_name, error_result) |
|
|
return error_result |