File size: 4,595 Bytes
44823a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import threading
from collections.abc import Sequence
from typing import Any, cast, Optional, Union
import torch
from torch._utils import ExceptionWrapper
from torch.cuda._utils import _get_device_index
from torch.nn.modules import Module
__all__ = ["get_a_var", "parallel_apply"]
def get_a_var(
obj: Union[torch.Tensor, list[Any], tuple[Any, ...], dict[Any, Any]],
) -> Optional[torch.Tensor]:
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, (list, tuple)):
for result in map(get_a_var, obj):
if isinstance(result, torch.Tensor):
return result
if isinstance(obj, dict):
for result in map(get_a_var, obj.items()):
if isinstance(result, torch.Tensor):
return result
return None
def parallel_apply(
modules: Sequence[Module],
inputs: Sequence[Any],
kwargs_tup: Optional[Sequence[dict[str, Any]]] = None,
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
) -> list[Any]:
r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
Args:
modules (Module): modules to be parallelized
inputs (tensor): inputs to the modules
devices (list of int or torch.device): CUDA devices
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs), (
f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = (cast(dict[str, Any], {}),) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
devices = [_get_device_index(x, True) for x in devices]
streams = [torch.cuda.current_stream(x) for x in devices]
lock = threading.Lock()
results = {}
grad_enabled, autocast_enabled = (
torch.is_grad_enabled(),
torch.is_autocast_enabled(),
)
def _worker(
i: int,
module: Module,
input: Any,
kwargs: dict[str, Any],
device: Optional[Union[int, torch.device]] = None,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
torch.set_grad_enabled(grad_enabled)
if device is None:
t = get_a_var(input)
if t is None:
with lock:
results[i] = ExceptionWrapper(
where=f"in replica {i}, no device was provided and no tensor input was found; "
"device cannot be resolved"
)
return
device = t.get_device()
if stream is None:
stream = torch.cuda.current_stream(device)
try:
with (
torch.cuda.device(device),
torch.cuda.stream(stream),
torch.amp.autocast("cuda", enabled=autocast_enabled),
):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = module(*input, **kwargs)
with lock:
results[i] = output
except Exception:
with lock:
results[i] = ExceptionWrapper(
where=f"in replica {i} on device {device}"
)
if len(modules) > 1:
threads = [
threading.Thread(
target=_worker, args=(i, module, input, kwargs, device, stream)
)
for i, (module, input, kwargs, device, stream) in enumerate(
zip(modules, inputs, kwargs_tup, devices, streams)
)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, ExceptionWrapper):
output.reraise()
outputs.append(output)
return outputs
|