|
|
from typing import List |
|
|
|
|
|
import pandas as pd |
|
|
import tensorrt as trt |
|
|
|
|
|
from .pipeline_graph import PipelineGraph |
|
|
from .runtime_profiling import RuntimeProfiler |
|
|
from .simplifier import GraphConfig, StageType |
|
|
from .solver import CostGraph, Solver |
|
|
from .tensor_parallel.activation_node import Activation |
|
|
from .tensor_parallel.assertion_node import Assertion |
|
|
from .tensor_parallel.cast_node import Cast |
|
|
from .tensor_parallel.concatenation_node import Concatenation |
|
|
from .tensor_parallel.constant_node import Constant |
|
|
from .tensor_parallel.elementwise_node import ElementWise |
|
|
from .tensor_parallel.fill_node import Fill |
|
|
from .tensor_parallel.gather_node import Gather |
|
|
from .tensor_parallel.identity_node import Identity |
|
|
from .tensor_parallel.input_node import InputNode |
|
|
from .tensor_parallel.matmul_node import MatrixMultiply |
|
|
from .tensor_parallel.node import Node |
|
|
from .tensor_parallel.normalization_node import Normalization |
|
|
from .tensor_parallel.output_node import OuputNode |
|
|
from .tensor_parallel.p2p_node import P2PNode, P2PType |
|
|
from .tensor_parallel.plugin_node import PluginNode |
|
|
from .tensor_parallel.plugin_nodes.gemm_node import GemmPlugin |
|
|
from .tensor_parallel.plugin_nodes.gpt_attention_node import GPTAttentionPlugin |
|
|
from .tensor_parallel.plugin_nodes.identity_node import IdentityPlugin |
|
|
from .tensor_parallel.plugin_nodes.look_up_node import LookupPlugin |
|
|
from .tensor_parallel.plugin_nodes.normalization_node import (LayernormPlugin, |
|
|
RMSnormPlugin) |
|
|
from .tensor_parallel.reduce_node import Reduce |
|
|
from .tensor_parallel.select_node import Select |
|
|
from .tensor_parallel.shape_node import Shape |
|
|
from .tensor_parallel.shuffle_node import Shuffle |
|
|
from .tensor_parallel.slice_node import Slice |
|
|
from .tensor_parallel.softmax_node import SoftMax |
|
|
from .tensor_parallel.unary_node import Unary |
|
|
|
|
|
LAYER_TYPE_2_NODE_TYPE = { |
|
|
trt.LayerType.ACTIVATION: Activation, |
|
|
trt.LayerType.ASSERTION: Assertion, |
|
|
trt.LayerType.CAST: Cast, |
|
|
trt.LayerType.CONCATENATION: Concatenation, |
|
|
trt.LayerType.CONSTANT: Constant, |
|
|
trt.LayerType.ELEMENTWISE: ElementWise, |
|
|
trt.LayerType.FILL: Fill, |
|
|
trt.LayerType.GATHER: Gather, |
|
|
trt.LayerType.IDENTITY: Identity, |
|
|
trt.LayerType.MATRIX_MULTIPLY: MatrixMultiply, |
|
|
trt.LayerType.NORMALIZATION: Normalization, |
|
|
trt.LayerType.PLUGIN_V2: PluginNode, |
|
|
trt.LayerType.REDUCE: Reduce, |
|
|
trt.LayerType.SELECT: Select, |
|
|
trt.LayerType.SHAPE: Shape, |
|
|
trt.LayerType.SHUFFLE: Shuffle, |
|
|
trt.LayerType.SLICE: Slice, |
|
|
trt.LayerType.SOFTMAX: SoftMax, |
|
|
trt.LayerType.UNARY: Unary, |
|
|
} |
|
|
|
|
|
PLUGIN_LAYER_TYPE_2_NODE_TYPE = { |
|
|
'GPTAttention': GPTAttentionPlugin, |
|
|
'Gemm': GemmPlugin, |
|
|
'Layernorm': LayernormPlugin, |
|
|
'Rmsnorm': RMSnormPlugin, |
|
|
'Lookup': LookupPlugin, |
|
|
'Identity': IdentityPlugin, |
|
|
} |
|
|
|
|
|
|
|
|
class NodeGraph: |
|
|
|
|
|
def __init__(self, graph: PipelineGraph): |
|
|
self._nodes = {} |
|
|
|
|
|
|
|
|
for input in graph.inputs: |
|
|
self._nodes[input.name] = InputNode(input) |
|
|
for layer in graph.layers: |
|
|
layer.to_base_class() |
|
|
if "p2p_type" in layer.attrs: |
|
|
self._nodes[layer.name] = P2PNode(layer) |
|
|
elif layer.type == trt.LayerType.PLUGIN_V2: |
|
|
layer.to_subclass() |
|
|
plugin_type = layer.as_trt().plugin.plugin_type |
|
|
layer.to_base_class() |
|
|
if plugin_type in PLUGIN_LAYER_TYPE_2_NODE_TYPE: |
|
|
node = PLUGIN_LAYER_TYPE_2_NODE_TYPE[plugin_type](layer) |
|
|
else: |
|
|
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer) |
|
|
self._nodes[layer.name] = node |
|
|
else: |
|
|
node = LAYER_TYPE_2_NODE_TYPE[layer.type](layer) |
|
|
self._nodes[layer.name] = node |
|
|
for output in graph.outputs: |
|
|
self._nodes[output.name] = OuputNode(output) |
|
|
for node in self.nodes: |
|
|
node.post_init(self) |
|
|
node.node_runtime_profiler = RuntimeProfiler() |
|
|
|
|
|
def get_node(self, name): |
|
|
return self._nodes[name] |
|
|
|
|
|
@property |
|
|
def nodes(self) -> List[Node]: |
|
|
return [*self._nodes.values()] |
|
|
|
|
|
def assign_cost_weights(self, graph_config: GraphConfig): |
|
|
layer_mapping = graph_config.graph_mapping.layer_mapping |
|
|
for layer_name in layer_mapping.values(): |
|
|
node = self.get_node(layer_name) |
|
|
node.sharding_weight += 1 |
|
|
node.resharding_weight += 1 |
|
|
same_spec_layer_mapping = graph_config.graph_mapping.same_spec_layer_mapping |
|
|
for same_spec_layer_name, layer_name in same_spec_layer_mapping.items(): |
|
|
node = self.get_node(layer_name) |
|
|
same_spec_node = self.get_node(same_spec_layer_name) |
|
|
same_spec_node.sharding_weight = node.sharding_weight |
|
|
same_spec_node.resharding_weight = node.resharding_weight |
|
|
|
|
|
def set_slowest_stage(self, stage_type: StageType, |
|
|
graph_config: GraphConfig): |
|
|
num_micro_batches = graph_config.num_micro_batches |
|
|
block_per_stage = graph_config.num_blocks // graph_config.num_stages |
|
|
block_pipeline_weight = block_per_stage * (num_micro_batches - 1) |
|
|
for node in self.nodes: |
|
|
node.pipeline_weight = 0 |
|
|
node.cost_level = -1 |
|
|
if node.stage_type == StageType.START: |
|
|
if stage_type == StageType.START: |
|
|
node.pipeline_weight = num_micro_batches - 1 |
|
|
node.cost_level = 1 |
|
|
else: |
|
|
node.cost_level = 0 |
|
|
if stage_type == StageType.START and node.in_start_block: |
|
|
node.pipeline_weight = block_pipeline_weight |
|
|
if node.stage_type == StageType.END: |
|
|
if stage_type == StageType.END: |
|
|
node.pipeline_weight = num_micro_batches - 1 |
|
|
node.cost_level = 1 |
|
|
else: |
|
|
node.cost_level = 0 |
|
|
if stage_type == StageType.END and node.in_end_block: |
|
|
node.pipeline_weight = block_pipeline_weight |
|
|
if isinstance(node, P2PNode): |
|
|
if (graph_config.has_cross_host |
|
|
and node.p2p_type == P2PType.CROSS_HOST) or ( |
|
|
not graph_config.has_cross_host |
|
|
and node.p2p_type == P2PType.CROSS_DEVICE): |
|
|
if stage_type == StageType.BLOCK: |
|
|
node.pipeline_weight += num_micro_batches - 1 |
|
|
node.cost_level = 1 |
|
|
else: |
|
|
node.cost_level = 0 |
|
|
elif (graph_config.has_cross_device |
|
|
and node.p2p_type == P2PType.CROSS_DEVICE) or ( |
|
|
not graph_config.has_cross_device |
|
|
and node.p2p_type == P2PType.CROSS_HOST): |
|
|
node.pipeline_weight += num_micro_batches - 1 |
|
|
if stage_type == StageType.BLOCK and node.in_slowest_block: |
|
|
node.pipeline_weight = block_pipeline_weight |
|
|
|
|
|
def get_cost_graph(self, lmesh): |
|
|
leaf_strategies = [] |
|
|
for node in self.nodes: |
|
|
if node.is_replicated: |
|
|
node.set_strategy(None, lmesh) |
|
|
else: |
|
|
node.collect_strategies(lmesh) |
|
|
for node in self.nodes: |
|
|
strategies_vector = node.update_resharding_cost() |
|
|
if len(strategies_vector) != 0: |
|
|
leaf_strategies.append(strategies_vector) |
|
|
cost_graph = CostGraph(leaf_strategies) |
|
|
return cost_graph |
|
|
|
|
|
def find_solution(self, cost_graph, memory_budget): |
|
|
solver = Solver(cost_graph, memory_budget=memory_budget) |
|
|
solution = solver.find_solution()[1] |
|
|
|
|
|
graph_strategy = solution.node_best_strategy |
|
|
for node_name, strategy in graph_strategy.items(): |
|
|
node = self._nodes[node_name] |
|
|
for idx, pre_node in enumerate(node.predecessor_nodes): |
|
|
if pre_node is None: |
|
|
continue |
|
|
if pre_node.node_name not in strategy.best_resharding_cost: |
|
|
continue |
|
|
strategy.best_resharding_cost[ |
|
|
idx] = strategy.best_resharding_cost[pre_node.node_name] |
|
|
strategy.node_names[idx] = pre_node.node_name |
|
|
for key in list(strategy.best_resharding_cost.keys()): |
|
|
if isinstance(key, str): |
|
|
del strategy.best_resharding_cost[key] |
|
|
|
|
|
return solution |
|
|
|
|
|
def visualize(self, name='pp_graph'): |
|
|
with open(name + '.dot', 'w') as f: |
|
|
f.write("digraph {\n") |
|
|
''' |
|
|
f.write(" // Value Nodes\n") |
|
|
for name, tensor in self._tensors.items(): |
|
|
f.write(" \"{}\" [fillcolor = \"green\", label = \"{}\", shape = \"box\", style = \"filled\"];\n".format(name, tensor.shape)) |
|
|
''' |
|
|
f.write(" // Operation Nodes\n") |
|
|
for name, node in self._nodes.items(): |
|
|
fillcolor = 'white' |
|
|
if 'MATRIX_MULTIPLY' in name: |
|
|
fillcolor = 'green' |
|
|
label = name |
|
|
if len(node.outputs) > 0: |
|
|
label = name + '\\n' + str(node.outputs[0].shape) |
|
|
f.write( |
|
|
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"box\", style = \"filled\"];\n" |
|
|
.format(name, fillcolor, label)) |
|
|
f.write(" // Edges\n") |
|
|
for name, node in self._nodes.items(): |
|
|
for successor_node in node.successor_nodes: |
|
|
if successor_node: |
|
|
f.write(" \"{}\" ->\"{}\";\n".format( |
|
|
name, successor_node.node_name)) |
|
|
f.write(" }\n") |
|
|
|
|
|
def visualize_solution(self, |
|
|
solution, |
|
|
fname='pp_graph_solution', |
|
|
ignore_shape_io=True): |
|
|
with open(fname + '.dot', 'w') as f: |
|
|
names, costs, block_ids = [], [], [] |
|
|
f.write("digraph {\n") |
|
|
f.write(" // Operation Nodes\n") |
|
|
for name, node in self._nodes.items(): |
|
|
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io: |
|
|
continue |
|
|
cost = 0.0 |
|
|
fillcolor = 'white' |
|
|
if 'MATRIX_MULTIPLY' in name or 'PLUGIN_V2_Gemm' in name: |
|
|
fillcolor = 'orange' |
|
|
elif '_same_spec' in name: |
|
|
fillcolor = 'gray' |
|
|
elif 'p2p_block' in name: |
|
|
fillcolor = 'blue' |
|
|
elif 'PLUGIN' in name: |
|
|
fillcolor = 'yellow' |
|
|
|
|
|
shape = 'box' |
|
|
if 'output_node' == node.node_type or 'input_node' == node.node_type: |
|
|
shape = 'ellipse' |
|
|
fillcolor = 'green' |
|
|
|
|
|
label = name + f'_block{node.building_block_id}_weight{node.sharding_weight}' |
|
|
if len(node.inputs) > 0: |
|
|
for idx, input in enumerate(node.inputs): |
|
|
if not input: |
|
|
continue |
|
|
label = label + f'\\ninput{idx}_' + str( |
|
|
input.shape) + f'_{input.dtype_str_size[0]}_' |
|
|
if node.node_name in solution.node_best_strategy: |
|
|
best_strategy = solution.node_best_strategy[ |
|
|
node.node_name] |
|
|
shard_seq = str( |
|
|
best_strategy.sharding_specs[f'input{idx}']. |
|
|
sharding_sequence) |
|
|
label = label + shard_seq |
|
|
if idx not in best_strategy.best_resharding_cost: |
|
|
continue |
|
|
rcosts = best_strategy.best_resharding_cost[idx][0] |
|
|
comm_action_sequence, resharding_cost = rcosts[ |
|
|
1], rcosts[2] |
|
|
if len(comm_action_sequence) > 0: |
|
|
label = label + '|' |
|
|
for commspec in comm_action_sequence: |
|
|
comm = [ |
|
|
commspec.comm_pattern, commspec.gather_dim, |
|
|
commspec.shard_dim, |
|
|
commspec.logical_process_axis |
|
|
] |
|
|
label = label + '->' + str(comm) |
|
|
if resharding_cost > 0: |
|
|
label = label + '_rcost{:.2}'.format( |
|
|
resharding_cost) |
|
|
cost = cost + resharding_cost |
|
|
if len(node.outputs) > 0: |
|
|
best_strategy = None |
|
|
for idx, output in enumerate(node.outputs): |
|
|
label = label + f'\\noutput{idx}_' + str( |
|
|
output.shape) + f'_{output.dtype_str_size[0]}' |
|
|
if node.node_name in solution.node_best_strategy: |
|
|
best_strategy = solution.node_best_strategy[ |
|
|
node.node_name] |
|
|
shard_seq = str( |
|
|
best_strategy.sharding_specs[f'output{idx}']. |
|
|
sharding_sequence) |
|
|
comm = None |
|
|
if f'output{idx}' in best_strategy.communication_actions: |
|
|
commspec = best_strategy.communication_actions[ |
|
|
f'output{idx}'] |
|
|
comm = [ |
|
|
commspec.comm_pattern, commspec.gather_dim, |
|
|
commspec.shard_dim, |
|
|
commspec.logical_process_axis |
|
|
] |
|
|
label = label + '_' + shard_seq |
|
|
if comm: |
|
|
label = label + f' | {comm}' |
|
|
if best_strategy: |
|
|
cost = cost + best_strategy.sharding_cost + best_strategy.communication_cost |
|
|
label = label + '| scost{:.2}'.format( |
|
|
best_strategy.sharding_cost) |
|
|
if best_strategy.communication_cost > 0: |
|
|
label = label + ' | ccost{:.2}'.format( |
|
|
best_strategy.communication_cost) |
|
|
names.append(name) |
|
|
costs.append(cost) |
|
|
block_ids.append([ |
|
|
node.building_block_id, node.cost_level, |
|
|
node.sharding_weight + node.pipeline_weight, |
|
|
node.same_spec_id |
|
|
]) |
|
|
f.write( |
|
|
" \"{}\" [fillcolor = \"{}\", label = \"{}\", shape = \"{}\", style = \"filled\"];\n" |
|
|
.format(name, fillcolor, label, shape)) |
|
|
f.write(" // Edges\n") |
|
|
for name, node in self._nodes.items(): |
|
|
if ignore_shape_io and node.layer is not None and node.layer.is_shape_io: |
|
|
continue |
|
|
for successor_node in node.successor_nodes: |
|
|
if successor_node: |
|
|
if ignore_shape_io and successor_node.layer is not None and successor_node.layer.is_shape_io: |
|
|
continue |
|
|
f.write(" \"{}\" ->\"{}\";\n".format( |
|
|
name, successor_node.node_name)) |
|
|
f.write(" }\n") |
|
|
df = pd.DataFrame.from_dict({ |
|
|
'node': |
|
|
names, |
|
|
'cost': |
|
|
costs, |
|
|
'block_id': [block[0] for block in block_ids], |
|
|
'cost_level': [block[1] for block in block_ids], |
|
|
'sharding_weight': [block[2] for block in block_ids], |
|
|
'same_spec_id': [block[3] for block in block_ids] |
|
|
}) |
|
|
df['weight_cost'] = df['sharding_weight'] * df['cost'] |
|
|
df.to_csv(fname + '.csv') |
|
|
|