| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
| from types import MethodType |
| from typing import Any, Optional, Union |
|
|
| from .state import PartialState |
| from .utils import ( |
| calculate_maximum_sizes, |
| convert_bytes, |
| copy_tensor_to_devices, |
| ignorant_find_batch_size, |
| infer_auto_device_map, |
| is_pippy_available, |
| pad_input_tensors, |
| send_to_device, |
| ) |
|
|
|
|
| def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None): |
| """ |
| Calculates the device map for `model` with an offset for PiPPy |
| """ |
| if num_processes == 1: |
| return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) |
| if max_memory is None: |
| model_size, shared = calculate_maximum_sizes(model) |
|
|
| |
| memory = (model_size + shared[0]) / num_processes |
| memory = convert_bytes(memory) |
| value, ending = memory.split(" ") |
|
|
| |
| memory = math.ceil(float(value)) * 1.1 |
| memory = f"{memory} {ending}" |
| max_memory = {i: memory for i in range(num_processes)} |
| device_map = infer_auto_device_map( |
| model, |
| max_memory=max_memory, |
| no_split_module_classes=no_split_module_classes, |
| clean_result=False, |
| ) |
| return device_map |
|
|
|
|
| def find_pippy_batch_size(args, kwargs): |
| found_batch_size = None |
| if args is not None: |
| for arg in args: |
| found_batch_size = ignorant_find_batch_size(arg) |
| if found_batch_size is not None: |
| break |
| if kwargs is not None and found_batch_size is None: |
| for kwarg in kwargs.values(): |
| found_batch_size = ignorant_find_batch_size(kwarg) |
| if found_batch_size is not None: |
| break |
| return found_batch_size |
|
|
|
|
| def build_pipeline(model, split_points, args, kwargs, num_chunks): |
| """ |
| Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing |
| in needed `args` and `kwargs` as the model needs on the CPU. |
| |
| Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use |
| `AcceleratorState.num_processes` |
| """ |
| |
| from torch.distributed.pipelining import ScheduleGPipe, SplitPoint, pipeline |
|
|
| |
| state = PartialState() |
| split_spec = {split_point: SplitPoint.BEGINNING for split_point in split_points} |
| pipe = pipeline( |
| model, |
| mb_args=args, |
| mb_kwargs=kwargs, |
| split_spec=split_spec, |
| ) |
| stage = pipe.build_stage(state.local_process_index, device=state.device) |
| schedule = ScheduleGPipe(stage, num_chunks) |
|
|
| return schedule |
|
|
|
|
| def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): |
| state = PartialState() |
| output = None |
|
|
| if state.num_processes == 1: |
| output = forward(*args, **kwargs) |
| elif state.is_local_main_process: |
| found_batch_size = find_pippy_batch_size(args, kwargs) |
| if found_batch_size is None: |
| raise ValueError("Could not find batch size from args or kwargs") |
| else: |
| if found_batch_size != num_chunks: |
| args = pad_input_tensors(args, found_batch_size, num_chunks) |
| kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) |
| forward(*args, **kwargs) |
| elif state.is_last_process: |
| output = forward() |
| else: |
| forward() |
| if gather_output: |
| |
| output = copy_tensor_to_devices(output) |
| return output |
|
|
|
|
| def prepare_pippy( |
| model, |
| split_points: Optional[Union[str, list[str]]] = "auto", |
| no_split_module_classes: Optional[list[str]] = None, |
| example_args: Optional[tuple[Any]] = (), |
| example_kwargs: Optional[dict[str, Any]] = None, |
| num_chunks: Optional[int] = None, |
| gather_output: Optional[bool] = False, |
| ): |
| """ |
| Wraps `model` for pipeline parallel inference. |
| |
| Args: |
| model (`torch.nn.Module`): |
| A model we want to split for pipeline-parallel inference |
| split_points (`str` or `List[str]`, defaults to 'auto'): |
| How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced |
| split given any model. Should be a list of layer names in the model to split by otherwise. |
| no_split_module_classes (`List[str]`): |
| A list of class names for layers we don't want to be split. |
| example_args (tuple of model inputs): |
| The expected inputs for the model that uses order-based inputs for a *single process*. Recommended to use |
| this method if possible. |
| example_kwargs (dict of model inputs) |
| The expected inputs for the model that uses dictionary-based inputs for a *single process*. This is a |
| *highly* limiting structure that requires the same keys be present at *all* inference calls. Not |
| recommended unless the prior condition is true for all cases. |
| num_chunks (`int`, defaults to the number of available GPUs): |
| The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but |
| this can be tuned and played with. In general one should have num_chunks >= num_gpus. |
| gather_output (`bool`, defaults to `False`): |
| If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. |
| """ |
| if not is_pippy_available(): |
| raise ImportError("Using `torch.distributed.pipelining` requires PyTorch 2.4.0 or later.") |
| state = PartialState() |
| example_args = send_to_device(example_args, "cpu") |
| example_kwargs = send_to_device(example_kwargs, "cpu") |
| if num_chunks is None: |
| num_chunks = state.num_processes |
| if split_points == "auto": |
| device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes) |
| split_points = [] |
| for i in range(1, num_chunks): |
| split_points.append(next(k for k, v in device_map.items() if v == i)) |
| model.hf_split_points = split_points |
| stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks) |
| model._original_forward = model.forward |
| model._original_call = model.__call__ |
| model.pippy_stage = stage |
| model.hf_split_points = split_points |
|
|
| def forward(*args, **kwargs): |
| return pippy_forward(stage.step, num_chunks, gather_output, *args, **kwargs) |
|
|
| |
| |
| model_forward = MethodType(forward, model) |
| forward.__wrapped__ = model_forward |
| model.forward = forward |
| return model |
|
|