| | import contextlib |
| | import sys |
| | import time |
| |
|
| | import torch |
| |
|
| | if sys.version_info >= (3, 7): |
| |
|
| | @contextlib.contextmanager |
| | def profile_time(trace_name, |
| | name, |
| | enabled=True, |
| | stream=None, |
| | end_stream=None): |
| | """Print time spent by CPU and GPU. |
| | |
| | Useful as a temporary context manager to find sweet spots of code |
| | suitable for async implementation. |
| | """ |
| | if (not enabled) or not torch.cuda.is_available(): |
| | yield |
| | return |
| | stream = stream if stream else torch.cuda.current_stream() |
| | end_stream = end_stream if end_stream else stream |
| | start = torch.cuda.Event(enable_timing=True) |
| | end = torch.cuda.Event(enable_timing=True) |
| | stream.record_event(start) |
| | try: |
| | cpu_start = time.monotonic() |
| | yield |
| | finally: |
| | cpu_end = time.monotonic() |
| | end_stream.record_event(end) |
| | end.synchronize() |
| | cpu_time = (cpu_end - cpu_start) * 1000 |
| | gpu_time = start.elapsed_time(end) |
| | msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms ' |
| | msg += f'gpu_time {gpu_time:.2f} ms stream {stream}' |
| | print(msg, end_stream) |
| |
|