| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """The pipeline parallelism of Pipe.""" |
| | from queue import Queue |
| | from types import TracebackType |
| | from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence |
| |
|
| | import torch |
| | from torch import Tensor, nn |
| | from torch.autograd.profiler import record_function |
| |
|
| | from .checkpoint import Checkpointing |
| | from .copy import Copy, Wait |
| | from .dependency import fork, join |
| | from .microbatch import Batch |
| | from .skip.layout import SkipLayout |
| | from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker |
| | from .stream import AbstractStream, current_stream, use_device |
| | from .worker import Task, create_workers |
| |
|
| | __all__: List[str] = [] |
| |
|
| |
|
| | Tensors = Sequence[Tensor] |
| | TensorOrTensors = Union[Tensor, Tensors] |
| |
|
| | ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] |
| |
|
| | |
| | |
| | if TYPE_CHECKING: |
| | InQueue = Queue[Optional["Task"]] |
| | OutQueue = Queue[Tuple[bool, Union[Tuple["Task", Batch], ExcInfo, None]]] |
| | else: |
| | InQueue = Queue |
| | OutQueue = Queue |
| |
|
| |
|
| | def _depend(fork_from: Batch, join_to: Batch) -> None: |
| | fork_from_idx = fork_from.find_tensor_idx() |
| | join_to_idx = join_to.find_tensor_idx() |
| |
|
| | fork_from[fork_from_idx], phony = fork(fork_from[fork_from_idx]) |
| | join_to[join_to_idx] = join(join_to[join_to_idx], phony) |
| |
|
| |
|
| | def _copy(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: |
| | batch[:] = Copy.apply(prev_stream, next_stream, *batch) |
| | |
| | batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) |
| |
|
| |
|
| | def _wait(batch: Batch, prev_stream: AbstractStream, next_stream: AbstractStream) -> None: |
| | batch[:] = Wait.apply(prev_stream, next_stream, *batch) |
| | |
| | batch[:] = tuple([x.detach() if torch.is_tensor(x) and not x.is_floating_point() else x for x in batch]) |
| |
|
| |
|
| | def _clock_cycles(m: int, n: int) -> Iterable[List[Tuple[int, int]]]: |
| | """Generates schedules for each clock cycle.""" |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for k in range(m + n - 1): |
| | yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] |
| |
|
| |
|
| | class Pipeline: |
| | """The pipeline parallelism for Pipe.""" |
| |
|
| | def __init__( |
| | self, |
| | partitions: List[nn.Sequential], |
| | devices: List[torch.device], |
| | copy_streams: List[List[AbstractStream]], |
| | skip_layout: SkipLayout, |
| | checkpoint_stop: int, |
| | ) -> None: |
| | self.partitions = partitions |
| | self.devices = devices |
| | self.copy_streams = copy_streams |
| | self.skip_layout = skip_layout |
| | self.checkpoint_stop = checkpoint_stop |
| | (self.in_queues, self.out_queues) = create_workers(devices) |
| |
|
| | def run(self, batches: List[Batch]) -> None: |
| | """Runs pipeline parallelism. |
| | |
| | It modifies the given batches in place. |
| | |
| | """ |
| | partitions = self.partitions |
| | devices = self.devices |
| | skip_layout = self.skip_layout |
| |
|
| | m = len(batches) |
| | n = len(partitions) |
| |
|
| | skip_trackers = [SkipTrackerThroughPotals(skip_layout) for _ in batches] |
| |
|
| | for schedule in _clock_cycles(m, n): |
| | self.fence(batches, schedule, skip_trackers) |
| | self.compute(batches, schedule, skip_trackers) |
| |
|
| | def fence( |
| | self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], |
| | ) -> None: |
| | """Copies micro-batches after computation for the previous |
| | micro-batches. |
| | """ |
| | copy_streams = self.copy_streams |
| | skip_layout = self.skip_layout |
| |
|
| | for i, j in schedule: |
| | |
| | |
| | if i != 0 and j != 0: |
| | _depend(batches[i - 1], batches[i]) |
| |
|
| | next_stream = copy_streams[j][i] |
| |
|
| | for prev_j, ns, name in skip_layout.copy_policy(j): |
| | prev_stream = copy_streams[prev_j][i] |
| | skip_trackers[i].copy(batches[i], prev_stream, next_stream, ns, name) |
| |
|
| | if j != 0: |
| | prev_stream = copy_streams[j - 1][i] |
| | _copy(batches[i], prev_stream, next_stream) |
| |
|
| | def compute( |
| | self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], |
| | ) -> None: |
| | """Runs tasks with synchronization to copy streams.""" |
| | partitions = self.partitions |
| | devices = self.devices |
| | copy_streams = self.copy_streams |
| | checkpoint_stop = self.checkpoint_stop |
| |
|
| | |
| | if not self.partitions[0].training: |
| | checkpoint_stop = 0 |
| |
|
| | n = len(partitions) |
| | streams = [current_stream(d) for d in devices] |
| | exc_info: Optional[ExcInfo] = None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | for i, j in schedule: |
| | batch = batches[i] |
| | partition = partitions[j] |
| |
|
| | |
| | if j != 0: |
| | _wait(batch, copy_streams[j][i], streams[j]) |
| |
|
| | |
| | checkpoint = i < checkpoint_stop |
| | if checkpoint: |
| |
|
| | def function( |
| | *inputs, |
| | partition: nn.Module = partition, |
| | skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], |
| | chunk_id: int = i, |
| | part_id: int = j, |
| | ) -> TensorOrTensors: |
| | with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): |
| | return partition(*inputs) |
| |
|
| | chk = Checkpointing(function, batch) |
| | task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) |
| | del function, chk |
| |
|
| | else: |
| |
|
| | def compute( |
| | batch: Batch = batch, |
| | partition: nn.Module = partition, |
| | skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], |
| | chunk_id: int = i, |
| | part_id: int = j, |
| | ) -> Batch: |
| | with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): |
| | return batch.call(partition) |
| |
|
| | task = Task(streams[j], compute=compute, finalize=None) |
| | del compute |
| |
|
| | |
| | self.in_queues[j].put(task) |
| |
|
| | for i, j in schedule: |
| | ok, payload = self.out_queues[j].get() |
| |
|
| | |
| | if exc_info is not None: |
| | continue |
| | elif not ok: |
| | exc_info = cast(ExcInfo, payload) |
| | continue |
| |
|
| | task, batch = cast(Tuple[Task, Batch], payload) |
| |
|
| | |
| | |
| | if j != n - 1: |
| | _wait(batch, streams[j], copy_streams[j][i]) |
| |
|
| | |
| | |
| | |
| | with use_device(devices[j]): |
| | task.finalize(batch) |
| |
|
| | batches[i] = batch |
| |
|
| | |
| | if exc_info is not None: |
| | raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) |
| |
|