| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import math |
| | from types import MethodType |
| | from typing import Any, Optional, Union |
| |
|
| | from .state import PartialState |
| | from .utils import ( |
| | calculate_maximum_sizes, |
| | convert_bytes, |
| | copy_tensor_to_devices, |
| | ignorant_find_batch_size, |
| | infer_auto_device_map, |
| | is_pippy_available, |
| | pad_input_tensors, |
| | send_to_device, |
| | ) |
| |
|
| |
|
| | def generate_device_map( |
| | model, num_processes: int = 1, no_split_module_classes=None, max_memory: Optional[dict] = None |
| | ): |
| | """ |
| | Calculates the device map for `model` with an offset for PiPPy |
| | """ |
| | if num_processes == 1: |
| | return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) |
| | if max_memory is None: |
| | model_size, shared = calculate_maximum_sizes(model) |
| |
|
| | |
| | memory = (model_size + shared[0]) / num_processes |
| | memory = convert_bytes(memory) |
| | value, ending = memory.split(" ") |
| |
|
| | |
| | memory = math.ceil(float(value)) * 1.1 |
| | memory = f"{memory} {ending}" |
| | max_memory = {i: memory for i in range(num_processes)} |
| | device_map = infer_auto_device_map( |
| | model, |
| | max_memory=max_memory, |
| | no_split_module_classes=no_split_module_classes, |
| | clean_result=False, |
| | ) |
| | return device_map |
| |
|
| |
|
| | def find_pippy_batch_size(args, kwargs): |
| | found_batch_size = None |
| | if args is not None: |
| | for arg in args: |
| | found_batch_size = ignorant_find_batch_size(arg) |
| | if found_batch_size is not None: |
| | break |
| | if kwargs is not None and found_batch_size is None: |
| | for kwarg in kwargs.values(): |
| | found_batch_size = ignorant_find_batch_size(kwarg) |
| | if found_batch_size is not None: |
| | break |
| | return found_batch_size |
| |
|
| |
|
| | def build_pipeline(model, split_points, args, kwargs, num_chunks): |
| | """ |
| | Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing |
| | in needed `args` and `kwargs` as the model needs on the CPU. |
| | |
| | Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use |
| | `AcceleratorState.num_processes` |
| | """ |
| | |
| | from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline |
| |
|
| | |
| | state = PartialState() |
| | split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points} |
| | pipe = pipeline( |
| | model, |
| | mb_args=args, |
| | mb_kwargs=kwargs, |
| | split_spec=split_spec, |
| | ) |
| | stage = pipe.build_stage(state.local_process_index, device=state.device) |
| | schedule = ScheduleGPipe(stage, num_chunks) |
| |
|
| | return schedule |
| |
|
| |
|
| | def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): |
| | state = PartialState() |
| | output = None |
| |
|
| | if state.num_processes == 1: |
| | output = forward(*args, **kwargs) |
| | elif state.is_local_main_process: |
| | found_batch_size = find_pippy_batch_size(args, kwargs) |
| | if found_batch_size is None: |
| | raise ValueError("Could not find batch size from args or kwargs") |
| | else: |
| | if found_batch_size != num_chunks: |
| | args = pad_input_tensors(args, found_batch_size, num_chunks) |
| | kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) |
| | forward(*args, **kwargs) |
| | elif state.is_last_process: |
| | output = forward() |
| | else: |
| | forward() |
| | if gather_output: |
| | |
| | output = copy_tensor_to_devices(output) |
| | return output |
| |
|
| |
|
| | def prepare_pippy( |
| | model, |
| | split_points: Optional[Union[str, list[str]]] = "auto", |
| | no_split_module_classes: Optional[list[str]] = None, |
| | example_args: Optional[tuple[Any]] = (), |
| | example_kwargs: Optional[dict[str, Any]] = None, |
| | num_chunks: Optional[int] = None, |
| | gather_output: Optional[bool] = False, |
| | ): |
| | """ |
| | Wraps `model` for pipeline parallel inference. |
| | |
| | Args: |
| | model (`torch.nn.Module`): |
| | A model we want to split for pipeline-parallel inference |
| | split_points (`str` or `List[str]`, defaults to 'auto'): |
| | How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced |
| | split given any model. Should be a list of layer names in the model to split by otherwise. |
| | no_split_module_classes (`List[str]`): |
| | A list of class names for layers we don't want to be split. |
| | example_args (tuple of model inputs): |
| | The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use |
| | this method if possible. |
| | example_kwargs (dict of model inputs) |
| | The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a |
| | *highly* limiting structure that requires the same keys be present at *all* inference calls. Not |
| | recommended unless the prior condition is true for all cases. |
| | num_chunks (`int`, defaults to the number of available GPUs): |
| | The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but |
| | this can be tuned and played with. In general one should have num_chunks >= num_gpus. |
| | gather_output (`bool`, defaults to `False`): |
| | If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. |
| | """ |
| | if not is_pippy_available(): |
| | raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.") |
| | state = PartialState() |
| | example_args = send_to_device(example_args, "cpu") |
| | example_kwargs = send_to_device(example_kwargs, "cpu") |
| | if num_chunks is None: |
| | num_chunks = state.num_processes |
| | if split_points == "auto": |
| | device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes) |
| | split_points = [] |
| | for i in range(1, num_chunks): |
| | split_points.append(next(k for k, v in device_map.items() if v == i)) |
| | model.hf_split_points = split_points |
| | stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks) |
| | model._original_forward = model.forward |
| | model._original_call = model.__call__ |
| | model.pippy_stage = stage |
| | model.hf_split_points = split_points |
| |
|
| | def forward(*args, **kwargs): |
| | return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs) |
| |
|
| | |
| | |
| | model_forward = MethodType(forward, model) |
| | forward.__wrapped__ = model_forward |
| | model.forward = forward |
| | return model |
| |
|