|
|
import math |
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
from typing import Dict, List, Tuple |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
from tensorrt_llm.network import Network |
|
|
|
|
|
from .config import AutoParallelConfig |
|
|
from .device_mesh import PhysicalDeviceMesh |
|
|
from .pipeline_graph import PipelineGraph |
|
|
from .shape_info import ShapeInfo, ShapeType, get_shape_info |
|
|
from .tensor_parallel.p2p_node import P2PType |
|
|
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger |
|
|
|
|
|
|
|
|
class StageType(Enum): |
|
|
START = 0 |
|
|
BLOCK = 1 |
|
|
END = 2 |
|
|
|
|
|
|
|
|
class BuildingBlock: |
|
|
|
|
|
def __init__(self, graph, layer_range) -> None: |
|
|
self.graph = graph |
|
|
self.layer_range = layer_range |
|
|
self.network = graph.as_trt() |
|
|
self.owned_inputs = {} |
|
|
self.is_edges_collected = False |
|
|
self.intra_edges = [] |
|
|
self.src_inter_edges = [] |
|
|
self.dst_inter_edges = [] |
|
|
self.relative_src_inter_edges = [] |
|
|
self.relative_dst_inter_edges = [] |
|
|
self.relative_inter_edges = set() |
|
|
self.edge_hash = None |
|
|
self.outputs = None |
|
|
self.type_id = -1 |
|
|
self.block_id = -1 |
|
|
self.p2p_type = None |
|
|
self.is_superset = False |
|
|
self.is_subset = False |
|
|
self.sorted_layer_ids = [] |
|
|
|
|
|
def collect_edges(self): |
|
|
if self.is_edges_collected: |
|
|
return |
|
|
for layer_index in self.layer_range: |
|
|
trt_layer = self.network.get_layer(layer_index) |
|
|
layer = self.graph.get_layer(trt_layer.name) |
|
|
layer_offset = layer.index - self.layer_range.start |
|
|
for input_index, input in enumerate(layer.inputs): |
|
|
if input is not None: |
|
|
if input.is_graph_input: |
|
|
is_owned = input.graph_input_index in self.owned_inputs |
|
|
if not is_owned and np.all([ |
|
|
layer.index in self.layer_range or np.all([ |
|
|
output.as_trt().is_shape_tensor |
|
|
for output in layer.outputs |
|
|
]) for layer, _ in input.consumers |
|
|
]): |
|
|
self.owned_inputs[input.graph_input_index] = len( |
|
|
self.owned_inputs) |
|
|
is_owned = True |
|
|
if is_owned: |
|
|
self.intra_edges.append( |
|
|
(-1, self.owned_inputs[input.graph_input_index], |
|
|
layer_offset, input_index)) |
|
|
else: |
|
|
self.dst_inter_edges.append( |
|
|
(-1, input.graph_input_index, layer_offset, |
|
|
input_index)) |
|
|
else: |
|
|
src_layer_index = input.producer.index |
|
|
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop: |
|
|
self.dst_inter_edges.append( |
|
|
(src_layer_index, input.output_index, |
|
|
layer_offset, input_index)) |
|
|
else: |
|
|
src_layer_offset = src_layer_index - self.layer_range.start |
|
|
self.intra_edges.append( |
|
|
(src_layer_offset, input.output_index, |
|
|
layer_offset, input_index)) |
|
|
for output_index, output in enumerate(layer.outputs): |
|
|
for dst_layer, dst_input_index in output.consumers: |
|
|
dst_layer_index = dst_layer.index |
|
|
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop: |
|
|
self.src_inter_edges.append( |
|
|
(layer_offset, output_index, dst_layer_index, |
|
|
dst_input_index)) |
|
|
self.edge_hash = tuple(self.intra_edges) |
|
|
self.outputs = sorted( |
|
|
set((edge[0], edge[1]) for edge in self.src_inter_edges)) |
|
|
self.is_edges_collected = True |
|
|
|
|
|
def collect_relative_inter_edges(self, layer_to_block): |
|
|
self.collect_edges() |
|
|
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges: |
|
|
if src_layer_index in layer_to_block: |
|
|
src_block = layer_to_block[src_layer_index] |
|
|
src_layer_offset = src_layer_index - src_block.layer_range.start |
|
|
dst = (self.type_id, dst_layer_index, dst_input_index) |
|
|
self.relative_dst_inter_edges.append( |
|
|
(src_block.type_id, src_layer_offset, src_output_index, |
|
|
*dst)) |
|
|
else: |
|
|
self.relative_dst_inter_edges.append( |
|
|
(-1, src_layer_index, src_output_index, self.type_id, |
|
|
dst_layer_index, dst_input_index)) |
|
|
self.relative_inter_edges = set(self.relative_dst_inter_edges + |
|
|
self.outputs) |
|
|
|
|
|
def get_input_names(self): |
|
|
self.collect_edges() |
|
|
input_tensor_names = [] |
|
|
for edge in self.dst_inter_edges: |
|
|
layer_index = edge[0] |
|
|
output_index = edge[1] |
|
|
if layer_index == -1: |
|
|
tensor_name = self.network.get_input(output_index).name |
|
|
else: |
|
|
tensor_name = self.network.get_layer(layer_index).get_output( |
|
|
output_index).name |
|
|
input_tensor_names.append(tensor_name) |
|
|
return input_tensor_names |
|
|
|
|
|
def get_input_mapping(self, last_blocks): |
|
|
input_mapping = {} |
|
|
for tensor_name, relative_edge in zip(self.get_input_names(), |
|
|
self.relative_dst_inter_edges): |
|
|
type_id = relative_edge[0] |
|
|
output_index = relative_edge[2] |
|
|
if type_id >= 0: |
|
|
last_block = last_blocks[type_id] |
|
|
layer_offset = relative_edge[1] |
|
|
mapped_layer_index = last_block.layer_range.start + layer_offset |
|
|
mapped_tensor_name = self.network.get_layer( |
|
|
mapped_layer_index).get_output(output_index).name |
|
|
input_mapping[tensor_name] = mapped_tensor_name |
|
|
else: |
|
|
input_mapping[tensor_name] = tensor_name |
|
|
return input_mapping |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GraphMapping: |
|
|
layer_mapping: Dict[int, int] = None |
|
|
block_mapping: Dict[int, int] = None |
|
|
p2p_types: Dict[int, P2PType] = None |
|
|
p2p_tensors: Dict[int, List[str]] = None |
|
|
block_to_stage: Dict[int, int] = None |
|
|
same_spec_layer_mapping: Dict[str, str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GraphConfig: |
|
|
num_micro_batches: int = 1 |
|
|
num_blocks: int = 1 |
|
|
num_stages: int = 1 |
|
|
has_cross_device: bool = False |
|
|
has_cross_host: bool = False |
|
|
graph_mapping: GraphMapping = None |
|
|
phy_mesh: PhysicalDeviceMesh = None |
|
|
stage_phy_meshes: List[PhysicalDeviceMesh] = None |
|
|
|
|
|
|
|
|
class Simplifier: |
|
|
|
|
|
def __init__(self, network: Network, config: AutoParallelConfig): |
|
|
self.config = config |
|
|
self.sharded_io_allowlist = config.sharded_io_allowlist |
|
|
self.same_buffer_io = config.same_buffer_io |
|
|
self.same_spec_io = config.same_spec_io.copy() |
|
|
for key, value in self.same_buffer_io.items(): |
|
|
if key not in self.same_spec_io: |
|
|
self.same_spec_io[key] = value |
|
|
|
|
|
self.llm_network = network |
|
|
self.network = network.trt_network |
|
|
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map |
|
|
self.graph = self.get_graph() |
|
|
self.init_layer_hash() |
|
|
|
|
|
module_tree = self.get_module_tree() |
|
|
building_blocks = self.collect_building_blocks(module_tree) |
|
|
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks) |
|
|
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash( |
|
|
blocks_by_module_hash) |
|
|
self.layer_to_block = self.get_layer_to_block() |
|
|
self.blocks = self.get_all_blocks() |
|
|
self.backbone_blocks = self.get_backbone_blocks() |
|
|
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape() |
|
|
self.graph_for_shape = self.create_simplified_graph_for_shape() |
|
|
self.shape_info = None |
|
|
self.num_micro_batches = None |
|
|
|
|
|
def infer_shapes(self, num_micro_batches): |
|
|
if self.num_micro_batches == num_micro_batches: |
|
|
return |
|
|
with silent_trt_logger(): |
|
|
self.shape_info = self.get_full_shape_info(num_micro_batches) |
|
|
self.graph.assign_shapes(self.shape_info) |
|
|
self.num_micro_batches = num_micro_batches |
|
|
|
|
|
def list_all_num_micro_batches(self): |
|
|
opt_batch_size = self.get_opt_batch_size() |
|
|
candidates = [] |
|
|
for num_micro_batches in range(1, self.get_opt_batch_size() + 1): |
|
|
if opt_batch_size % num_micro_batches == 0: |
|
|
candidates.append(num_micro_batches) |
|
|
return candidates |
|
|
|
|
|
def get_graph(self): |
|
|
graph = PipelineGraph.from_trt(self.network) |
|
|
graph._unfilled_weights = self.llm_network._unfilled_weights.copy() |
|
|
graph._io_buffer_mapping |
|
|
for input in graph.inputs: |
|
|
input_name = input.name |
|
|
for pattern, repl in self.same_buffer_io.items(): |
|
|
if re.match(pattern, input_name): |
|
|
output_name = re.sub(pattern, repl, input_name) |
|
|
output = graph.get_output(output_name) |
|
|
if output is not None: |
|
|
graph._io_buffer_mapping[output_name] = input_name |
|
|
return graph |
|
|
|
|
|
def get_opt_batch_size(self): |
|
|
input_tensors = self.llm_network._inputs |
|
|
num_profiles = len(list(input_tensors.values())[0].profiles) |
|
|
opt_batch_sizes = [] |
|
|
for i in range(num_profiles): |
|
|
for input_tensor in input_tensors.values(): |
|
|
shape_profile = input_tensor.profiles[i] |
|
|
opt_shape = shape_profile.opt |
|
|
for j in range(len(input_tensor.shape)): |
|
|
name = input_tensor.trt_tensor.get_dimension_name(j) |
|
|
if name == 'batch_size': |
|
|
opt_batch_sizes.append(opt_shape[j]) |
|
|
return min(opt_batch_sizes) |
|
|
|
|
|
def get_module_hash(self, layer_range): |
|
|
module_hash = () |
|
|
for i in layer_range: |
|
|
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}" |
|
|
layer_name = self.network.get_layer(i).name |
|
|
layer = self.graph.get_layer(layer_name) |
|
|
module_hash += (layer.attrs["hash"], ) |
|
|
return module_hash |
|
|
|
|
|
def get_network_hash(self) -> str: |
|
|
return str(self.get_module_hash(range(self.network.num_layers))) |
|
|
|
|
|
def collect_building_blocks(self, module_tree): |
|
|
building_blocks = {} |
|
|
queue = [] |
|
|
for tree in module_tree["children"].values(): |
|
|
queue.append(tree) |
|
|
while len(queue) > 0: |
|
|
while len(queue) > 0: |
|
|
tree = queue.pop(0) |
|
|
module_name = tree["name"] |
|
|
if module_name is None: |
|
|
for child in tree["children"].values(): |
|
|
queue.append(child) |
|
|
continue |
|
|
layer_range = self.module_to_layer_range_map[module_name] |
|
|
module_hash = self.get_module_hash(layer_range) |
|
|
if module_hash in building_blocks: |
|
|
building_blocks[module_hash].append(tree) |
|
|
else: |
|
|
building_blocks[module_hash] = [tree] |
|
|
for module_hash in [*building_blocks.keys()]: |
|
|
if len(building_blocks[module_hash]) == 1: |
|
|
tree = building_blocks[module_hash][0] |
|
|
for child in tree["children"].values(): |
|
|
queue.append(child) |
|
|
del building_blocks[module_hash] |
|
|
blocks_by_module_hash = { |
|
|
module_hash: [ |
|
|
BuildingBlock(self.graph, |
|
|
self.module_to_layer_range_map[tree["name"]]) |
|
|
for tree in trees |
|
|
] |
|
|
for module_hash, trees in building_blocks.items() |
|
|
} |
|
|
building_blocks = [] |
|
|
for block_list in blocks_by_module_hash.values(): |
|
|
for block in block_list: |
|
|
building_blocks.append(block) |
|
|
building_blocks = sorted(building_blocks, |
|
|
key=lambda x: x.layer_range.start) |
|
|
if len(building_blocks) >= 2: |
|
|
for block, next_block in zip(building_blocks[:-1], |
|
|
building_blocks[1:]): |
|
|
block.layer_range = range(block.layer_range.start, |
|
|
next_block.layer_range.start) |
|
|
return building_blocks |
|
|
|
|
|
def get_all_blocks(self): |
|
|
building_blocks = [] |
|
|
for block_list in self.blocks_by_edge_hash.values(): |
|
|
for block in block_list: |
|
|
building_blocks.append(block) |
|
|
building_blocks = sorted(building_blocks, |
|
|
key=lambda x: x.layer_range.start) |
|
|
all_blocks = [] |
|
|
current_layer_index = 0 |
|
|
block_id = 0 |
|
|
for block in building_blocks: |
|
|
assert current_layer_index <= block.layer_range.start |
|
|
if current_layer_index < block.layer_range.start: |
|
|
new_block = BuildingBlock( |
|
|
self.graph, |
|
|
range(current_layer_index, block.layer_range.start)) |
|
|
new_block.block_id = block_id |
|
|
block_id += 1 |
|
|
all_blocks.append(new_block) |
|
|
block.block_id = block_id |
|
|
block_id += 1 |
|
|
all_blocks.append(block) |
|
|
current_layer_index = block.layer_range.stop |
|
|
if current_layer_index < self.graph.num_layers: |
|
|
new_block = BuildingBlock( |
|
|
self.graph, range(current_layer_index, self.graph.num_layers)) |
|
|
new_block.block_id = block_id |
|
|
all_blocks.append(new_block) |
|
|
sorted_layer_ids = get_sorted_layer_ids(self.network) |
|
|
for block in all_blocks: |
|
|
block.collect_relative_inter_edges(self.layer_to_block) |
|
|
for layer_id in sorted_layer_ids: |
|
|
if layer_id in block.layer_range: |
|
|
block.sorted_layer_ids.append(layer_id) |
|
|
return all_blocks |
|
|
|
|
|
def get_backbone_blocks(self): |
|
|
sorted_blocks = sorted( |
|
|
self.blocks_by_edge_hash.values(), |
|
|
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)), |
|
|
) |
|
|
if len(sorted_blocks) == 0: |
|
|
return [] |
|
|
else: |
|
|
return sorted_blocks[-1] |
|
|
|
|
|
def get_blocks_by_module_hash(self, blocks): |
|
|
blocks_by_module_hash = {} |
|
|
for block in blocks: |
|
|
module_hash = self.get_module_hash(block.layer_range) |
|
|
if module_hash not in blocks_by_module_hash: |
|
|
blocks_by_module_hash[module_hash] = [] |
|
|
blocks_by_module_hash[module_hash].append(block) |
|
|
for module_hash in [*blocks_by_module_hash.keys()]: |
|
|
if len(blocks_by_module_hash[module_hash]) == 1: |
|
|
del blocks_by_module_hash[module_hash] |
|
|
return blocks_by_module_hash |
|
|
|
|
|
def get_module_tree(self): |
|
|
module_tree = {"children": {}, "name": None} |
|
|
for module_name in self.module_to_layer_range_map.keys(): |
|
|
full_name = module_name.split('.') |
|
|
current_tree = module_tree["children"] |
|
|
for depth, name in enumerate(full_name): |
|
|
if name not in current_tree: |
|
|
current_tree[name] = {"children": {}, "name": None} |
|
|
if depth == len(full_name) - 1: |
|
|
current_tree[name]["name"] = module_name |
|
|
else: |
|
|
current_tree = current_tree[name]["children"] |
|
|
return module_tree |
|
|
|
|
|
def get_blocks_by_edge_hash(self, blocks_by_module_hash): |
|
|
blocks_by_edge_hash = {} |
|
|
for block_list in blocks_by_module_hash.values(): |
|
|
for block in block_list: |
|
|
block.collect_edges() |
|
|
edge_hash = block.edge_hash |
|
|
if edge_hash not in blocks_by_edge_hash: |
|
|
blocks_by_edge_hash[edge_hash] = [] |
|
|
blocks_by_edge_hash[edge_hash].append(block) |
|
|
for edge_hash in [*blocks_by_edge_hash.keys()]: |
|
|
if len(blocks_by_edge_hash[edge_hash]) == 1: |
|
|
del blocks_by_edge_hash[edge_hash] |
|
|
else: |
|
|
block_list = blocks_by_edge_hash[edge_hash] |
|
|
blocks_by_edge_hash[edge_hash] = sorted( |
|
|
block_list, key=lambda x: x.layer_range.start) |
|
|
for type_id, block_list in enumerate(blocks_by_edge_hash.values()): |
|
|
for block in block_list: |
|
|
block.type_id = type_id |
|
|
return blocks_by_edge_hash |
|
|
|
|
|
def get_layer_to_block(self): |
|
|
layer_to_block = {} |
|
|
for block_list in self.blocks_by_edge_hash.values(): |
|
|
for block in block_list: |
|
|
for layer_index in block.layer_range: |
|
|
layer_to_block[layer_index] = block |
|
|
return layer_to_block |
|
|
|
|
|
def clean_blocks(self): |
|
|
for block in self.blocks: |
|
|
block.p2p_type = None |
|
|
block.is_superset = False |
|
|
block.is_subset = False |
|
|
|
|
|
def mark_p2p_type(self, phy_mesh, stage_phy_meshes, |
|
|
graph_config: GraphConfig): |
|
|
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1: |
|
|
return |
|
|
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0 |
|
|
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes) |
|
|
|
|
|
for block in self.backbone_blocks: |
|
|
block.p2p_type = None |
|
|
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]): |
|
|
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1] |
|
|
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1] |
|
|
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten( |
|
|
)[0] |
|
|
num_devices_per_host = phy_mesh.num_devices_per_host |
|
|
next_block = self.backbone_blocks[(stage_index + 1) * |
|
|
block_per_stage] |
|
|
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host: |
|
|
next_block.p2p_type = P2PType.CROSS_HOST |
|
|
graph_config.has_cross_host = True |
|
|
else: |
|
|
next_block.p2p_type = P2PType.CROSS_DEVICE |
|
|
graph_config.has_cross_device = True |
|
|
|
|
|
def get_graph_mapping(self): |
|
|
layer_mapping = {} |
|
|
block_mapping = {} |
|
|
p2p_types = {} |
|
|
p2p_tensors = {} |
|
|
for block_list in self.blocks_by_edge_hash.values(): |
|
|
superset_blocks = [] |
|
|
superset_block_index = {} |
|
|
for block in block_list: |
|
|
block_added = False |
|
|
for index, superset_block in enumerate(list(superset_blocks)): |
|
|
if block.p2p_type == superset_block.p2p_type: |
|
|
if block.relative_inter_edges.issubset( |
|
|
superset_block.relative_inter_edges): |
|
|
block.is_subset = True |
|
|
block.is_superset = False |
|
|
superset_block_index[id(block)] = index |
|
|
block_added = True |
|
|
break |
|
|
elif superset_block.relative_inter_edges.issubset( |
|
|
block.relative_inter_edges): |
|
|
superset_block.is_subset = True |
|
|
superset_block.is_superset = False |
|
|
block.is_subset = False |
|
|
block.is_superset = True |
|
|
superset_blocks[index] = block |
|
|
superset_block_index[id(block)] = index |
|
|
block_added = True |
|
|
break |
|
|
if not block_added: |
|
|
block.is_subset = False |
|
|
block.is_superset = True |
|
|
superset_blocks.append(block) |
|
|
superset_block_index[id(block)] = len(superset_blocks) - 1 |
|
|
for block in block_list: |
|
|
assert not (block.is_subset and block.is_superset) |
|
|
if block.is_subset: |
|
|
superset_block = superset_blocks[superset_block_index[id( |
|
|
block)]] |
|
|
block_mapping[block.block_id] = superset_block.block_id |
|
|
owned_inputs = map( |
|
|
lambda x: x[0], |
|
|
sorted(block.owned_inputs.items(), key=lambda x: x[1])) |
|
|
superset_owned_inputs = map( |
|
|
lambda x: x[0], |
|
|
sorted(superset_block.owned_inputs.items(), |
|
|
key=lambda x: x[1])) |
|
|
for from_input_id, to_input_id in zip( |
|
|
owned_inputs, superset_owned_inputs): |
|
|
from_input_name = self.network.get_input( |
|
|
from_input_id).name |
|
|
to_input_name = self.network.get_input(to_input_id).name |
|
|
layer_mapping[from_input_name] = to_input_name |
|
|
for from_layer_id, to_layer_id in zip( |
|
|
block.layer_range, superset_block.layer_range): |
|
|
from_layer = self.network.get_layer(from_layer_id) |
|
|
to_layer = self.network.get_layer(to_layer_id) |
|
|
layer_mapping[from_layer.name] = to_layer.name |
|
|
for i in range(from_layer.num_outputs): |
|
|
from_output = from_layer.get_output(i) |
|
|
if from_output.is_network_output: |
|
|
to_output = to_layer.get_output(i) |
|
|
layer_mapping[from_output.name] = to_output.name |
|
|
if block.p2p_type is not None: |
|
|
p2p_types[block.block_id] = block.p2p_type |
|
|
p2p_tensors[block.block_id] = [ |
|
|
*set(block.get_input_names()) |
|
|
] |
|
|
for from_name, to_name in zip( |
|
|
block.get_input_names(), |
|
|
superset_block.get_input_names()): |
|
|
layer_mapping[ |
|
|
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}" |
|
|
stage_id = 0 |
|
|
block_to_stage = {} |
|
|
for block in self.blocks: |
|
|
if block.p2p_type is not None: |
|
|
stage_id += 1 |
|
|
block_to_stage[block.block_id] = stage_id |
|
|
return GraphMapping( |
|
|
layer_mapping, |
|
|
block_mapping, |
|
|
p2p_types, |
|
|
p2p_tensors, |
|
|
block_to_stage, |
|
|
) |
|
|
|
|
|
def create_simplified_graph(self, graph_config: GraphConfig): |
|
|
new_graph = PipelineGraph.create_graph() |
|
|
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping |
|
|
layer_mapping = graph_config.graph_mapping.layer_mapping |
|
|
|
|
|
for i in range(self.network.num_inputs): |
|
|
trt_input = self.network.get_input(i) |
|
|
if trt_input.name not in layer_mapping: |
|
|
new_graph.add_input(trt_input) |
|
|
|
|
|
last_blocks = {} |
|
|
same_spec_mapping = {} |
|
|
same_spec_layer_mapping = {} |
|
|
shape_mapping = {} |
|
|
building_block_id = 0 |
|
|
same_spec_ids = {} |
|
|
same_spec_count = 0 |
|
|
for block in self.blocks: |
|
|
if not block.is_subset: |
|
|
stage_type = None |
|
|
if not block.is_superset: |
|
|
if block.block_id == 0: |
|
|
stage_type = StageType.START |
|
|
elif block.block_id == len(self.blocks) - 1: |
|
|
stage_type = StageType.END |
|
|
input_mapping = block.get_input_mapping(last_blocks) |
|
|
for from_name, to_name in [*input_mapping.items()]: |
|
|
if to_name in same_spec_mapping: |
|
|
input_mapping[from_name] = same_spec_mapping[to_name] |
|
|
if to_name in layer_mapping: |
|
|
input_mapping[from_name] = layer_mapping[to_name] |
|
|
if block.is_superset and block.p2p_type is not None: |
|
|
for from_name, to_name in [*input_mapping.items()]: |
|
|
output_tensor = new_graph.get_tensor(to_name) |
|
|
p2p_layer = new_graph.as_trt().add_identity( |
|
|
output_tensor.as_trt()) |
|
|
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}" |
|
|
p2p_layer.metadata = p2p_layer.name |
|
|
p2p_tensor = p2p_layer.get_output(0) |
|
|
p2p_tensor.name = f"{p2p_layer.name}_output" |
|
|
wrapped_layer = new_graph.register_layer(p2p_layer) |
|
|
wrapped_layer.attrs[ |
|
|
"building_block_id"] = building_block_id |
|
|
wrapped_layer.attrs["p2p_type"] = block.p2p_type |
|
|
input_mapping[from_name] = p2p_tensor.name |
|
|
shape_mapping[p2p_tensor.name] = from_name |
|
|
building_block_id += 1 |
|
|
for i in block.sorted_layer_ids: |
|
|
layer = self.network.get_layer(i) |
|
|
wrapped_layer = new_graph.add_layer( |
|
|
layer, |
|
|
input_mapping=input_mapping, |
|
|
) |
|
|
wrapped_layer.attrs["building_block_id"] = building_block_id |
|
|
wrapped_layer.attrs["stage_type"] = stage_type |
|
|
if block.is_superset: |
|
|
last_blocks[block.type_id] = block |
|
|
|
|
|
if block.type_id in same_spec_ids: |
|
|
same_spec_id = same_spec_ids[block.type_id] |
|
|
update_same_spec_count = False |
|
|
else: |
|
|
same_spec_id = same_spec_count |
|
|
same_spec_ids[block.type_id] = same_spec_id |
|
|
update_same_spec_count = True |
|
|
count = same_spec_id |
|
|
for i, (layer_offset, |
|
|
output_index) in enumerate(block.outputs): |
|
|
layer = self.network.get_layer(block.layer_range.start + |
|
|
layer_offset) |
|
|
tensor_name = layer.get_output(output_index).name |
|
|
output_tensor = new_graph.get_tensor(tensor_name) |
|
|
same_spec_layer = new_graph.as_trt().add_identity( |
|
|
output_tensor.as_trt()) |
|
|
same_spec_layer.name = f"{tensor_name}_same_spec" |
|
|
same_spec_layer.metadata = same_spec_layer.name |
|
|
same_spec_tensor = same_spec_layer.get_output(0) |
|
|
same_spec_tensor.name = f"{same_spec_layer.name}_output" |
|
|
wrapped_layer = new_graph.register_layer( |
|
|
same_spec_layer) |
|
|
wrapped_layer.attrs[ |
|
|
"building_block_id"] = building_block_id |
|
|
wrapped_layer.attrs["same_spec_id"] = count |
|
|
count += 1 |
|
|
same_spec_mapping[tensor_name] = same_spec_tensor.name |
|
|
same_spec_layer_mapping[ |
|
|
same_spec_layer.name] = layer.name |
|
|
shape_mapping[same_spec_tensor.name] = tensor_name |
|
|
for i, graph_input_index in enumerate( |
|
|
block.owned_inputs.keys()): |
|
|
input_name = self.network.get_input( |
|
|
graph_input_index).name |
|
|
input_tensor = new_graph.get_input(input_name) |
|
|
input_tensor.attrs["same_spec_id"] = count |
|
|
count += 1 |
|
|
if update_same_spec_count: |
|
|
same_spec_count = count |
|
|
building_block_id += 1 |
|
|
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping |
|
|
|
|
|
if len(self.backbone_blocks) >= 2: |
|
|
start_block = self.backbone_blocks[0] |
|
|
if start_block.is_subset: |
|
|
start_block = self.blocks[graph_config.graph_mapping. |
|
|
block_mapping[start_block.block_id]] |
|
|
for i in start_block.layer_range: |
|
|
layer_name = self.network.get_layer(i).name |
|
|
layer = new_graph.get_layer(layer_name) |
|
|
layer.attrs["in_start_block"] = True |
|
|
end_block = self.backbone_blocks[-1] |
|
|
if end_block.is_subset: |
|
|
end_block = self.blocks[graph_config.graph_mapping. |
|
|
block_mapping[end_block.block_id]] |
|
|
for i in end_block.layer_range: |
|
|
layer_name = self.network.get_layer(i).name |
|
|
layer = new_graph.get_layer(layer_name) |
|
|
layer.attrs["in_end_block"] = True |
|
|
slowest_p2p_type = None |
|
|
if graph_config.has_cross_host: |
|
|
slowest_p2p_type = P2PType.CROSS_HOST |
|
|
elif graph_config.has_cross_device: |
|
|
slowest_p2p_type = P2PType.CROSS_DEVICE |
|
|
if slowest_p2p_type is not None: |
|
|
for block in self.blocks: |
|
|
if block.is_superset and block.p2p_type == slowest_p2p_type: |
|
|
for i in block.layer_range: |
|
|
layer_name = self.network.get_layer(i).name |
|
|
layer = new_graph.get_layer(layer_name) |
|
|
layer.attrs["in_slowest_block"] = True |
|
|
|
|
|
for i in range(self.network.num_outputs): |
|
|
trt_output = self.network.get_output(i) |
|
|
output = self.graph.get_output(trt_output.name) |
|
|
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[ |
|
|
output.producer.index].is_subset: |
|
|
continue |
|
|
if trt_output.is_shape_tensor: |
|
|
new_output = new_graph.add_output_shape(trt_output) |
|
|
else: |
|
|
new_output = new_graph.add_output(trt_output) |
|
|
sharded_io = False |
|
|
for pattern in self.sharded_io_allowlist: |
|
|
if re.match(pattern, new_output.name): |
|
|
sharded_io = True |
|
|
break |
|
|
if not sharded_io: |
|
|
new_output.producer.attrs["is_replicated"] = True |
|
|
|
|
|
for input in new_graph.inputs: |
|
|
input_name = input.name |
|
|
sharded_io = False |
|
|
for pattern in self.sharded_io_allowlist: |
|
|
if re.match(pattern, input_name): |
|
|
sharded_io = True |
|
|
break |
|
|
if not sharded_io: |
|
|
input.attrs["is_replicated"] = True |
|
|
for pattern, repl in self.same_spec_io.items(): |
|
|
if re.match(pattern, input_name): |
|
|
output_name = re.sub(pattern, repl, input_name) |
|
|
output = new_graph.get_output(output_name) |
|
|
if output is not None: |
|
|
if "same_spec_id" in input.attrs: |
|
|
same_spec_id = input.attrs["same_spec_id"] |
|
|
else: |
|
|
same_spec_id = same_spec_count |
|
|
same_spec_count += 1 |
|
|
input.attrs["same_spec_id"] = same_spec_id |
|
|
output.attrs["same_spec_id"] = same_spec_id |
|
|
if math.prod(self.graph.get_input( |
|
|
input_name).shape) < math.prod( |
|
|
self.graph.get_output(output_name).shape): |
|
|
input.attrs["no_memory_footprint"] = True |
|
|
else: |
|
|
output.attrs["no_memory_footprint"] = True |
|
|
|
|
|
return new_graph, shape_mapping |
|
|
|
|
|
def enrich_shape_info(self, shape_mapping): |
|
|
shapes = self.shape_info.shapes.copy() |
|
|
max_shapes = self.shape_info.max_shapes.copy() |
|
|
values = self.shape_info.values.copy() |
|
|
shape_layers = self.shape_info.shape_layers |
|
|
for from_name, to_name in shape_mapping.items(): |
|
|
if to_name in shapes: |
|
|
shapes[from_name] = shapes[to_name] |
|
|
if to_name in max_shapes: |
|
|
max_shapes[from_name] = max_shapes[to_name] |
|
|
if to_name in values: |
|
|
values[from_name] = values[to_name] |
|
|
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes) |
|
|
return shape_info |
|
|
|
|
|
def simplify_graph( |
|
|
self, phy_mesh: PhysicalDeviceMesh, num_stages: int, |
|
|
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]: |
|
|
num_blocks = len(self.backbone_blocks) |
|
|
if num_blocks % num_stages != 0: |
|
|
return None, None |
|
|
graph_config = GraphConfig() |
|
|
graph_config.num_micro_batches = self.num_micro_batches |
|
|
graph_config.num_blocks = num_blocks |
|
|
graph_config.num_stages = num_stages |
|
|
graph_config.phy_mesh = phy_mesh |
|
|
stage_phy_meshes = phy_mesh.split_pipeline_meshes( |
|
|
num_stages, num_devices_per_stage) |
|
|
graph_config.stage_phy_meshes = stage_phy_meshes |
|
|
with silent_trt_logger(): |
|
|
self.clean_blocks() |
|
|
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config) |
|
|
graph_config.graph_mapping = self.get_graph_mapping() |
|
|
new_graph, shape_mapping = self.create_simplified_graph( |
|
|
graph_config) |
|
|
shape_info = self.enrich_shape_info(shape_mapping) |
|
|
new_graph.assign_shapes(shape_info) |
|
|
return new_graph, graph_config |
|
|
|
|
|
def get_graph_mapping_for_shape(self): |
|
|
layer_mapping = {} |
|
|
tensor_mapping = {} |
|
|
for block_list in self.blocks_by_edge_hash.values(): |
|
|
head_block = block_list[0] |
|
|
for block in block_list[1:]: |
|
|
for from_layer_id, to_layer_id in zip(block.layer_range, |
|
|
head_block.layer_range): |
|
|
from_layer = self.network.get_layer(from_layer_id) |
|
|
to_layer = self.network.get_layer(to_layer_id) |
|
|
layer_mapping[from_layer.name] = to_layer.name |
|
|
for i in range(from_layer.num_outputs): |
|
|
tensor_mapping[from_layer.get_output( |
|
|
i).name] = to_layer.get_output(i).name |
|
|
return layer_mapping, tensor_mapping |
|
|
|
|
|
def create_simplified_graph_for_shape(self): |
|
|
new_graph = PipelineGraph.create_graph() |
|
|
|
|
|
for i in range(self.network.num_inputs): |
|
|
trt_input = self.network.get_input(i) |
|
|
new_graph.add_input(trt_input) |
|
|
|
|
|
head_blocks = {} |
|
|
removed_blocks = set() |
|
|
removed_layers = set() |
|
|
for block_list in self.blocks_by_edge_hash.values(): |
|
|
head_block = block_list[0] |
|
|
head_blocks[head_block.type_id] = head_block |
|
|
for block in block_list[1:]: |
|
|
removed_blocks.add(id(block)) |
|
|
for layer_index in block.layer_range: |
|
|
removed_layers.add(layer_index) |
|
|
|
|
|
for block in self.blocks: |
|
|
if not id(block) in removed_blocks: |
|
|
input_mapping = block.get_input_mapping(head_blocks) |
|
|
for i in block.sorted_layer_ids: |
|
|
layer = self.network.get_layer(i) |
|
|
new_graph.add_layer( |
|
|
layer, |
|
|
input_mapping=input_mapping, |
|
|
) |
|
|
|
|
|
for i in range(self.network.num_outputs): |
|
|
trt_output = self.network.get_output(i) |
|
|
output = self.graph.get_output(trt_output.name) |
|
|
if output.producer is not None and output.producer.index in removed_layers: |
|
|
continue |
|
|
if trt_output.is_shape_tensor: |
|
|
new_graph.add_output_shape(trt_output) |
|
|
else: |
|
|
new_graph.add_output(trt_output) |
|
|
|
|
|
return new_graph |
|
|
|
|
|
def get_full_shape_info(self, num_micro_batches): |
|
|
layer_mapping, tensor_mapping = self.graph_mapping_for_shape |
|
|
optimization_profiles = self.llm_network._generate_optimization_profiles( |
|
|
) |
|
|
if len(optimization_profiles) > 0: |
|
|
optimization_profile = optimization_profiles[-1] |
|
|
else: |
|
|
optimization_profile = None |
|
|
shape_info = get_shape_info(self.graph_for_shape.as_trt(), |
|
|
optimization_profile) |
|
|
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(), |
|
|
optimization_profile, |
|
|
shape_type=ShapeType.MAX) |
|
|
shape_info.max_shapes = max_shape_info.shapes |
|
|
for removed_tensor_name, tensor_name in tensor_mapping.items(): |
|
|
shape_info.shapes[removed_tensor_name] = shape_info.shapes[ |
|
|
tensor_name] |
|
|
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[ |
|
|
tensor_name] |
|
|
if tensor_name in shape_info.values: |
|
|
shape_info.values[removed_tensor_name] = shape_info.values[ |
|
|
tensor_name] |
|
|
for removed_layer_name, layer_name in layer_mapping.items(): |
|
|
if layer_name in shape_info.shape_layers: |
|
|
shape_info.shape_layers.add(removed_layer_name) |
|
|
return shape_info |
|
|
|
|
|
def init_layer_hash(self): |
|
|
with silent_trt_logger(): |
|
|
optimization_profiles = self.llm_network._generate_optimization_profiles( |
|
|
) |
|
|
if len(optimization_profiles) > 0: |
|
|
optimization_profile = optimization_profiles[-1] |
|
|
else: |
|
|
optimization_profile = None |
|
|
shape_info = get_shape_info(self.network, optimization_profile) |
|
|
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors} |
|
|
for layer in self.graph.layers: |
|
|
layer_hash = get_cache_key( |
|
|
layer.as_trt(), |
|
|
shape_info.shapes, |
|
|
shape_info.values, |
|
|
dtypes, |
|
|
) |
|
|
layer.attrs["hash"] = layer_hash |
|
|
|