diff --git "a/lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py" "b/lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py" new file mode 100644--- /dev/null +++ "b/lib.linux-x86_64-cpython-310/tensorrt_llm/auto_parallel/parallelization.py" @@ -0,0 +1,2297 @@ +import contextlib +import copy +import itertools +import pickle # nosec B403 +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, ClassVar, Dict, List, Sequence, Set, Tuple, Union + +import numpy as np +import tensorrt as trt +import torch +from filelock import FileLock + +from tensorrt_llm._utils import (str_dtype_to_trt, trt_dtype_to_np, + trt_dtype_to_torch, trt_gte_10) +from tensorrt_llm.functional import (AllReduceConfig, AllReduceFusionParams, + AllReduceStrategy, create_allreduce_plugin) +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.network import (PluginInfo, delete_plugin_info, get_np_weight, + get_plugin_info, set_plugin_info) +from tensorrt_llm.plugin import TRT_LLM_PLUGIN_NAMESPACE, init_all_reduce_helper +from tensorrt_llm.plugin.plugin import (CustomAllReduceHelper, + current_all_reduce_helper) +from tensorrt_llm.version import __version__ + +from .config import AutoParallelConfig +from .device_mesh import LogicalDeviceMesh +from .pipeline_graph import Layer, PipelineGraph, Tensor +from .shape_info import (ShapeInfo, get_per_layer_graph, get_shape_layers, + infer_per_layer_shapes) +from .simplifier import GraphConfig, GraphMapping, Simplifier, StageType +from .tensor_parallel.comm_spec import CommSpec +from .tensor_parallel.plugin_nodes.gpt_attention_node import ( + GPTAttentionPlugin, IdxEntry, IdxEntryParser) +from .tensor_parallel.sharding_spec import ShardingSpec, get_sharding_sequence +from .tensor_parallel.sharding_strategy import ShardingStrategy +from .utils import (get_updated_plugin, to_base_class_layer, to_subclass_layer, + to_trt_weights) + +default_int_dtype = trt.int64 if trt_gte_10() else trt.int32 + + +@dataclass +class ParallelConfig: + VERSION: ClassVar[str] = __version__ + + version: str = VERSION + network_hash: str = None + auto_parallel_config: AutoParallelConfig = None + graph_config: GraphConfig = None + lmesh: LogicalDeviceMesh = None + cost: float = None + graph_strategy: Dict[str, ShardingStrategy] = None + stage_type: StageType = None + + def save(self, filename): + with open(filename, 'wb') as file: + pickle.dump(self, file) + + @staticmethod + def from_file(filename) -> "ParallelConfig": + with open(filename, "rb") as file: + return pickle.load(file) # nosec B301 + + def print_graph_strategy(self, file=None): + for index, (node_name, + strategy) in enumerate(self.graph_strategy.items()): + print(f'\n[{index}]: node_name = {node_name}', file=file) + strategy.print_strategy(best_resharding_cost_only=True, file=file) + + +def desimplify_strategy( + graph: PipelineGraph, + graph_strategy: Dict[str, ShardingStrategy], + graph_mapping: GraphMapping, +): + for strategy in graph_strategy.values(): + for name, commspec in list(strategy.communication_actions.items()): + strategy.communication_actions[name] = [commspec] + strategy.sharding_specs[ + f"{name}_after_comm"] = strategy.sharding_specs[name] + + # insert same spec layers' communication actions after + # its producer's communication actions + same_spec_layer_mapping = graph_mapping.same_spec_layer_mapping + for same_spec_layer_name in same_spec_layer_mapping.keys(): + same_spec_strategy = graph_strategy[same_spec_layer_name] + same_spec_commspecs = same_spec_strategy.best_resharding_cost[0][0][1] + if len(same_spec_commspecs) == 0: + continue + output_name = same_spec_layer_name[:-len("_same_spec")] + output = graph.get_tensor(output_name) + layer_name = output.producer.name + output_index = output.output_index + strategy = graph_strategy[layer_name] + commspecs = strategy.communication_actions.get(f"output{output_index}", + []) + commspecs.extend(same_spec_commspecs) + strategy.communication_actions[f"output{output_index}"] = commspecs + strategy.sharding_specs[ + f"output{output_index}_after_comm"] = same_spec_strategy.sharding_specs[ + "output0"] + + layer_mapping = graph_mapping.layer_mapping + for removed_layer_name, layer_name in layer_mapping.items(): + if layer_name in graph_strategy: + strategy = copy.copy(graph_strategy[layer_name]) + layer = graph.get_layer(removed_layer_name) + if layer is not None: + strategy.node_names = strategy.node_names.copy() + for index, name in list(strategy.node_names.items()): + input = layer.get_input(index) + node_name = input.name if input.producer is None else input.producer.name + strategy.node_names[index] = node_name + graph_strategy[removed_layer_name] = strategy + + +@dataclass +class SplitInfo: + input_dim: Union[int, trt.ITensor] + partition: int + + def __deepcopy__(self, memo) -> "SplitInfo": + return SplitInfo(self.input_dim, self.partition) + + +@dataclass +class TensorInfo: + name: str = None + split_infos: Dict[int, SplitInfo] = field(default_factory=dict) + + def set_split_info(self, dim, split_info): + self.split_infos[dim] = split_info + + def __deepcopy__(self, memo) -> "TensorInfo": + return TensorInfo(self.name, copy.deepcopy(self.split_infos)) + + +@dataclass +class TensorContext: + info_by_device: Dict[int, TensorInfo] = field(default_factory=dict) + device_dims_for_shape: Set[int] = field(default_factory=set) + + def update_name_mapping(self, device_id, new_name): + if device_id not in self.info_by_device: + self.info_by_device[device_id] = TensorInfo() + self.info_by_device[device_id].name = new_name + + def set_split_info(self, device_id, dim, split_info): + if device_id not in self.info_by_device: + self.info_by_device[device_id] = TensorInfo() + self.info_by_device[device_id].set_split_info(dim, split_info) + + def set_split_infos(self, device_id, split_infos: Dict[int, SplitInfo]): + if device_id not in self.info_by_device: + self.info_by_device[device_id] = TensorInfo() + self.info_by_device[device_id].split_infos = split_infos + + def __deepcopy__(self, memo) -> "TensorContext": + return TensorContext(copy.deepcopy(self.info_by_device), + set(self.device_dims_for_shape)) + + +@dataclass +class LayerUpdate: + updated_attrs: Dict[str, Any] = field(default_factory=dict) + updated_inputs: Dict[int, trt.ITensor] = field(default_factory=dict) + split_info_updated: bool = False + + @staticmethod + def none() -> "LayerUpdate": + return LayerUpdate() + + +@dataclass +class GraphContext: + tensor_contexts: Dict[str, TensorContext] = field(default_factory=dict) + + def get_name(self, tensor_name, device_id): + if tensor_name not in self.tensor_contexts: + return None + if device_id not in self.tensor_contexts[tensor_name].info_by_device: + return None + return self.tensor_contexts[tensor_name].info_by_device[device_id].name + + def update_name_mapping(self, tensor_name, device_id, new_name): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].update_name_mapping( + device_id, new_name) + + def get_name_mapping(self, device_id, prefix: str) -> Dict[str, str]: + name_mapping = {} + for tensor_name in self.tensor_contexts.keys(): + new_name = self.get_name(tensor_name, device_id) + if new_name is not None: + name_mapping[f"{prefix}{tensor_name}"] = new_name + return name_mapping + + def add_device_dims_for_shape(self, tensor_name: str, + device_dims: Sequence[int]): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].device_dims_for_shape.update( + device_dims) + + def get_device_dims_for_shape(self, tensor_name: str): + if tensor_name not in self.tensor_contexts: + return set() + return self.tensor_contexts[tensor_name].device_dims_for_shape + + def get_split_infos(self, tensor_name, device_id): + if tensor_name not in self.tensor_contexts: + return None + if device_id not in self.tensor_contexts[tensor_name].info_by_device: + return None + return self.tensor_contexts[tensor_name].info_by_device[ + device_id].split_infos + + def set_split_info(self, tensor_name, device_id, dim, split_info): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].set_split_info( + device_id, dim, split_info) + + def set_split_infos(self, tensor_name, device_id, + split_infos: Dict[int, SplitInfo]): + if tensor_name not in self.tensor_contexts: + self.tensor_contexts[tensor_name] = TensorContext() + self.tensor_contexts[tensor_name].set_split_infos( + device_id, split_infos) + + def update_layer_context(self, wrapped_layer: Layer, + layer_update: LayerUpdate, + local_context: "GraphContext", device_id: int, + device_ids: np.ndarray, + sharding_specs: Dict[str, ShardingSpec]): + layer = wrapped_layer.as_trt() + for i in range(layer.num_outputs): + output = layer.get_output(i) + new_name = local_context.get_name(output.name, device_id) + if new_name is not None: + self.update_name_mapping(output.name, device_id, new_name) + if layer_update.split_info_updated: + for i in range(layer.num_outputs): + output = layer.get_output(i) + split_infos = local_context.get_split_infos( + output.name, device_id) + if split_infos is not None: + self.set_split_infos(output.name, device_id, split_infos) + return + split_info_by_device_dim = {} + for i in range(layer.num_inputs): + input = layer.get_input(i) + if input is None: + continue + sharding_spec = sharding_specs[f"input{i}"] + split_infos = local_context.get_split_infos(input.name, device_id) + if split_infos is None: + continue + for dim, split_info in split_infos.items(): + device_dim = tuple(sharding_spec.dim_partition_dict[dim]) + split_info_by_device_dim[device_dim] = split_info + for i in range(layer.num_outputs): + output = layer.get_output(i) + sharding_spec = sharding_specs[f"output{i}"] + for dim, device_dim in sharding_spec.dim_partition_dict.items(): + split_info = split_info_by_device_dim.get(tuple(device_dim)) + if split_info is None: + if device_dim == [0, 1] or device_dim == [1, 0]: + if (0, ) in split_info_by_device_dim and ( + 1, ) in split_info_by_device_dim: + split_info = SplitInfo( + split_info_by_device_dim[(0, )].input_dim * + split_info_by_device_dim[(1, )].input_dim, + split_info_by_device_dim[(0, )].partition * + split_info_by_device_dim[(1, )].partition, + ) + assert split_info is not None + partition = get_partition(device_dim, device_ids) + if split_info.input_dim != output.shape[dim]: + assert output.shape[ + dim] > 0 and output.shape[dim] % partition == 0 + output_split_info = SplitInfo(output.shape[dim], partition) + self.set_split_info(output.name, device_id, dim, + output_split_info) + + def get_local_context(self, layer: trt.ILayer) -> "GraphContext": + local_context = GraphContext() + for i in range(layer.num_inputs): + input = layer.get_input(i) + if input is None: + continue + local_context.tensor_contexts[input.name] = copy.deepcopy( + self.tensor_contexts[input.name]) + return local_context + + def get_local_context_for_output(self, + output: trt.ITensor) -> "GraphContext": + local_context = GraphContext() + local_context.tensor_contexts[output.name] = copy.deepcopy( + self.tensor_contexts[output.name]) + return local_context + + def merge_context(self, context: "GraphContext"): + self.tensor_contexts.update(context.tensor_contexts) + + +@dataclass +class ShardContext: + graph_context: GraphContext + layer: Layer + nditer: np.nditer + device_ids: np.ndarray + strategy: ShardingStrategy + + +def get_partition(device_dim, device_ids): + if device_dim == [0]: + partition = device_ids.shape[0] + elif device_dim == [1]: + partition = device_ids.shape[1] + else: + assert device_dim == [0, 1] or device_dim == [1, 0] + partition = device_ids.size + return partition + + +def get_index(device_dim, iter): + if device_dim == [0]: + index = iter.multi_index[0] + elif device_dim == [1]: + index = iter.multi_index[1] + else: + assert device_dim == [0, 1] or device_dim == [1, 0] + index = iter.iterindex + return index + + +def get_full_sharding_spec(sharding_spec): + return ShardingSpec(sharding_spec.device_mesh, + sharding_spec.data_type_size, + sharding_spec.entire_shape, + sharding_spec.max_entire_shape, + sharding_spec.raw_shape, + dim_partition_dict={}) + + +def get_comm_action_sequence(from_sharding_sepc, to_sharding_sepc): + comm_action_sequence = from_sharding_sepc.device_mesh.shape_consistency_manager.shape_consistency( + from_sharding_sepc, to_sharding_sepc)[1] + # TODO: should merged by shape_consistency + if len(comm_action_sequence) == 2: + if comm_action_sequence[0].comm_pattern == comm_action_sequence[ + 1].comm_pattern == "all_gather": + if comm_action_sequence[0].gather_dim == comm_action_sequence[ + 1].gather_dim: + comm_action_sequence = [ + CommSpec( + comm_action_sequence[0].comm_pattern, + comm_action_sequence[0].sharding_spec, + comm_action_sequence[0].gather_dim, + comm_action_sequence[0].shard_dim, [[ + *comm_action_sequence[0].logical_process_axis[0], + *comm_action_sequence[1].logical_process_axis[0] + ]], comm_action_sequence[0].mix_gather, + comm_action_sequence[0].forward_only) + ] + assert len(comm_action_sequence[0].logical_process_axis[0]) <= 2 + assert len(comm_action_sequence) <= 1 + return comm_action_sequence + + +class GraphGroup(ABC): + + @staticmethod + def from_graph( + graph: PipelineGraph, + config: ParallelConfig, + auto_parallel_config: AutoParallelConfig, + ) -> "GraphGroup": + if auto_parallel_config.debug_mode: + return PrefixedGraphGroup(graph, config, auto_parallel_config) + else: + return DistributedGraphGroup(graph, config, auto_parallel_config) + + @property + @abstractmethod + def auto_parallel_config(self) -> AutoParallelConfig: + ... + + @abstractmethod + def add_input(self, tensor, device_ids, strategy: ShardingStrategy): + ... + + @abstractmethod + def add_layer(self, layer, device_ids, strategy: ShardingStrategy): + ... + + @abstractmethod + def add_output(self, tensor, device_ids, sharding_spec: ShardingSpec): + ... + + @abstractmethod + def get_network(self, device_id) -> trt.INetworkDefinition: + ... + + @abstractmethod + def get_graph(self, device_id) -> PipelineGraph: + ... + + @property + @abstractmethod + def full_graph(self) -> PipelineGraph: + ... + + @abstractmethod + def get_prefix(self, device_id) -> str: + ... + + @abstractmethod + def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: + ... + + @abstractmethod + def get_values(self, device_id) -> Dict[str, List[int]]: + ... + + @abstractmethod + def add_all_reduce_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_reduce_tensors): + ... + + @abstractmethod + def add_all_gather_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_gather_tensors): + ... + + @abstractmethod + def register_layer(self, + layer, + base_name, + input_name, + output_name=None, + device_id=None, + keep_tensor_name=False) -> Layer: + ... + + def get_tensor(self, context: GraphContext, tensor_name: str, + device_id: int) -> Tensor: + name = context.get_name(tensor_name, device_id) + return self.get_graph(device_id).get_tensor(name) + + def add_comm(self, + context: GraphContext, + input_name, + device_ids, + commspec, + output_name=None, + is_singleton=False): + remove_index = [] + for i, device_dim in enumerate(commspec.logical_process_axis): + partition = get_partition(device_dim, device_ids) + if partition == 1: + remove_index.append(i) + if len(remove_index) > 0: + if commspec.comm_pattern in ["all_gather", "all_to_all"]: + commspec.gather_dim = [ + dim for i, dim in enumerate(commspec.gather_dim) + if i not in remove_index + ] + if commspec.comm_pattern in [ + "split", "reduce_scatter", "all_to_all" + ]: + commspec.shard_dim = [ + dim for i, dim in enumerate(commspec.shard_dim) + if i not in remove_index + ] + commspec.logical_process_axis = [ + dim for i, dim in enumerate(commspec.logical_process_axis) + if i not in remove_index + ] + flatten_device_dim = list( + itertools.chain.from_iterable(commspec.logical_process_axis)) + if flatten_device_dim == []: + return + if flatten_device_dim == [0, 1] or flatten_device_dim == [1, 0]: + self._add_comm(context, input_name, device_ids, commspec, + output_name, is_singleton) + elif flatten_device_dim == [0]: + for i in range(device_ids.shape[1]): + self._add_comm(context, input_name, device_ids[:, i:i + 1], + commspec, output_name, is_singleton) + elif flatten_device_dim == [1]: + for i in range(device_ids.shape[0]): + self._add_comm(context, input_name, device_ids[i:i + 1, :], + commspec, output_name, is_singleton) + else: + raise RuntimeError( + f"Invalid flatten device_dim: {flatten_device_dim}") + + def _add_comm(self, + context: GraphContext, + input_name, + device_ids, + commspec, + output_name=None, + is_singleton=False): + input_tensors = [ + self.get_tensor(context, input_name, device_id.item()) + for device_id in np.nditer(device_ids) + ] + comm_pattern = commspec.comm_pattern + if comm_pattern == "split": + self.add_split(context, input_name, output_name, device_ids, + commspec.shard_dim, commspec.logical_process_axis) + elif comm_pattern == "all_gather": + self.add_all_gather(context, input_name, output_name, device_ids, + commspec.gather_dim, + commspec.logical_process_axis, is_singleton) + elif comm_pattern == "all_reduce": + self.add_all_reduce(context, input_name, output_name, device_ids) + elif comm_pattern == "reduce_scatter": + self.add_reduce_scatter(context, input_name, output_name, + device_ids, commspec.shard_dim, + commspec.logical_process_axis) + elif comm_pattern == "all_to_all": + self.add_all_to_all(context, input_name, output_name, device_ids, + commspec.gather_dim, commspec.shard_dim, + commspec.logical_process_axis) + else: + raise NotImplementedError + output_tensors = [ + self.get_tensor(context, input_name, device_id.item()) + for device_id in np.nditer(device_ids) + ] + for input_tensor, output_tensor in zip(input_tensors, output_tensors): + if input_tensor.dtype != output_tensor.dtype: + raise ValueError( + f"Input tensor and output tensor should have the same dtype for communication layers, " + f"input dtype is {input_tensor.dtype} for {input_tensor.name}, " + f"but output dtype is {output_tensor.dtype} for {output_tensor.name}" + ) + + def add_all_reduce(self, context: GraphContext, input_name, output_name, + device_ids): + dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype) + to_reduce_tensors = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + input_dtype = input_tensor.dtype + if input_dtype != dtype: + to_reduce_tensor = self.cast( + network, + input_tensor, + dtype, + layer_info, + ) + else: + to_reduce_tensor = input_tensor + to_reduce_tensors.append(to_reduce_tensor) + self.add_all_reduce_layer(context, input_name, output_name, device_ids, + to_reduce_tensors) + if input_dtype != dtype: + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor( + context, + input_name, + device_id, + ).as_trt() + output_tensor = self.cast( + network, + input_tensor, + input_dtype, + layer_info, + ) + context.update_name_mapping( + input_name, + device_id, + output_tensor.name, + ) + + def add_reduce_scatter(self, context: GraphContext, input_name, output_name, + device_ids, shard_dims, device_dims): + self.add_all_reduce(context, input_name, output_name, device_ids) + self.add_split(context, input_name, output_name, device_ids, shard_dims, + device_dims) + + # TODO: use native all_to_all operation + def add_all_to_all(self, context: GraphContext, input_name, output_name, + device_ids, gather_dims, shard_dims, device_dims): + self.add_all_gather(context, input_name, output_name, device_ids, + gather_dims, device_dims) + self.add_split(context, input_name, output_name, device_ids, shard_dims, + device_dims) + + def get_item(self, network, tensor, index, layer_info): + get_item_layer = network.add_slice(tensor, [index], [1], [1]) + self.register_layer(get_item_layer, f"get_item{index}", *layer_info) + return get_item_layer.get_output(0) + + def get_shape(self, network, tensor, layer_info): + shape_layer = network.add_shape(tensor) + self.register_layer(shape_layer, "shape", *layer_info) + return shape_layer.get_output(0) + + def concat(self, network, tensors, layer_info): + concat_layer = network.add_concatenation(tensors) + self.register_layer(concat_layer, "concat", *layer_info) + return concat_layer.get_output(0) + + def flatten(self, network, tensor, layer_info): + shuffle_layer = network.add_shuffle(tensor) + shuffle_layer.reshape_dims = [-1] + shuffle_layer.zero_is_placeholder = False + self.register_layer(shuffle_layer, "flatten", *layer_info) + return shuffle_layer.get_output(0) + + def reshape(self, network, tensor, reshape_dims, layer_info): + reshape_layer = network.add_shuffle(tensor) + reshape_layer.set_input(1, reshape_dims) + reshape_layer.zero_is_placeholder = False + self.register_layer(reshape_layer, "reshape", *layer_info) + return reshape_layer.get_output(0) + + def cast(self, network, tensor, dtype, layer_info): + if tensor.dtype == dtype: + return tensor + cast_layer = network.add_cast(tensor, dtype) + self.register_layer(cast_layer, "cast", *layer_info) + return cast_layer.get_output(0) + + def const_int(self, network, name, value, layer_info): + const_layer = network.add_constant( + [1], np.array([value], dtype=trt_dtype_to_np(default_int_dtype))) + self.register_layer(const_layer, name, *layer_info) + return const_layer.get_output(0) + + def get_dim_size(self, network, tensor, dim, layer_info, shape_tensor=None): + raw_shape = tensor.shape + dim_size = raw_shape[dim] + if dim_size != -1: + return dim_size + else: + if shape_tensor is None: + shape_tensor = self.get_shape(network, tensor, layer_info) + return self.get_item(network, shape_tensor, dim, layer_info) + + def add_split(self, context: GraphContext, input_name, output_name, + device_ids, shard_dims, device_dims): + it = np.nditer(device_ids, flags=['multi_index']) + for device_id in it: + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + raw_input_shape = input_tensor.shape + start = [] + output_dims = [] + stride = [] + input_shape_tensor = self.get_shape(network, input_tensor, + layer_info) + for dim in range(len(raw_input_shape)): + stride.append(1) + if dim not in shard_dims: + start.append(0) + output_dims.append( + self.get_item(network, input_shape_tensor, dim, + layer_info)) + else: + start.append(None) + output_dims.append(None) + + for dim, device_dim in zip(shard_dims, device_dims): + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, it) + input_dim = raw_input_shape[dim] + assert input_dim != -1 + assert input_dim % partition == 0 + quotient = input_dim // partition + start[dim] = index * quotient + output_dims[dim] = self.const_int(network, f"output_dim{dim}", + quotient, layer_info) + context.set_split_info(input_name, device_id, dim, + SplitInfo(input_dim, partition)) + output_dims_tensor = self.concat(network, output_dims, layer_info) + split_layer = network.add_slice(input_tensor, start, [], stride) + split_layer.set_input(2, output_dims_tensor) + wrapped_layer = self.register_layer(split_layer, "split", + *layer_info) + wrapped_layer.attrs["strategy"] = get_sharding_sequence( + len(raw_input_shape), + shard_dims, + device_dims, + ) + + output_tensor = split_layer.get_output(0) + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_gather(self, + context: GraphContext, + input_name, + output_name, + device_ids, + gather_dims, + device_dims, + is_singleton=False): + to_gather_tensors = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + to_gather_tensor = self.flatten(network, input_tensor, layer_info) + to_gather_tensors.append(to_gather_tensor) + + all_gather_layers = self.add_all_gather_layer( + context, + input_name, + output_name, + device_ids, + to_gather_tensors, + ) + + if len(device_dims) == 1: + gather_indices = [0] + elif len(device_dims) == 2 and device_dims[0] == [1]: + gather_indices = [1, 0] + else: + gather_indices = [0, 1] + + for device_id, all_gather_layer in zip(np.nditer(device_ids), + all_gather_layers): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + permutation = [] + gathered_dims = [] + output_dims = [] + partitions = [] + raw_input_shape = input_tensor.shape + + wrapped_layer = self.get_graph(device_id).get_layer( + all_gather_layer.name) + wrapped_layer.attrs["strategy"] = get_sharding_sequence( + len(raw_input_shape), + gather_dims, + device_dims, + ) + + input_shape_layer = network.add_shape(input_tensor) + self.register_layer(input_shape_layer, "input_shape", *layer_info) + input_shape_tensor = input_shape_layer.get_output(0) + split_infos = context.get_split_infos(input_name, device_id) + for index in gather_indices: + gather_dim = gather_dims[index] + device_dim = device_dims[index] + partition = get_partition(device_dim, device_ids) + assert partition == split_infos[gather_dim].partition + partitions.append( + self.const_int(network, f"partition_num{gather_dim}", + partition, layer_info)) + for dim in range(len(raw_input_shape)): + if dim in gather_dims: + gather_index = gather_dims.index(dim) + device_dim = device_dims[gather_index] + permutation.append(gather_indices.index(gather_index)) + permutation.append(dim + len(gather_dims)) + if dim not in split_infos: + output_dim_layer = network.add_slice( + input_shape_tensor, [dim], [1], [1]) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + dim_tensor = output_dim_layer.get_output(0) + output_dims.append(dim_tensor) + gathered_dims.append(dim_tensor) + else: + input_dim = split_infos[dim].input_dim + partition = split_infos[dim].partition + assert input_dim != -1 + assert input_dim % partition == 0 + quotient = input_dim // partition + output_dims.append( + self.const_int(network, f"output_dim{dim}", quotient, + layer_info)) + if dim in gather_dims: + gathered_dims.append( + self.const_int(network, f"gathered_dim{dim}", + quotient * partition, layer_info)) + del split_infos[dim] + else: + gathered_dims.append(output_dim_layer.get_output(0)) + + reshape_dims_for_transpose_layer = network.add_concatenation( + [*partitions, *output_dims]) + self.register_layer(reshape_dims_for_transpose_layer, + "reshape_dims_for_transpose", *layer_info) + reshape_dims_tensor = reshape_dims_for_transpose_layer.get_output(0) + transpose_layer = network.add_shuffle( + all_gather_layer.get_output(0)) + transpose_layer.set_input(1, reshape_dims_tensor) + transpose_layer.second_transpose = permutation + transpose_layer.zero_is_placeholder = False + self.register_layer(transpose_layer, "transpose", *layer_info) + + reshape_dims_for_reshape_layer = network.add_concatenation( + gathered_dims) + self.register_layer(reshape_dims_for_reshape_layer, + "reshape_dims_for_reshape", *layer_info) + reshape_dims_tensor = reshape_dims_for_reshape_layer.get_output(0) + output_tensor = self.reshape( + network, + transpose_layer.get_output(0), + reshape_dims_tensor, + layer_info, + ) + context.update_name_mapping(input_name, device_id, + output_tensor.name) + if is_singleton: + break + + def register_unfilled_weights(self, graph, layer): + if (layer.name in self.full_graph._unfilled_weights + and layer.name not in graph._unfilled_weights): + weights, values = self.full_graph._unfilled_weights[layer.name] + graph._register_unfilled_weights( + layer.name, + weights, + values, + ) + + def shard_constant(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + device_id = context.nditer.value.item() + device_ids = context.device_ids + layer = context.layer.as_trt() + graph = self.get_graph(device_id) + if len(shard_dims) == 0: + self.register_unfilled_weights(graph, layer) + return LayerUpdate(split_info_updated=True) + flatten_device_dim = list( + itertools.chain.from_iterable(shard_dims.values())) + output_name = layer.get_output(0).name + output_dtype = layer.get_output(0).dtype + output_shape = layer.shape + output_dims = [] + weight_index = [] + for dim in range(len(output_shape)): + output_dim = output_shape[dim] + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, context.nditer) + assert output_dim % partition == 0 + quotient = output_dim // partition + output_dims.append(quotient) + weight_index.append( + slice(index * quotient, (index + 1) * quotient)) + context.graph_context.set_split_info( + output_name, device_id, dim, + SplitInfo(output_dim, partition)) + else: + output_dims.append(output_dim) + weight_index.append(slice(None)) + if layer.name in self.full_graph._unfilled_weights: + values = self.full_graph._unfilled_weights[layer.name][1] + else: + values = layer.weights + if isinstance(values, trt.Weights): + values = values.numpy() + # TODO: remove this WAR after https://nvbugs/4359151 fixed. + if isinstance(values, trt.Weights): + network = context.layer.graph.as_trt() + values = get_np_weight(network, layer.name) + if values is not None: + values = values.reshape(layer.shape) + assert values.size == np.prod(layer.shape) + sharded_values = values[tuple(weight_index)] + assert sharded_values.size * get_partition( + flatten_device_dim, device_ids) == np.prod(layer.shape) + else: + sharded_values = None + dtype = trt_dtype_to_np(output_dtype) + sharded_weights = np.empty(tuple(output_dims), dtype) + graph._register_unfilled_weights( + f"device{device_id}_{layer.name}", + sharded_weights, + sharded_values, + ) + sharded_weights = to_trt_weights(sharded_weights) + return LayerUpdate( + updated_attrs=dict( + shape=trt.Dims(output_dims), + weights=sharded_weights, + ), + split_info_updated=True, + ) + + def shard_fill(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + if len(shard_dims) == 0: + return LayerUpdate(split_info_updated=True) + device_id = context.nditer.value.item() + device_ids = context.device_ids + layer = context.layer.as_trt() + output_name = layer.get_output(0).name + output_shape = layer.shape + output_dims = [] + for dim in range(len(output_shape)): + output_dim = output_shape[dim] + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + assert output_dim % partition == 0 + quotient = output_dim // partition + output_dims.append(quotient) + context.graph_context.set_split_info( + output_name, device_id, dim, + SplitInfo(output_dim, partition)) + else: + output_dims.append(output_dim) + return LayerUpdate( + updated_attrs=dict(shape=trt.Dims(output_dims), ), + split_info_updated=True, + ) + + def update_shape(self, context: ShardContext): + if not context.layer.is_shape_io: + return + layer = context.layer.as_trt() + input_name = layer.get_input(0).name + output_name = layer.get_output(0).name + device_id = context.nditer.value.item() + layer_info = (output_name, None, device_id) + split_infos = context.graph_context.get_split_infos( + input_name, device_id) + if len(split_infos) == 0: + return + network = self.get_network(device_id) + shape_tensor = self.get_tensor(context.graph_context, output_name, + device_id).as_trt() + output_dims = [] + for dim in range(len(context.layer.get_input(0).shape)): + if dim not in split_infos: + output_dim_layer = network.add_slice(shape_tensor, [dim], [1], + [1]) + else: + input_dim = split_infos[dim].input_dim + output_dim_layer = network.add_constant( + [1], np.array([input_dim], dtype=default_int_dtype)) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + output_dims.append(output_dim_layer.get_output(0)) + new_shape_layer = network.add_concatenation(output_dims) + self.register_layer(new_shape_layer, "new_shape", *layer_info) + new_shape_tensor = new_shape_layer.get_output(0) + context.graph_context.update_name_mapping(output_name, device_id, + new_shape_tensor.name) + + def shard_slice(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + if len(shard_dims) == 0: + return LayerUpdate.none() + device_id = context.nditer.value.item() + network = self.get_network(device_id) + device_ids = context.device_ids + layer = context.layer.as_trt() + output_dims = [] + updated_attrs = {} + updated_inputs = {} + + if layer.num_inputs >= 3: + raw_output_shape = layer.get_output(0).shape + input_name = layer.get_input(2).name + layer_info = (input_name, layer.name, device_id) + shape_tensor = self.get_tensor(context.graph_context, input_name, + device_id).as_trt() + for dim in range(len(raw_output_shape)): + output_dim_layer = network.add_slice(shape_tensor, [dim], [1], + [1]) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + partition_num_tensor = self.const_int( + network, f"partition_num{dim}", partition, layer_info) + quotient_layer = network.add_elementwise( + output_dim_layer.get_output(0), partition_num_tensor, + trt.ElementWiseOperation.FLOOR_DIV) + self.register_layer(quotient_layer, f"quotient{dim}", + *layer_info) + output_dim = self.cast(network, + quotient_layer.get_output(0), + default_int_dtype, layer_info) + output_dims.append(output_dim) + else: + output_dims.append(output_dim_layer.get_output(0)) + output_dims_layer = network.add_concatenation(output_dims) + self.register_layer(output_dims_layer, "output_dims", *layer_info) + updated_inputs[2] = output_dims_layer.get_output(0) + else: + output_shape = layer.shape + for dim in range(len(output_shape)): + output_dim = output_shape[dim] + assert output_dim != -1 + if dim in shard_dims: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + assert output_dim % partition == 0 + quotient = output_dim // partition + output_dims.append(quotient) + else: + output_dims.append(output_dim) + updated_attrs["shape"] = trt.Dims(output_dims) + return LayerUpdate(updated_attrs, updated_inputs) + + def shard_shuffle(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + if len(shard_dims) == 0: + return LayerUpdate.none() + device_id = context.nditer.value.item() + network = self.get_network(device_id) + device_ids = context.device_ids + layer = context.layer.as_trt() + updated_attrs = {} + updated_inputs = {} + updated_reshape_dims = {} + second_transpose = layer.second_transpose + + if layer.num_inputs >= 2: + raw_output_shape = layer.get_output(0).shape + input_name = layer.get_input(1).name + layer_info = (input_name, layer.name, device_id) + reshape_dims_tensor = self.get_tensor(context.graph_context, + input_name, device_id) + reshape_dims = context.layer.get_input(1).value + reshape_dims_tensor = reshape_dims_tensor.as_trt() + for dim in range(len(raw_output_shape)): + if second_transpose is not None: + reshape_dim = second_transpose[dim] + else: + reshape_dim = dim + output_dim_layer = network.add_slice(reshape_dims_tensor, + [reshape_dim], [1], [1]) + self.register_layer(output_dim_layer, f"output_dim{dim}", + *layer_info) + output_dim = reshape_dims[reshape_dim] + if dim in shard_dims and output_dim != -1: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + partition_num_tensor = self.const_int( + network, f"partition_num{dim}", partition, layer_info) + quotient_layer = network.add_elementwise( + output_dim_layer.get_output(0), partition_num_tensor, + trt.ElementWiseOperation.FLOOR_DIV) + self.register_layer(quotient_layer, f"quotient{dim}", + *layer_info) + updated_reshape_dims[reshape_dim] = self.cast( + network, + quotient_layer.get_output(0), + default_int_dtype, + layer_info, + ) + else: + updated_reshape_dims[ + reshape_dim] = output_dim_layer.get_output(0) + updated_reshape_dims = list( + map(lambda x: x[1], sorted(updated_reshape_dims.items()))) + reshape_dims_layer = network.add_concatenation(updated_reshape_dims) + self.register_layer(reshape_dims_layer, "reshape_dims", *layer_info) + updated_inputs[1] = reshape_dims_layer.get_output(0) + else: + reshape_dims = layer.reshape_dims + if reshape_dims.__len__() < 0: + return LayerUpdate.none() + for dim in range(len(reshape_dims)): + if second_transpose is not None: + reshape_dim = second_transpose[dim] + else: + reshape_dim = dim + output_dim = reshape_dims[reshape_dim] + if dim in shard_dims and output_dim != -1: + device_dim = shard_dims[dim] + partition = get_partition(device_dim, device_ids) + quotient = output_dim // partition + updated_reshape_dims[reshape_dim] = quotient + else: + updated_reshape_dims[reshape_dim] = output_dim + updated_reshape_dims = list( + map(lambda x: x[1], sorted(updated_reshape_dims.items()))) + updated_attrs["reshape_dims"] = trt.Dims(updated_reshape_dims) + return LayerUpdate(updated_attrs, updated_inputs) + + def shard_gpt_attention(self, context: ShardContext): + layer = context.layer.as_trt() + plugin_info = get_plugin_info( + self.full_graph.as_trt(), + layer.name, + ) + parser = IdxEntryParser(plugin_info) + head_dim = 1 if parser.remove_input_padding else 2 + sharding_spec = context.strategy.sharding_specs[ + f"input{parser.get_index(IdxEntry.QKV_TENSOR)}"] + shard_dims = sharding_spec.dim_partition_dict + if head_dim not in shard_dims: + return LayerUpdate.none() + device_id = context.nditer.value.item() + network = self.get_network(device_id) + device_ids = context.device_ids + updated_attrs = {} + updated_inputs = {} + device_dim = shard_dims[head_dim] + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, context.nditer) + if parser.is_entry_used(IdxEntry.K_TENSOR): + kv_sharding_spec = context.strategy.sharding_specs[ + f"input{parser.get_index(IdxEntry.K_TENSOR)}"] + kv_shard_dims = kv_sharding_spec.dim_partition_dict + if head_dim in kv_shard_dims: + kv_device_dim = kv_shard_dims[head_dim] + kv_partition = get_partition(kv_device_dim, device_ids) + else: + kv_partition = 1 + else: + kv_partition = 1 + num_heads = plugin_info.pfc_as_ndarray["num_heads"].copy() + num_kv_heads = plugin_info.pfc_as_ndarray["num_kv_heads"].copy() + tp_size = plugin_info.pfc_as_ndarray["tp_size"].copy() + tp_rank = plugin_info.pfc_as_ndarray["tp_rank"].copy() + num_kv_heads = np.maximum(num_kv_heads // kv_partition, 1) + num_heads = np.maximum(num_heads // partition, 1) + tp_size[0] = partition + tp_rank[0] = index + + new_plugin, new_plugin_info = get_updated_plugin( + plugin_info, + dict( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + tp_size=tp_size, + tp_rank=tp_rank, + )) + prefix = self.get_prefix(device_id) + new_layer_name = f"{prefix}{layer.name}" + set_plugin_info(network, new_layer_name, new_plugin_info) + updated_attrs["plugin"] = new_plugin + return LayerUpdate(updated_attrs, updated_inputs) + + def shard_lookup(self, context: ShardContext): + sharding_spec = context.strategy.sharding_specs["input1"] + shard_dims = sharding_spec.dim_partition_dict + if 0 not in shard_dims: + return LayerUpdate.none() + layer = context.layer.as_trt() + plugin_info = get_plugin_info( + self.full_graph.as_trt(), + layer.name, + ) + device_id = context.nditer.value.item() + network = self.get_network(device_id) + updated_attrs = {} + device_dim = shard_dims[0] + index = get_index(device_dim, context.nditer) + rank = plugin_info.pfc_as_ndarray["rank"].copy() + rank[0] = index + + new_plugin, new_plugin_info = get_updated_plugin( + plugin_info, dict(rank=rank, )) + prefix = self.get_prefix(device_id) + new_layer_name = f"{prefix}{layer.name}" + set_plugin_info(network, new_layer_name, new_plugin_info) + updated_attrs["plugin"] = new_plugin + return LayerUpdate(updated_attrs) + + +class GraphGroupBase(GraphGroup): + + def __init__( + self, + full_graph: PipelineGraph, + config: ParallelConfig, + auto_parallel_config: AutoParallelConfig, + ) -> None: + self._full_graph = full_graph + self.config = config + self._auto_parallel_config = auto_parallel_config + self.infer_shape = auto_parallel_config.infer_shape + self.global_context = GraphContext() + self.shape_cache = {} + self.suffix = 0 + self.current_block_id = -1 + + @property + def auto_parallel_config(self) -> AutoParallelConfig: + return self._auto_parallel_config + + @property + def full_graph(self) -> PipelineGraph: + return self._full_graph + + def register_layer(self, + layer, + base_name, + input_name, + output_name=None, + device_id=None, + keep_tensor_name=False) -> Layer: + layer_name = f"{base_name}_{input_name}" + if device_id is not None: + layer_name = f"{self.get_prefix(device_id)}{layer_name}" + if output_name is not None: + layer_name = f"{layer_name}_to_{output_name}" + suffix = self.suffix + self.suffix += 1 + layer_name = f"{layer_name}_{suffix}" + if layer.type == trt.LayerType.PLUGIN_V2: + network = self.get_network(device_id) + plugin_info = get_plugin_info(network, layer.name) + if plugin_info is not None: + set_plugin_info(network, layer_name, plugin_info) + delete_plugin_info(network, layer.name) + layer.name = layer_name + layer.metadata = layer.name + if not keep_tensor_name: + for i in range(layer.num_outputs): + output_tensor = layer.get_output(i) + assert output_tensor.shape.__len__() >= 0 + output_tensor.name = f"{layer.name}_output_{i}" + wrapped_layer = self.get_graph(device_id).register_layer(layer) + if self.current_block_id != -1: + wrapped_layer.attrs["block_id"] = self.current_block_id + wrapped_layer.attrs["role"] = "helper" + if self.infer_shape: + infer_per_layer_shapes( + layer, + self.get_shapes(device_id), + self.get_values(device_id), + self.shape_cache, + is_shape_io=True, + ) + wrapped_layer.assign_shapes( + self.get_shapes(device_id), + self.get_values(device_id), + ) + return wrapped_layer + + def add_layer(self, wrapped_layer: Layer, device_ids, + strategy: ShardingStrategy): + layer = wrapped_layer.as_trt() + local_context = self.global_context.get_local_context(layer) + self.current_block_id = wrapped_layer.attrs["block_id"] + + for i, input in enumerate(wrapped_layer.inputs): + if input is None: + continue + if i not in strategy.best_resharding_cost: + continue + comm_action_sequence = strategy.best_resharding_cost[i][0][1] + for commspec in comm_action_sequence: + self.add_comm(local_context, + input.name, + device_ids, + commspec, + output_name=layer.name) + + it = np.nditer(device_ids, flags=['multi_index']) + for device_id in it: + device_id = device_id.item() + + layer_type = layer.type + to_subclass_layer(layer) + shard_context = ShardContext( + local_context, + wrapped_layer, + it, + device_ids, + strategy, + ) + if layer_type == trt.LayerType.CONSTANT: + layer_update = self.shard_constant(shard_context) + elif layer_type == trt.LayerType.FILL: + layer_update = self.shard_fill(shard_context) + elif layer_type == trt.LayerType.SLICE: + layer_update = self.shard_slice(shard_context) + elif layer_type == trt.LayerType.SHUFFLE: + layer_update = self.shard_shuffle(shard_context) + elif layer_type == trt.LayerType.PLUGIN_V2: + if layer.plugin.plugin_type == "GPTAttention": + layer_update = self.shard_gpt_attention(shard_context) + elif layer.plugin.plugin_type == "Lookup": + layer_update = self.shard_lookup(shard_context) + else: + layer_update = LayerUpdate.none() + else: + layer_update = LayerUpdate.none() + to_base_class_layer(layer) + + for i, updated_input in layer_update.updated_inputs.items(): + input_name = layer.get_input(i).name + local_context.update_name_mapping(input_name, device_id, + updated_input.name) + if layer.get_input(i).dtype != updated_input.dtype: + raise ValueError( + f"Input dtype mismatch for {layer.name}, " + f"expect {layer.get_input(i).dtype} for {input_name}, " + f"get {updated_input.dtype} for {updated_input.name}") + + prefix = self.get_prefix(device_id) + new_wrapped_layer = self.get_graph(device_id).add_layer( + layer, + prefix=prefix, + input_mapping=local_context.get_name_mapping(device_id, + prefix=prefix), + updated_attrs=layer_update.updated_attrs, + ) + new_wrapped_layer.attrs["strategy"] = strategy.name + new_wrapped_layer.attrs["block_id"] = self.current_block_id + new_layer = new_wrapped_layer.as_trt() + + if self.infer_shape: + infer_per_layer_shapes( + new_layer, + self.get_shapes(device_id), + self.get_values(device_id), + self.shape_cache, + is_shape_io=wrapped_layer.is_shape_io, + ) + new_wrapped_layer.assign_shapes( + self.get_shapes(device_id), + self.get_values(device_id), + ) + + for i in range(layer.num_outputs): + output_tensor = new_layer.get_output(i) + assert output_tensor.shape.__len__() >= 0 + local_context.update_name_mapping( + layer.get_output(i).name, device_id, output_tensor.name) + + if layer.type == trt.LayerType.SHAPE: + self.update_shape(shard_context) + + self.global_context.update_layer_context( + wrapped_layer, + layer_update, + local_context, + device_id, + device_ids, + strategy.sharding_specs, + ) + + for i in range(layer.num_outputs): + commspecs = strategy.communication_actions.get(f"output{i}") + if commspecs is None: + continue + output = layer.get_output(i) + for commspec in commspecs: + self.add_comm( + self.global_context, + output.name, + device_ids, + commspec, + ) + + self.current_block_id = -1 + + +class DistributedGraphGroup(GraphGroupBase): + + def __init__( + self, + full_graph: PipelineGraph, + config: ParallelConfig, + auto_parallel_config: AutoParallelConfig, + ) -> None: + super().__init__(full_graph, config, auto_parallel_config) + self.graphs = {} + self.io_tensor_shards = {} + self.shapes_by_device = {} + self.values_by_device = {} + phy_mesh = config.graph_config.phy_mesh + device_ids = phy_mesh.phy_devices_id + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = PipelineGraph.create_graph() + graph._auto_parallel_config = { + "io_shards": {}, + "mapping": + Mapping( + world_size=device_ids.size, + rank=device_id, + gpus_per_node=device_ids.shape[1], + tp_size=device_ids.size // config.graph_config.num_stages, + pp_size=config.graph_config.num_stages, + ), + } + self.graphs[device_id] = graph + self.shapes_by_device[device_id] = {} + self.values_by_device[device_id] = {} + + @contextlib.contextmanager + def disable_infer_shape(self): + infer_shape = self.infer_shape + self.infer_shape = False + yield + self.infer_shape = infer_shape + + def get_network(self, device_id) -> trt.INetworkDefinition: + return self.graphs[device_id].as_trt() + + def get_graph(self, device_id) -> PipelineGraph: + return self.graphs[device_id] + + def get_prefix(self, device_id) -> str: + return "" + + def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: + return self.shapes_by_device[device_id] + + def get_values(self, device_id) -> Dict[str, List[int]]: + return self.values_by_device[device_id] + + def add_reduce_scatter(self, context: GraphContext, input_name, output_name, + device_ids, shard_dims, device_dims): + dtype = str_dtype_to_trt(self.full_graph._plugin_config.dtype) + it = np.nditer(device_ids, flags=['multi_index']) + for device_id in it: + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + input_tensor = self.get_tensor(context, input_name, + device_id).as_trt() + raw_input_shape = input_tensor.shape + input_shape_tensor = self.get_shape(network, input_tensor, + layer_info) + if shard_dims != [0]: + permutation = list(range(len(raw_input_shape))) + for dim in shard_dims: + permutation.remove(dim) + permutation = shard_dims + permutation + transpose_layer = network.add_shuffle(input_tensor) + transpose_layer.second_transpose = permutation + self.register_layer(transpose_layer, "input_transpose", + *layer_info) + input_tensor = transpose_layer.get_output(0) + flatten_tensor = self.flatten(network, input_tensor, layer_info) + input_dtype = flatten_tensor.dtype + if input_dtype != dtype: + to_reduce_tensor = self.cast( + network, + flatten_tensor, + dtype, + layer_info, + ) + else: + to_reduce_tensor = flatten_tensor + + reduce_scatter_plg_creator = trt.get_plugin_registry( + ).get_plugin_creator('ReduceScatter', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert reduce_scatter_plg_creator is not None + + group = trt.PluginField( + "group", + np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), + trt.PluginFieldType.INT32) + pf_type = trt.PluginField( + "type_id", np.array([int(to_reduce_tensor.dtype)], np.int32), + trt.PluginFieldType.INT32) + + pfc = trt.PluginFieldCollection([group, pf_type]) + rs_plug = reduce_scatter_plg_creator.create_plugin( + "reduce_scatter", pfc) + + reduce_scatter_layer = network.add_plugin_v2([to_reduce_tensor], + rs_plug) + plugin_info = PluginInfo(reduce_scatter_plg_creator, + "reduce_scatter", pfc) + set_plugin_info(network, reduce_scatter_layer.name, plugin_info) + with self.disable_infer_shape(): + wrapped_tensor = self.register_layer( + reduce_scatter_layer, + "reduce_scatter", + *layer_info, + ).get_output(0) + reduce_scatter_tensor = reduce_scatter_layer.get_output(0) + if self.infer_shape: + shape = self.shapes_by_device[device_id][to_reduce_tensor.name] + assert len(shape) == 1 + output_shape = (shape[0] // device_ids.size, ) + self.shapes_by_device[device_id][ + reduce_scatter_tensor.name] = output_shape + wrapped_tensor.shape = output_shape + if input_dtype != dtype: + reduce_scatter_tensor = self.cast( + network, + reduce_scatter_tensor, + input_dtype, + layer_info, + ) + + start = [] + output_dims = [] + stride = [] + for dim in range(len(raw_input_shape)): + stride.append(1) + if dim not in shard_dims: + start.append(0) + output_dims.append( + self.get_item(network, input_shape_tensor, dim, + layer_info)) + else: + start.append(None) + output_dims.append(None) + + for dim, device_dim in zip(shard_dims, device_dims): + partition = get_partition(device_dim, device_ids) + index = get_index(device_dim, it) + input_dim = raw_input_shape[dim] + assert input_dim != -1 + assert input_dim % partition == 0 + quotient = input_dim // partition + start[dim] = index * quotient + output_dims[dim] = self.const_int(network, f"output_dim{dim}", + quotient, layer_info) + context.set_split_info(input_name, device_id, dim, + SplitInfo(input_dim, partition)) + if shard_dims != [0]: + output_dims = [ + output_dims[permutation[i]] for i in range(len(output_dims)) + ] + output_dims_tensor = self.concat(network, output_dims, layer_info) + output_tensor = self.reshape( + network, + reduce_scatter_tensor, + output_dims_tensor, + layer_info, + ) + if shard_dims != [0]: + transpose_layer = network.add_shuffle(output_tensor) + transpose_layer.second_transpose = permutation + self.register_layer(transpose_layer, "output_transpose", + *layer_info) + output_tensor = transpose_layer.get_output(0) + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_reduce_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_reduce_tensors): + counter = current_all_reduce_helper().gen_id() + for device_id, to_reduce_tensor in zip(np.nditer(device_ids), + to_reduce_tensors): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + graph = self.get_graph(device_id) + strategy = AllReduceStrategy.AUTO + workspace = graph.get_input("all_reduce_workspace").as_trt() + + all_reduce_layer, allreduce_plg_creator, pfc = create_allreduce_plugin( + network=network, + tensor=to_reduce_tensor, + workspace=workspace, + group=np.ascontiguousarray( + device_ids.reshape(-1).astype(np.int32)), + strategy=strategy, + dtype=to_reduce_tensor.dtype, + config=AllReduceConfig(0), + counter=counter, + reduce_fusion_params=AllReduceFusionParams(), + ) + plugin_info = PluginInfo(allreduce_plg_creator, "allreduce", pfc) + set_plugin_info(network, all_reduce_layer.name, plugin_info) + with self.disable_infer_shape(): + wrapped_tensor = self.register_layer( + all_reduce_layer, + "all_reduce", + *layer_info, + ).get_output(0) + output_tensor = all_reduce_layer.get_output(0) + if self.infer_shape: + shape = self.shapes_by_device[device_id][to_reduce_tensor.name] + self.shapes_by_device[device_id][output_tensor.name] = shape + wrapped_tensor.shape = shape + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_gather_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_gather_tensors): + all_gather_layers = [] + for device_id, to_gather_tensor in zip(np.nditer(device_ids), + to_gather_tensors): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + + allgather_plg_creator = trt.get_plugin_registry( + ).get_plugin_creator('AllGather', '1', TRT_LLM_PLUGIN_NAMESPACE) + assert allgather_plg_creator is not None + + group = trt.PluginField( + "group", + np.ascontiguousarray(device_ids.reshape(-1).astype(np.int32)), + trt.PluginFieldType.INT32) + pf_type = trt.PluginField( + "type_id", np.array([int(to_gather_tensor.dtype)], np.int32), + trt.PluginFieldType.INT32) + pfc = trt.PluginFieldCollection([group, pf_type]) + allgather = allgather_plg_creator.create_plugin("allgather", pfc) + + all_gather_layer = network.add_plugin_v2([to_gather_tensor], + allgather) + plugin_info = PluginInfo(allgather_plg_creator, "allgather", pfc) + set_plugin_info(network, all_gather_layer.name, plugin_info) + with self.disable_infer_shape(): + wrapped_tensor = self.register_layer( + all_gather_layer, + "all_gather", + *layer_info, + ).get_output(0) + if self.infer_shape: + output_tensor = all_gather_layer.get_output(0) + shape = self.shapes_by_device[device_id][to_gather_tensor.name] + assert len(shape) == 1 + output_shape = (shape[0] * device_ids.size, ) + self.shapes_by_device[device_id][ + output_tensor.name] = output_shape + wrapped_tensor.shape = output_shape + all_gather_layers.append(all_gather_layer) + return all_gather_layers + + def set_shard_num(self, tensor_name, dim, shard_num): + for graph in self.graphs.values(): + io_shards = graph._auto_parallel_config["io_shards"] + if tensor_name not in io_shards: + io_shards[tensor_name] = {} + io_shards[tensor_name][dim] = shard_num + + def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy): + context = self.global_context + sharding_spec = strategy.sharding_specs["output0"] + shard_dims = sharding_spec.dim_partition_dict + for dim, device_dim in shard_dims.items(): + partition = get_partition(device_dim, device_ids) + self.set_shard_num(tensor.name, dim, partition) + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = self.get_graph(device_id) + new_input = graph.add_input(tensor.as_trt()) + shape = [*tensor.shape] + if len(shard_dims) != 0: + output_shape = [*tensor.raw_shape] + for dim, device_dim in shard_dims.items(): + partition = get_partition(device_dim, device_ids) + output_dim = output_shape[dim] + assert output_dim != -1 + assert output_dim % partition == 0 + quotient = output_dim // partition + output_shape[dim] = quotient + shape[dim] = quotient + assert tensor.value is None + context.set_split_info(tensor.name, device_id, dim, + SplitInfo(output_dim, partition)) + new_input.raw_shape = output_shape + context.update_name_mapping(tensor.name, device_id, tensor.name) + if self.infer_shape: + self.shapes_by_device[device_id][tensor.name] = tuple(shape) + new_input.shape = tuple(shape) + if tensor.value is not None: + self.values_by_device[device_id][tensor.name] = tensor.value + new_input.value = tensor.value + + def add_output(self, tensor: Tensor, device_ids, + strategy: ShardingStrategy): + comm_action_sequence = strategy.best_resharding_cost[0][0][1] + for commspec in comm_action_sequence: + self.add_comm(self.global_context, tensor.name, device_ids, + commspec) + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = self.get_graph(device_id) + output_name = tensor.name + new_output_name = self.global_context.get_name( + output_name, device_id) + if new_output_name != output_name: + suffix = self.suffix + self.suffix += 1 + original_name = f"original_{output_name}_{suffix}" + original_tensor = graph.get_tensor(output_name) + original_tensor.as_trt().name = original_name + output_tensor = graph.get_tensor(new_output_name) + output_tensor.as_trt().name = output_name + graph._tensors[original_name] = original_tensor + graph._tensors[output_name] = output_tensor + del graph._tensors[new_output_name] + else: + output_tensor = graph.get_tensor(output_name) + trt_output = output_tensor.as_trt() + if trt_output.is_shape_tensor: + graph.add_output_shape(trt_output) + else: + graph.add_output(trt_output) + trt_output.dtype = tensor.dtype + if tensor.dtype != output_tensor.dtype: + raise ValueError( + f"Output dtype mismatch, " + f"expect {tensor.dtype} for {tensor.name}, " + f"get {output_tensor.dtype} for {output_tensor.name}") + + shard_dims = strategy.sharding_specs["input0"].dim_partition_dict + for dim, device_dim in shard_dims.items(): + partition = get_partition(device_dim, device_ids) + self.set_shard_num(tensor.name, dim, partition) + + +class PrefixedGraphGroup(GraphGroupBase): + + def __init__( + self, + full_graph: PipelineGraph = None, + config: ParallelConfig = None, + auto_parallel_config: AutoParallelConfig = None, + ) -> None: + auto_parallel_config = auto_parallel_config or dict( + infer_shape=False, + validation_mode=False, + ) + super().__init__(full_graph, config, auto_parallel_config) + self.validation_mode = auto_parallel_config.validation_mode + if not self.infer_shape: + self.validation_mode = False + self.prefixed_graph = PipelineGraph.create_graph() + if self.validation_mode: + self.layer_mapping = config.graph_config.graph_mapping.layer_mapping + self.graph_strategy = config.graph_strategy + self.shapes = {} + self.values = {} + self.timing_cache = None + + def get_network(self, device_id) -> trt.INetworkDefinition: + return self.prefixed_graph.as_trt() + + def get_graph(self, device_id) -> PipelineGraph: + return self.prefixed_graph + + def get_prefix(self, device_id) -> str: + return f"device{device_id}_" + + def get_shapes(self, device_id) -> Dict[str, Tuple[int, ...]]: + return self.shapes + + def get_values(self, device_id) -> Dict[str, List[int]]: + return self.values + + def add_all_reduce_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_reduce_tensors): + reshaped_tensors = [] + for device_id, to_reduce_tensor in zip(np.nditer(device_ids), + to_reduce_tensors): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + reshape_dims_tensor = self.concat( + network, + [ + self.get_shape(network, to_reduce_tensor, layer_info), + self.const_int(network, "expanded_dim", 1, layer_info) + ], + layer_info, + ) + reshaped_tensor = self.reshape( + network, + to_reduce_tensor, + reshape_dims_tensor, + layer_info, + ) + reshaped_tensors.append(reshaped_tensor) + + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + input_tensor = self.get_tensor(context, input_name, 0).as_trt() + num_dims = len(input_tensor.shape) + network = self.get_network(device_id) + concat_layer = network.add_concatenation(reshaped_tensors) + concat_layer.axis = num_dims + self.register_layer(concat_layer, "concat", *layer_info) + reduce_layer = network.add_reduce(concat_layer.get_output(0), + trt.ReduceOperation.SUM, + axes=1 << num_dims, + keep_dims=False) + dtype = to_reduce_tensors[0].dtype + reduce_layer.precision = dtype + reduce_layer.set_output_type(0, dtype) + self.register_layer(reduce_layer, "reduce", *layer_info) + output_tensor = reduce_layer.get_output(0) + + context.update_name_mapping(input_name, device_id, + output_tensor.name) + + def add_all_gather_layer(self, context: GraphContext, input_name, + output_name, device_ids, to_gather_tensors): + all_gather_layers = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (input_name, output_name, device_id) + network = self.get_network(device_id) + all_gather_layer = network.add_concatenation(to_gather_tensors) + all_gather_layer.axis = 0 + self.register_layer(all_gather_layer, "all_gather", *layer_info) + all_gather_layers.append(all_gather_layer) + return all_gather_layers + + def add_input(self, tensor: Tensor, device_ids, strategy: ShardingStrategy): + + def add_identity(): + identity_layer = network.add_identity(input.as_trt()) + return identity_layer + + input = self.prefixed_graph.add_input(tensor.as_trt()) + if self.infer_shape: + self.shapes[tensor.name] = tensor.shape + input.shape = tensor.shape + if tensor.value is not None: + self.values[tensor.name] = tensor.value + input.value = tensor.value + network = self.get_network(None) + if self.validation_mode: + identity_layer = add_identity() + identity_layer.get_output(0).name = f"ref_{tensor.name}" + layer_info = (tensor.name, None, None) + self.register_layer(identity_layer, + "identity", + *layer_info, + keep_tensor_name=True) + input.attrs["strategy"] = strategy.name + sharding_spec = strategy.sharding_specs["output0"] + pre_sharding_sepc = get_full_sharding_spec(sharding_spec) + comm_action_sequence = get_comm_action_sequence(pre_sharding_sepc, + sharding_spec) + context = self.global_context + for device_id in np.nditer(device_ids): + device_id = device_id.item() + layer_info = (tensor.name, None, device_id) + context.update_name_mapping(tensor.name, device_id, tensor.name) + if len(comm_action_sequence + ) == 0 and not tensor.as_trt().is_shape_tensor: + identity_layer = add_identity() + self.register_layer(identity_layer, "identity", *layer_info) + context.update_name_mapping( + tensor.name, + device_id, + identity_layer.get_output(0).name, + ) + for commspec in comm_action_sequence: + self.add_comm(context, tensor.name, device_ids, commspec) + + def get_graph_in_range(self, graph_group, src_layer, layer_range, + device_ids, shapes, values): + src_network = self.prefixed_graph.as_trt() + graph = graph_group.prefixed_graph + network = graph.as_trt() + input_mapping = {} + for device_id in np.nditer(device_ids): + device_id = device_id.item() + for i in range(src_layer.num_inputs): + src_input = src_layer.get_input(i) + if src_input is not None: + input = self.get_tensor( + self.global_context, + src_input.name, + device_id, + ).as_trt() + if graph.get_input(src_input.name) is not None: + new_input = graph_group.get_tensor( + graph_group.global_context, + src_input.name, + device_id, + ).as_trt() + input_mapping[input.name] = new_input.name + continue + if graph.get_tensor(input.name) is not None: + continue + shape = shapes[input.name] + assert input.name in values + value = values[input.name] + weights = np.asarray(value, + dtype=trt_dtype_to_np(input.dtype)) + weights = to_trt_weights(weights) + input_layer = network.add_constant(shape, weights) + new_input = input_layer.get_output(0) + new_input.name = input.name + graph.register_layer(input_layer) + for i in layer_range: + layer = src_network.get_layer(i) + graph.add_layer(layer, input_mapping=input_mapping) + + def add_layer_singleton(self, output, device_ids, sharding_spec): + assert self.prefixed_graph.get_tensor(output.name) is None + network = self.prefixed_graph.as_trt() + full_sharding_sepc = get_full_sharding_spec(sharding_spec) + comm_action_sequence = get_comm_action_sequence(sharding_spec, + full_sharding_sepc) + output_context = self.global_context.get_local_context_for_output( + output) + if len(comm_action_sequence) != 0: + for commspec in comm_action_sequence[:-1]: + self.add_comm(output_context, output.name, device_ids, commspec) + self.add_comm( + output_context, + output.name, + device_ids, + comm_action_sequence[-1], + is_singleton=True, + ) + device_id = next(np.nditer(device_ids)).item() + layer_info = (output.name, None, device_id) + output_tensor = self.get_tensor(output_context, output.name, + device_id).as_trt() + singleton_layer = network.add_identity(output_tensor) + singleton_layer.get_output(0).name = output.name + self.register_layer(singleton_layer, + "singleton", + *layer_info, + keep_tensor_name=True) + + def add_layer(self, wrapped_layer: Layer, device_ids, + strategy: ShardingStrategy): + graph = self.prefixed_graph + network = graph.as_trt() + start_layer_id = network.num_layers + + super().add_layer(wrapped_layer, device_ids, strategy) + + layer = wrapped_layer.as_trt() + + if self.validation_mode: + is_shape = (wrapped_layer.is_shape_io + or layer.type == trt.LayerType.SHAPE) + + if not is_shape: + self.current_block_id = wrapped_layer.attrs["block_id"] + for i, wrapped_output in enumerate(wrapped_layer.outputs): + if wrapped_output.is_graph_output: + continue + output = wrapped_output.as_trt() + output_name = f"output{i}" + if strategy.communication_actions.get( + output_name) is not None: + output_name += "_after_comm" + sharding_spec = strategy.sharding_specs[output_name] + self.add_layer_singleton(output, device_ids, sharding_spec) + self.current_block_id = -1 + end_layer_id = network.num_layers + + is_skip = (is_shape or layer.type == trt.LayerType.CONSTANT + or layer.name in self.layer_mapping) + sharded = False + for sharding_spec in strategy.sharding_specs.values(): + if len(sharding_spec.dim_partition_dict) > 0: + sharded = True + break + if not sharded: + is_skip = True + + ref_layer = graph.add_layer(layer, prefix="ref_") + ref_layer.attrs["strategy"] = strategy.name + ref_layer.attrs["block_id"] = wrapped_layer.attrs["block_id"] + if layer.type == trt.LayerType.CONSTANT: + self.register_unfilled_weights(graph, layer) + + if is_skip: + return + + logger.debug(f"validating layer {layer.name}") + + layer_type = layer.type + generated_input_values = {} + to_subclass_layer(layer) + if layer_type == trt.LayerType.PLUGIN_V2: + if layer.plugin.plugin_type == "GPTAttention": + sharding_specs = {} + for name, sharding_spec in strategy.sharding_specs.items(): + sharding_specs[name] = get_full_sharding_spec( + sharding_spec) + plugin_info = get_plugin_info( + self.full_graph.as_trt(), + layer.name, + ) + generated_input_values = GPTAttentionPlugin.parameter_generator( + sharding_specs, plugin_info) + to_base_class_layer(layer) + + validation_graph_group = PrefixedGraphGroup() + validation_graph = validation_graph_group.prefixed_graph + validation_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping + extra_input_values = {} + validation_shapes = {} + for i, wrapped_input in enumerate(wrapped_layer.inputs): + if wrapped_input is None: + continue + input = wrapped_input.as_trt() + validation_shapes[input.name] = wrapped_input.shape + if wrapped_input.value is None: + if i in generated_input_values: + extra_input_value = generated_input_values[i] + else: + extra_input_value = torch.empty( + tuple(wrapped_input.shape), + dtype=trt_dtype_to_torch(input.dtype), + device=torch.cuda.current_device(), + ) + if torch.is_floating_point(extra_input_value): + extra_input_value.normal_() + # extra_input_value[:] = random.choice([2, 3, 5, 7]) + extra_input_values[input.name] = extra_input_value + self.values[input.name] = extra_input_value + if wrapped_input.producer is not None: + node_name = wrapped_input.producer.name + output_index = wrapped_input.output_index + else: + node_name = wrapped_input.name + output_index = 0 + sharding_spec = self.graph_strategy[ + node_name].sharding_specs[f"output{output_index}"] + validation_graph_group.add_input( + wrapped_input, + device_ids, + ShardingStrategy( + sharding_specs={"output0": sharding_spec}), + ) + validation_graph.get_input( + input.name).raw_shape = wrapped_input.shape + + self.get_graph_in_range( + validation_graph_group, + layer, + range(start_layer_id, end_layer_id), + device_ids, + self.shapes, + self.values, + ) + + for i, wrapped_output in enumerate(wrapped_layer.outputs): + output = wrapped_output.as_trt() + if wrapped_output.is_graph_output: + output_name = f"output{i}" + if strategy.communication_actions.get( + output_name) is not None: + output_name += "_after_comm" + sharding_spec = strategy.sharding_specs[output_name] + validation_graph_group.global_context.merge_context( + self.global_context.get_local_context_for_output( + output)) + validation_graph_group.add_layer_singleton( + output, device_ids, sharding_spec) + validation_graph.add_output(output) + validation_shapes[output.name] = wrapped_output.shape + if not self.timing_cache: + self.timing_cache = network.builder.create_builder_config( + ).create_timing_cache(b"") + logger.debug(f"run validation graph for layer {layer.name}") + validation_runner = validation_graph.get_runner( + validation_shapes, + self.values, + timing_cache=self.timing_cache, + opt_level=0, + ) + values = validation_runner.run() + refer_input_values = {} + for wrapped_input in wrapped_layer.inputs: + if wrapped_input is None: + continue + if wrapped_input.value is not None: + refer_input_values[wrapped_input.name] = wrapped_input.value + refer_graph, output_mapping = get_per_layer_graph( + layer, + validation_shapes, + refer_input_values, + is_shape_io=False, + ) + refer_graph._io_buffer_mapping = self.full_graph._io_buffer_mapping + for proxy_output, output in output_mapping.items(): + validation_shapes[proxy_output] = validation_shapes[output] + logger.debug(f"run refer graph for layer {layer.name}") + refer_runner = refer_graph.get_runner( + validation_shapes, + self.values, + timing_cache=self.timing_cache, + opt_level=0, + ) + refer_outputs = refer_runner.run() + for name, refer_output in refer_outputs.items(): + if name in output_mapping: + refer_output = refer_output.bool() + output = values[name] + # ∣output−refer_output∣ <= atol+rtol*∣refer_output∣ + atol = 1e-02 + rtol = 1e-02 + if not torch.allclose( + output, + refer_output, + rtol=rtol, + atol=atol, + equal_nan=True, + ): + size = output.nelement() + diff = (output - refer_output).abs() + diff_index = (~torch.isnan(diff)) & ( + diff > (atol + rtol * refer_output.abs())) + diff_output = diff[diff_index] + diff_size = diff_output.nelement() + logger.warning( + f"output {name} of {layer.name} is not accurate after parallelization. " + f"{diff_size} out of {size} elements ({diff_size / size * 100:.2f}%) are not close. " + f"max: {diff_output.max():.5f}, mean: {diff_output.float().mean():.5f}, std: {diff_output.float().std():.5f}. " + f"mean of reference: {refer_output.float().mean():.5f}, mean of output: {output.float().mean():.5f}." + ) + for name in extra_input_values.keys(): + del self.values[name] + + def add_output(self, tensor: Tensor, device_ids, + strategy: ShardingStrategy): + trt_output = tensor.as_trt() + comm_action_sequence = strategy.best_resharding_cost[0][0][1] + for commspec in comm_action_sequence: + self.add_comm(self.global_context, tensor.name, device_ids, + commspec) + self.add_layer_singleton(trt_output, device_ids, + strategy.sharding_specs["input0"]) + if trt_output.is_shape_tensor: + output = self.prefixed_graph.add_output_shape(trt_output) + else: + output = self.prefixed_graph.add_output(trt_output) + trt_output.dtype = tensor.dtype + output.attrs["strategy"] = strategy.name + + def assign_shapes(self, shape_info: ShapeInfo): + if self.validation_mode: + shapes = { + f"ref_{name}": shape + for name, shape in shape_info.shapes.items() + } + values = { + f"ref_{name}": value + for name, value in shape_info.values.items() + } + self.shapes.update(shapes) + self.values.update(values) + shape_layers = get_shape_layers(self.prefixed_graph.as_trt()) + shape_info = ShapeInfo(self.shapes, self.values, shape_layers) + self.prefixed_graph.assign_shapes(shape_info) + + +def parallelize( + simplifier: Simplifier, + config: ParallelConfig, +): + auto_parallel_config = simplifier.config + debug_mode = auto_parallel_config.debug_mode + dump_path = auto_parallel_config.dump_path + debug_outputs = auto_parallel_config.debug_outputs + + simplifier.infer_shapes(config.graph_config.num_micro_batches) + network = simplifier.network + graph = simplifier.graph + phy_mesh = config.graph_config.phy_mesh + # TODO: test device_ids = [[0]] + device_ids = phy_mesh.phy_devices_id + stage_phy_meshes = config.graph_config.stage_phy_meshes + block_to_stage = config.graph_config.graph_mapping.block_to_stage + graph_strategy = config.graph_strategy + desimplify_strategy( + graph, + graph_strategy, + config.graph_config.graph_mapping, + ) + graph._plugin_config = simplifier.llm_network.plugin_config + graph_group = GraphGroup.from_graph(graph, config, auto_parallel_config) + + if not debug_mode: + init_all_reduce_helper() + tp_size = phy_mesh.size // config.graph_config.num_stages + shape = (CustomAllReduceHelper.POINTERS_PER_RANK * tp_size, ) + workspace = graph.as_trt().add_input( + name="all_reduce_workspace", + dtype=trt.int64, + shape=shape, + ) + tensor = graph.register_input(workspace) + tensor.shape = shape + graph_strategy["all_reduce_workspace"] = ShardingStrategy( + sharding_specs={ + "output0": + ShardingSpec( + device_mesh=phy_mesh.as_logical_mesh(), + data_type_size=tensor.dtype_str_size, + data_shape=shape, + max_data_shape=shape, + raw_data_shape=shape, + dim_partition_dict={}, + ) + }) + + if dump_path is not None: + lock = FileLock(f"{dump_path}/path.lock", thread_local=False) + with lock: + with open(f'{dump_path}/sharded_graph.log', 'w+') as file: + config.print_graph_strategy(file) + + for input in graph.inputs: + graph_group.add_input(input, device_ids, graph_strategy[input.name]) + for block in simplifier.blocks: + stage_id = block_to_stage[block.block_id] + stage_phy_mesh = stage_phy_meshes[stage_id] + stage_device_ids = stage_phy_mesh.phy_devices_id.reshape( + config.lmesh.mesh_shape) + for i in block.sorted_layer_ids: + layer = graph.get_layer(network.get_layer(i).name) + layer.attrs["block_id"] = block.block_id + graph_group.add_layer( + layer, + stage_device_ids, + graph_strategy[layer.name], + ) + for output in graph.outputs: + graph_group.add_output(output, device_ids, graph_strategy[output.name]) + + if debug_mode: + new_graph = graph_group.prefixed_graph + debug_outputs = debug_outputs or [] + if isinstance(debug_outputs, str): + if debug_outputs == 'validation': + debug_outputs = [] + for tensor in new_graph.tensors: + if tensor.name.startswith('ref_'): + original_name = tensor.name[4:] + original_tensor = new_graph.get_tensor(original_name) + if original_tensor is not None: + if not original_tensor.is_graph_io: + debug_outputs.append(tensor.name) + debug_outputs.append(original_name) + if original_tensor.is_graph_output: + debug_outputs.append(tensor.name) + else: + pattern = debug_outputs + debug_outputs = [] + for tensor in new_graph.tensors: + if tensor.as_trt().is_shape_tensor: + continue + if tensor.producer is not None: + layer = tensor.producer + if layer.type == trt.LayerType.SHAPE: + continue + if re.match(pattern, tensor.name): + debug_outputs.append(tensor.name) + for output_name in debug_outputs: + trt_output = new_graph.get_tensor(output_name).as_trt() + if trt_output.is_shape_tensor: + output = new_graph.add_output_shape(trt_output) + else: + output = new_graph.add_output(trt_output) + graph_group.assign_shapes(simplifier.shape_info) + if dump_path is not None: + with lock: + new_graph.to_dot( + f'{dump_path}/sharded_graph.dot', + per_device=True, + per_block=True, + # ignore_shape_io=True, + extra_attrs=['strategy'], + ) + return [new_graph] + else: + graphs = [] + for device_id in np.nditer(device_ids): + device_id = device_id.item() + graph = graph_group.graphs[device_id] + graphs.append(graph) + return graphs