| import copy |
| import os |
| import torch |
| from typing import Any, Dict |
|
|
| from ..jit import build, cpp_format, generate, Runtime |
|
|
|
|
| class JITTuner: |
| def __init__(self) -> None: |
| self.tuned = {} |
|
|
| def compile_and_tune(self, name: str, keys: Dict[str, Any], space: tuple, |
| includes: tuple, arg_defs: tuple, template: str, args: tuple) -> Runtime: |
| |
| |
| |
| keys = {k: keys[k] for k in sorted(keys.keys())} |
| signature = (name, f'{keys}') |
| if signature in self.tuned: |
| if os.getenv('DG_JIT_DEBUG', None): |
| print(f'Using cached JIT kernel {name} with keys {keys}') |
| return self.tuned[signature] |
|
|
| if os.getenv('DG_JIT_DEBUG', None): |
| print(f'Auto-tuning JIT kernel {name} with keys {keys}') |
|
|
| assert signature not in self.tuned |
| assert args is not None |
| space = (dict(), ) if len(space) == 0 else space |
|
|
| kernels = [] |
| for tuned_keys in space: |
| assert isinstance(tuned_keys, dict) |
| full_keys = copy.deepcopy(keys) |
| full_keys.update(tuned_keys) |
| code = generate(includes, arg_defs, cpp_format(template, full_keys)) |
|
|
| |
| kernels.append((build(name, arg_defs, code), tuned_keys)) |
|
|
| best_runtime, best_time, best_keys = None, None, None |
| for runtime, tuned_keys in kernels: |
| if len(space) > 1: |
| |
| return_code = runtime(*args) |
| if return_code != 0: |
| |
| if os.getenv('DG_JIT_DEBUG', None): |
| print(f'Illegal JIT kernel {name} with keys {keys} and tuned keys {tuned_keys}: error code {return_code}') |
| continue |
|
|
| |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() |
| torch.randn((8192, 8192), dtype=torch.float, device='cuda') @ torch.randn((8192, 8192), dtype=torch.float, device='cuda') |
| start_event.record() |
| for i in range(20): |
| assert runtime(*args) == 0 |
| end_event.record() |
| end_event.synchronize() |
| elapsed_time = start_event.elapsed_time(end_event) |
| else: |
| elapsed_time = 0 |
|
|
| |
| if best_time is None or elapsed_time < best_time: |
| best_runtime, best_time, best_keys = runtime, elapsed_time, tuned_keys |
| if os.getenv('DG_JIT_DEBUG', None): |
| print(f'Tuned JIT kernel {name} with keys {keys} and tuned keys {tuned_keys} has time {elapsed_time}') |
| assert best_runtime is not None, f'Failed to tune JIT kernel {name} with keys {keys}' |
|
|
| |
| if os.getenv('DG_JIT_DEBUG', None) or os.getenv('DG_PRINT_AUTOTUNE', None): |
| print(f'Best JIT kernel {name} with keys {keys} has tuned keys {best_keys} and time {best_time}') |
| self.tuned[signature] = best_runtime |
| return best_runtime |
|
|
|
|
| jit_tuner = JITTuner() |
|
|