|
|
|
|
|
|
|
|
import itertools |
|
|
|
|
|
import torch |
|
|
from torch.autograd.profiler_legacy import profile |
|
|
from typing import List |
|
|
|
|
|
from . import ( |
|
|
_disable_server_process_global_profiler, |
|
|
_enable_server_process_global_profiler, |
|
|
) |
|
|
|
|
|
__all__: List[str] = [] |
|
|
|
|
|
class _server_process_global_profile(profile): |
|
|
""" |
|
|
It has the same API as ``torch.autograd.profiler.profile`` class, |
|
|
except that it enables profiling on all threads running RPC server request callbacks. |
|
|
|
|
|
Context manager that manages autograd profiler state and holds a summary of results. |
|
|
Under the hood it just records events of functions being executed in C++ and |
|
|
exposes those events to Python. You can wrap any code into it and it will |
|
|
only report runtime of PyTorch functions. |
|
|
Note: profiler is thread local and is automatically propagated into the async tasks |
|
|
|
|
|
Args: |
|
|
enabled (bool, optional): Setting this to False makes this context manager a no-op. |
|
|
Default: ``True``. |
|
|
|
|
|
use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. |
|
|
Adds approximately 4us of overhead to each tensor operation. |
|
|
Default: ``False`` |
|
|
|
|
|
record_shapes (bool, optional): If shapes recording is set, information |
|
|
about input dimensions will be collected. This allows one to see which |
|
|
dimensions have been used under the hood and further group by them |
|
|
using prof.key_averages(group_by_input_shape=True). Please note that |
|
|
shape recording might skew your profiling data. It is recommended to |
|
|
use separate runs with and without shape recording to validate the timing. |
|
|
Most likely the skew will be negligible for bottom most events (in a case |
|
|
of nested function calls). But for higher level functions the total |
|
|
self cpu time might be artificially increased because of the shape |
|
|
collection. |
|
|
|
|
|
profile_memory (bool, optional): Whether to report memory usage, default: ``False`` |
|
|
|
|
|
.. warning: |
|
|
Enabling memory profiling incurs additional profiler overhead |
|
|
|
|
|
.. warning: |
|
|
Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), |
|
|
one cannot use the profiler with ``use_cuda = True`` to benchmark |
|
|
DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, |
|
|
please use ``use_cuda = False`` or ``num_workers = 0``. |
|
|
|
|
|
Example: |
|
|
>>> # xdoctest: +SKIP |
|
|
>>> # On worker 0: |
|
|
>>> import torch |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> rpc.init_rpc("worker0", rank=0, world_size=2) |
|
|
>>> x, y = torch.tensor(1), torch.tensor(2) |
|
|
>>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) |
|
|
>>> outer_profile_rref.rpc_sync().__enter__() |
|
|
>>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) |
|
|
>>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) |
|
|
>>> inner_profile_rref.rpc_sync().__enter__() |
|
|
>>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) |
|
|
>>> inner_profile_rref.rpc_sync().__exit__(None, None, None) |
|
|
>>> outer_profile_rref.rpc_sync().__exit__(None, None, None) |
|
|
>>> print(inner_profile_rref.rpc_sync().key_averages()) |
|
|
--------- --------------- --------------- --------------- --------------- --------------- --------------- |
|
|
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls |
|
|
--------- --------------- --------------- --------------- --------------- --------------- --------------- |
|
|
sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 |
|
|
empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 |
|
|
--------- --------------- --------------- --------------- --------------- --------------- --------------- |
|
|
Self CPU time total: 89.667us |
|
|
>>> print(outer_profile_rref.rpc_sync().key_averages()) |
|
|
--------- --------------- --------------- --------------- --------------- --------------- --------------- |
|
|
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls |
|
|
--------- --------------- --------------- --------------- --------------- --------------- --------------- |
|
|
sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 |
|
|
empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 |
|
|
add 51.68% 110.550us 58.09% 124.259us 124.259us 1 |
|
|
--------- --------------- --------------- --------------- --------------- --------------- --------------- |
|
|
Self CPU time total: 213.926us |
|
|
>>> rpc.shutdown() |
|
|
|
|
|
>>> # On worker 1: |
|
|
>>> import torch.distributed.rpc as rpc |
|
|
>>> rpc.init_rpc("worker1", rank=1, world_size=2) |
|
|
>>> # wait for worker 0 to finish work, and then shutdown. |
|
|
>>> rpc.shutdown() |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def __enter__(self): |
|
|
""" |
|
|
Turn on server-side process-global profiling. |
|
|
This enables thread-local profiler on all RPC threads running server-side request callbacks. |
|
|
""" |
|
|
if not self.enabled: |
|
|
return |
|
|
|
|
|
if self.entered: |
|
|
raise RuntimeError("autograd profiler traces are not reentrant") |
|
|
self.entered = True |
|
|
|
|
|
profiler_kind = ( |
|
|
torch.autograd.ProfilerState.CUDA |
|
|
if self.use_cuda |
|
|
else torch.autograd.ProfilerState.CPU |
|
|
) |
|
|
profiler_config = torch.autograd.ProfilerConfig( |
|
|
profiler_kind, |
|
|
self.record_shapes, |
|
|
self.profile_memory, |
|
|
False, |
|
|
False, |
|
|
False, |
|
|
torch.profiler._ExperimentalConfig()) |
|
|
_enable_server_process_global_profiler(profiler_config) |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
""" |
|
|
Turn off server-side process-global profiling. |
|
|
Aggregate all profiling events recorded by RPC threads. |
|
|
|
|
|
These attributes are assigned on exiting context. |
|
|
|
|
|
Attributes: |
|
|
function_events (torch.autograd.profiler.EventList). It's a list that has helper |
|
|
methods, like 1) show record items in a pretty-print table. |
|
|
2) do averaging by grouping on keys. 3) and more. |
|
|
|
|
|
process_global_function_events (List[torch.autograd.profiler.FunctionEvent]). |
|
|
It's a list of ``FunctionEvent`` elements. Every element is a profiling result |
|
|
of an RPC request handling within the profiling range. |
|
|
""" |
|
|
if not self.enabled: |
|
|
return |
|
|
|
|
|
process_global_events = _disable_server_process_global_profiler() |
|
|
|
|
|
|
|
|
process_global_function_events = [] |
|
|
for thread_local_events in process_global_events: |
|
|
|
|
|
thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records( |
|
|
thread_local_events |
|
|
) |
|
|
thread_local_function_events.sort( |
|
|
key=lambda function_event: [ |
|
|
function_event.time_range.start, |
|
|
-(function_event.time_range.end), |
|
|
] |
|
|
) |
|
|
process_global_function_events.append(thread_local_function_events) |
|
|
|
|
|
flattened_function_events = list( |
|
|
itertools.chain(*process_global_function_events) |
|
|
) |
|
|
self.function_events = torch.autograd.profiler_util.EventList( |
|
|
flattened_function_events, |
|
|
use_cuda=self.use_cuda, |
|
|
profile_memory=self.profile_memory, |
|
|
) |
|
|
self.function_events._build_tree() |
|
|
|
|
|
self.process_global_function_events = process_global_function_events |
|
|
|
|
|
return False |
|
|
|