File size: 5,465 Bytes
02c783d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved.
import torch
import triton
from typing import Callable, List, Union, Tuple, Dict
import json
import os
import time # For unique filenames if needed

# --- do_bench_config class (can remain largely the same) ---
class do_bench_config():
    def __init__(
        self,
        warm_up: int = 25, # Default values for individual benchmarks
        repetition: int = 100,
        quantiles: List[float] = None,
        return_mode: str = "median"
    ):
        self.warm_up = warm_up
        self.repetition = repetition
        self.quantiles = quantiles if quantiles is not None else [0.5, 0.8, 0.2]
        self.return_mode = return_mode

# --- Global Collector for results ---
# This is a common way to gather data from pytest runs.
# It needs to be managed carefully, especially with parallel pytest (xdist).
# For non-parallel pytest, it's simpler.
# Let's assume non-parallel pytest for now for simplicity of collection.
# If using pytest-xdist, inter-process communication for collection is needed.

PYTEST_BENCHMARK_RESULTS = {} # Key: op_name, Value: list of result dicts

def add_benchmark_result(op_name: str, result_dict: Dict):
    if op_name not in PYTEST_BENCHMARK_RESULTS:
        PYTEST_BENCHMARK_RESULTS[op_name] = []
    PYTEST_BENCHMARK_RESULTS[op_name].append(result_dict)

def save_all_benchmark_results(output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    for op_name, results_list in PYTEST_BENCHMARK_RESULTS.items():
        if results_list: # Only save if there are results
            file_name = op_name + ".json" # e.g., add_kernel.json
            file_path = os.path.join(output_dir, file_name)
            try:
                with open(file_path, 'w', encoding='utf8') as f:
                    json.dump(results_list, f, indent=4, ensure_ascii=False)
                print(f"Benchmark results for {op_name} saved to: {file_path}")
            except IOError as e:
                print(f"Error saving results for {op_name} to {file_path}: {e}")
    PYTEST_BENCHMARK_RESULTS.clear() # Clear after saving

# --- Simplified Benchmarking Helper ---
class PytestBenchmarker:
    def __init__(self, op_callable: Callable, op_name: str, config: do_bench_config = None):
        self.op_callable = op_callable # This will be the kernel launch lambda
        self.op_name = op_name # e.g., "add_kernel" or "test_add_SIZE_BLOCK_SIZE_dtype"
        self.config = config if config else do_bench_config()

    def run_benchmark(self, current_params_dict: Dict,
                      gbps_calculator: Callable = None, # (inputs, ms) -> gbps
                      tflops_calculator: Callable = None # (inputs, ms) -> tflops
                     ):
        """
        Runs the benchmark for the op_callable (which should be a lambda with inputs already bound).
        current_params_dict: Dictionary describing the current pytest parameters.
        gbps_calculator: A function that takes (original_inputs_tuple, ms) and returns GB/s.
        tflops_calculator: A function that takes (original_inputs_tuple, ms) and returns TFLOPS.
        """
        try:
            ms, min_ms, max_ms = triton.testing.do_bench(
                self.op_callable,
                warmup=self.config.warm_up,
                rep=self.config.repetition,
                quantiles=self.config.quantiles,
                return_mode=self.config.return_mode
            )

            gbps = "N/A"
            if gbps_calculator:
                try:
                    # gbps_calculator needs access to the input tensors that define the problem size
                    # The caller of run_benchmark will need to provide these if they are not part of current_params_dict
                    # For simplicity, let's assume current_params_dict contains enough info, or calculator can access them
                    gbps = gbps_calculator(current_params_dict, ms)
                except Exception as e_gbps:
                    print(f"Warning: GB/s calculation failed for {self.op_name} with params {current_params_dict}: {e_gbps}")


            tflops = "N/A"
            if tflops_calculator:
                try:
                    tflops = tflops_calculator(current_params_dict, ms)
                except Exception as e_tflops:
                    print(f"Warning: TFLOPS calculation failed for {self.op_name} with params {current_params_dict}: {e_tflops}")


            result = {
                "params": current_params_dict, # From pytest.mark.parametrize
                "ms": round(ms, 4),
                "min_ms": round(min_ms, 4),
                "max_ms": round(max_ms, 4),
                "GB/s": round(gbps, 2) if isinstance(gbps, (float, int)) else gbps,
                "TFLOPS": round(tflops, 2) if isinstance(tflops, (float, int)) else tflops,
            }
            # Use a unique op_name for each kernel/function being benchmarked
            add_benchmark_result(self.op_name, result)
            # print(f"Benchmarked {self.op_name} with {current_params_dict}: {ms:.4f} ms")
            return result

        except Exception as e:
            print(f"Error during benchmark for {self.op_name} with params {current_params_dict}: {e}")
            error_result = {
                "params": current_params_dict,
                "error": str(e)
            }
            add_benchmark_result(self.op_name, error_result)
            return error_result