File size: 38,908 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 |
import math
import re
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Tuple
import numpy as np
from tensorrt_llm.network import Network
from .config import AutoParallelConfig
from .device_mesh import PhysicalDeviceMesh
from .pipeline_graph import PipelineGraph
from .shape_info import ShapeInfo, ShapeType, get_shape_info
from .tensor_parallel.p2p_node import P2PType
from .utils import get_cache_key, get_sorted_layer_ids, silent_trt_logger
class StageType(Enum):
START = 0
BLOCK = 1
END = 2
class BuildingBlock:
def __init__(self, graph, layer_range) -> None:
self.graph = graph
self.layer_range = layer_range
self.network = graph.as_trt()
self.owned_inputs = {}
self.is_edges_collected = False
self.intra_edges = []
self.src_inter_edges = []
self.dst_inter_edges = []
self.relative_src_inter_edges = []
self.relative_dst_inter_edges = []
self.relative_inter_edges = set()
self.edge_hash = None
self.outputs = None
self.type_id = -1
self.block_id = -1
self.p2p_type = None
self.is_superset = False
self.is_subset = False
self.sorted_layer_ids = []
def collect_edges(self):
if self.is_edges_collected:
return
for layer_index in self.layer_range:
trt_layer = self.network.get_layer(layer_index)
layer = self.graph.get_layer(trt_layer.name)
layer_offset = layer.index - self.layer_range.start
for input_index, input in enumerate(layer.inputs):
if input is not None:
if input.is_graph_input:
is_owned = input.graph_input_index in self.owned_inputs
if not is_owned and np.all([
layer.index in self.layer_range or np.all([
output.as_trt().is_shape_tensor
for output in layer.outputs
]) for layer, _ in input.consumers
]):
self.owned_inputs[input.graph_input_index] = len(
self.owned_inputs)
is_owned = True
if is_owned:
self.intra_edges.append(
(-1, self.owned_inputs[input.graph_input_index],
layer_offset, input_index))
else:
self.dst_inter_edges.append(
(-1, input.graph_input_index, layer_offset,
input_index))
else:
src_layer_index = input.producer.index
if src_layer_index < self.layer_range.start or src_layer_index >= self.layer_range.stop:
self.dst_inter_edges.append(
(src_layer_index, input.output_index,
layer_offset, input_index))
else:
src_layer_offset = src_layer_index - self.layer_range.start
self.intra_edges.append(
(src_layer_offset, input.output_index,
layer_offset, input_index))
for output_index, output in enumerate(layer.outputs):
for dst_layer, dst_input_index in output.consumers:
dst_layer_index = dst_layer.index
if dst_layer_index < self.layer_range.start or dst_layer_index >= self.layer_range.stop:
self.src_inter_edges.append(
(layer_offset, output_index, dst_layer_index,
dst_input_index))
self.edge_hash = tuple(self.intra_edges)
self.outputs = sorted(
set((edge[0], edge[1]) for edge in self.src_inter_edges))
self.is_edges_collected = True
def collect_relative_inter_edges(self, layer_to_block):
self.collect_edges()
for src_layer_index, src_output_index, dst_layer_index, dst_input_index in self.dst_inter_edges:
if src_layer_index in layer_to_block:
src_block = layer_to_block[src_layer_index]
src_layer_offset = src_layer_index - src_block.layer_range.start
dst = (self.type_id, dst_layer_index, dst_input_index)
self.relative_dst_inter_edges.append(
(src_block.type_id, src_layer_offset, src_output_index,
*dst))
else:
self.relative_dst_inter_edges.append(
(-1, src_layer_index, src_output_index, self.type_id,
dst_layer_index, dst_input_index))
self.relative_inter_edges = set(self.relative_dst_inter_edges +
self.outputs)
def get_input_names(self):
self.collect_edges()
input_tensor_names = []
for edge in self.dst_inter_edges:
layer_index = edge[0]
output_index = edge[1]
if layer_index == -1:
tensor_name = self.network.get_input(output_index).name
else:
tensor_name = self.network.get_layer(layer_index).get_output(
output_index).name
input_tensor_names.append(tensor_name)
return input_tensor_names
def get_input_mapping(self, last_blocks):
input_mapping = {}
for tensor_name, relative_edge in zip(self.get_input_names(),
self.relative_dst_inter_edges):
type_id = relative_edge[0]
output_index = relative_edge[2]
if type_id >= 0:
last_block = last_blocks[type_id]
layer_offset = relative_edge[1]
mapped_layer_index = last_block.layer_range.start + layer_offset
mapped_tensor_name = self.network.get_layer(
mapped_layer_index).get_output(output_index).name
input_mapping[tensor_name] = mapped_tensor_name
else:
input_mapping[tensor_name] = tensor_name
return input_mapping
@dataclass
class GraphMapping:
layer_mapping: Dict[int, int] = None
block_mapping: Dict[int, int] = None
p2p_types: Dict[int, P2PType] = None
p2p_tensors: Dict[int, List[str]] = None
block_to_stage: Dict[int, int] = None
same_spec_layer_mapping: Dict[str, str] = None
@dataclass
class GraphConfig:
num_micro_batches: int = 1
num_blocks: int = 1
num_stages: int = 1
has_cross_device: bool = False
has_cross_host: bool = False
graph_mapping: GraphMapping = None
phy_mesh: PhysicalDeviceMesh = None
stage_phy_meshes: List[PhysicalDeviceMesh] = None
class Simplifier:
def __init__(self, network: Network, config: AutoParallelConfig):
self.config = config
self.sharded_io_allowlist = config.sharded_io_allowlist
self.same_buffer_io = config.same_buffer_io
self.same_spec_io = config.same_spec_io.copy()
for key, value in self.same_buffer_io.items():
if key not in self.same_spec_io:
self.same_spec_io[key] = value
self.llm_network = network
self.network = network.trt_network
self.module_to_layer_range_map = network._module_call_stack.module_to_layer_range_map
self.graph = self.get_graph()
self.init_layer_hash()
module_tree = self.get_module_tree()
building_blocks = self.collect_building_blocks(module_tree)
blocks_by_module_hash = self.get_blocks_by_module_hash(building_blocks)
self.blocks_by_edge_hash = self.get_blocks_by_edge_hash(
blocks_by_module_hash)
self.layer_to_block = self.get_layer_to_block()
self.blocks = self.get_all_blocks()
self.backbone_blocks = self.get_backbone_blocks()
self.graph_mapping_for_shape = self.get_graph_mapping_for_shape()
self.graph_for_shape = self.create_simplified_graph_for_shape()
self.shape_info = None
self.num_micro_batches = None
def infer_shapes(self, num_micro_batches):
if self.num_micro_batches == num_micro_batches:
return
with silent_trt_logger():
self.shape_info = self.get_full_shape_info(num_micro_batches)
self.graph.assign_shapes(self.shape_info)
self.num_micro_batches = num_micro_batches
def list_all_num_micro_batches(self):
opt_batch_size = self.get_opt_batch_size()
candidates = []
for num_micro_batches in range(1, self.get_opt_batch_size() + 1):
if opt_batch_size % num_micro_batches == 0:
candidates.append(num_micro_batches)
return candidates
def get_graph(self):
graph = PipelineGraph.from_trt(self.network)
graph._unfilled_weights = self.llm_network._unfilled_weights.copy()
graph._io_buffer_mapping
for input in graph.inputs:
input_name = input.name
for pattern, repl in self.same_buffer_io.items():
if re.match(pattern, input_name):
output_name = re.sub(pattern, repl, input_name)
output = graph.get_output(output_name)
if output is not None:
graph._io_buffer_mapping[output_name] = input_name
return graph
def get_opt_batch_size(self):
input_tensors = self.llm_network._inputs
num_profiles = len(list(input_tensors.values())[0].profiles)
opt_batch_sizes = []
for i in range(num_profiles):
for input_tensor in input_tensors.values():
shape_profile = input_tensor.profiles[i]
opt_shape = shape_profile.opt
for j in range(len(input_tensor.shape)):
name = input_tensor.trt_tensor.get_dimension_name(j)
if name == 'batch_size':
opt_batch_sizes.append(opt_shape[j])
return min(opt_batch_sizes)
def get_module_hash(self, layer_range):
module_hash = ()
for i in layer_range:
assert i < self.network.num_layers, f"layer index {i} in {layer_range} out of range of {self.network.num_layers}"
layer_name = self.network.get_layer(i).name
layer = self.graph.get_layer(layer_name)
module_hash += (layer.attrs["hash"], )
return module_hash
def get_network_hash(self) -> str:
return str(self.get_module_hash(range(self.network.num_layers)))
def collect_building_blocks(self, module_tree):
building_blocks = {}
queue = []
for tree in module_tree["children"].values():
queue.append(tree)
while len(queue) > 0:
while len(queue) > 0:
tree = queue.pop(0)
module_name = tree["name"]
if module_name is None:
for child in tree["children"].values():
queue.append(child)
continue
layer_range = self.module_to_layer_range_map[module_name]
module_hash = self.get_module_hash(layer_range)
if module_hash in building_blocks:
building_blocks[module_hash].append(tree)
else:
building_blocks[module_hash] = [tree]
for module_hash in [*building_blocks.keys()]:
if len(building_blocks[module_hash]) == 1:
tree = building_blocks[module_hash][0]
for child in tree["children"].values():
queue.append(child)
del building_blocks[module_hash]
blocks_by_module_hash = {
module_hash: [
BuildingBlock(self.graph,
self.module_to_layer_range_map[tree["name"]])
for tree in trees
]
for module_hash, trees in building_blocks.items()
}
building_blocks = []
for block_list in blocks_by_module_hash.values():
for block in block_list:
building_blocks.append(block)
building_blocks = sorted(building_blocks,
key=lambda x: x.layer_range.start)
if len(building_blocks) >= 2:
for block, next_block in zip(building_blocks[:-1],
building_blocks[1:]):
block.layer_range = range(block.layer_range.start,
next_block.layer_range.start)
return building_blocks
def get_all_blocks(self):
building_blocks = []
for block_list in self.blocks_by_edge_hash.values():
for block in block_list:
building_blocks.append(block)
building_blocks = sorted(building_blocks,
key=lambda x: x.layer_range.start)
all_blocks = []
current_layer_index = 0
block_id = 0
for block in building_blocks:
assert current_layer_index <= block.layer_range.start
if current_layer_index < block.layer_range.start:
new_block = BuildingBlock(
self.graph,
range(current_layer_index, block.layer_range.start))
new_block.block_id = block_id
block_id += 1
all_blocks.append(new_block)
block.block_id = block_id
block_id += 1
all_blocks.append(block)
current_layer_index = block.layer_range.stop
if current_layer_index < self.graph.num_layers:
new_block = BuildingBlock(
self.graph, range(current_layer_index, self.graph.num_layers))
new_block.block_id = block_id
all_blocks.append(new_block)
sorted_layer_ids = get_sorted_layer_ids(self.network)
for block in all_blocks:
block.collect_relative_inter_edges(self.layer_to_block)
for layer_id in sorted_layer_ids:
if layer_id in block.layer_range:
block.sorted_layer_ids.append(layer_id)
return all_blocks
def get_backbone_blocks(self):
sorted_blocks = sorted(
self.blocks_by_edge_hash.values(),
key=lambda blocks: (len(blocks), len(blocks[0].layer_range)),
)
if len(sorted_blocks) == 0:
return []
else:
return sorted_blocks[-1]
def get_blocks_by_module_hash(self, blocks):
blocks_by_module_hash = {}
for block in blocks:
module_hash = self.get_module_hash(block.layer_range)
if module_hash not in blocks_by_module_hash:
blocks_by_module_hash[module_hash] = []
blocks_by_module_hash[module_hash].append(block)
for module_hash in [*blocks_by_module_hash.keys()]:
if len(blocks_by_module_hash[module_hash]) == 1:
del blocks_by_module_hash[module_hash]
return blocks_by_module_hash
def get_module_tree(self):
module_tree = {"children": {}, "name": None}
for module_name in self.module_to_layer_range_map.keys():
full_name = module_name.split('.')
current_tree = module_tree["children"]
for depth, name in enumerate(full_name):
if name not in current_tree:
current_tree[name] = {"children": {}, "name": None}
if depth == len(full_name) - 1:
current_tree[name]["name"] = module_name
else:
current_tree = current_tree[name]["children"]
return module_tree
def get_blocks_by_edge_hash(self, blocks_by_module_hash):
blocks_by_edge_hash = {}
for block_list in blocks_by_module_hash.values():
for block in block_list:
block.collect_edges()
edge_hash = block.edge_hash
if edge_hash not in blocks_by_edge_hash:
blocks_by_edge_hash[edge_hash] = []
blocks_by_edge_hash[edge_hash].append(block)
for edge_hash in [*blocks_by_edge_hash.keys()]:
if len(blocks_by_edge_hash[edge_hash]) == 1:
del blocks_by_edge_hash[edge_hash]
else:
block_list = blocks_by_edge_hash[edge_hash]
blocks_by_edge_hash[edge_hash] = sorted(
block_list, key=lambda x: x.layer_range.start)
for type_id, block_list in enumerate(blocks_by_edge_hash.values()):
for block in block_list:
block.type_id = type_id
return blocks_by_edge_hash
def get_layer_to_block(self):
layer_to_block = {}
for block_list in self.blocks_by_edge_hash.values():
for block in block_list:
for layer_index in block.layer_range:
layer_to_block[layer_index] = block
return layer_to_block
def clean_blocks(self):
for block in self.blocks:
block.p2p_type = None
block.is_superset = False
block.is_subset = False
def mark_p2p_type(self, phy_mesh, stage_phy_meshes,
graph_config: GraphConfig):
if len(self.backbone_blocks) == 0 or len(stage_phy_meshes) == 1:
return
assert len(self.backbone_blocks) % len(stage_phy_meshes) == 0
block_per_stage = len(self.backbone_blocks) // len(stage_phy_meshes)
for block in self.backbone_blocks:
block.p2p_type = None
for stage_index, stage_phy_mesh in enumerate(stage_phy_meshes[:-1]):
next_stage_phy_mesh = stage_phy_meshes[stage_index + 1]
last_device_id = stage_phy_mesh.phy_devices_id.flatten()[-1]
next_first_device_id = next_stage_phy_mesh.phy_devices_id.flatten(
)[0]
num_devices_per_host = phy_mesh.num_devices_per_host
next_block = self.backbone_blocks[(stage_index + 1) *
block_per_stage]
if last_device_id // num_devices_per_host != next_first_device_id // num_devices_per_host:
next_block.p2p_type = P2PType.CROSS_HOST
graph_config.has_cross_host = True
else:
next_block.p2p_type = P2PType.CROSS_DEVICE
graph_config.has_cross_device = True
def get_graph_mapping(self):
layer_mapping = {}
block_mapping = {}
p2p_types = {}
p2p_tensors = {}
for block_list in self.blocks_by_edge_hash.values():
superset_blocks = []
superset_block_index = {}
for block in block_list:
block_added = False
for index, superset_block in enumerate(list(superset_blocks)):
if block.p2p_type == superset_block.p2p_type:
if block.relative_inter_edges.issubset(
superset_block.relative_inter_edges):
block.is_subset = True
block.is_superset = False
superset_block_index[id(block)] = index
block_added = True
break
elif superset_block.relative_inter_edges.issubset(
block.relative_inter_edges):
superset_block.is_subset = True
superset_block.is_superset = False
block.is_subset = False
block.is_superset = True
superset_blocks[index] = block
superset_block_index[id(block)] = index
block_added = True
break
if not block_added:
block.is_subset = False
block.is_superset = True
superset_blocks.append(block)
superset_block_index[id(block)] = len(superset_blocks) - 1
for block in block_list:
assert not (block.is_subset and block.is_superset)
if block.is_subset:
superset_block = superset_blocks[superset_block_index[id(
block)]]
block_mapping[block.block_id] = superset_block.block_id
owned_inputs = map(
lambda x: x[0],
sorted(block.owned_inputs.items(), key=lambda x: x[1]))
superset_owned_inputs = map(
lambda x: x[0],
sorted(superset_block.owned_inputs.items(),
key=lambda x: x[1]))
for from_input_id, to_input_id in zip(
owned_inputs, superset_owned_inputs):
from_input_name = self.network.get_input(
from_input_id).name
to_input_name = self.network.get_input(to_input_id).name
layer_mapping[from_input_name] = to_input_name
for from_layer_id, to_layer_id in zip(
block.layer_range, superset_block.layer_range):
from_layer = self.network.get_layer(from_layer_id)
to_layer = self.network.get_layer(to_layer_id)
layer_mapping[from_layer.name] = to_layer.name
for i in range(from_layer.num_outputs):
from_output = from_layer.get_output(i)
if from_output.is_network_output:
to_output = to_layer.get_output(i)
layer_mapping[from_output.name] = to_output.name
if block.p2p_type is not None:
p2p_types[block.block_id] = block.p2p_type
p2p_tensors[block.block_id] = [
*set(block.get_input_names())
]
for from_name, to_name in zip(
block.get_input_names(),
superset_block.get_input_names()):
layer_mapping[
f"p2p_block{block.block_id}_{from_name}"] = f"p2p_block{superset_block.block_id}_{to_name}"
stage_id = 0
block_to_stage = {}
for block in self.blocks:
if block.p2p_type is not None:
stage_id += 1
block_to_stage[block.block_id] = stage_id
return GraphMapping(
layer_mapping,
block_mapping,
p2p_types,
p2p_tensors,
block_to_stage,
)
def create_simplified_graph(self, graph_config: GraphConfig):
new_graph = PipelineGraph.create_graph()
new_graph._io_buffer_mapping = self.graph._io_buffer_mapping
layer_mapping = graph_config.graph_mapping.layer_mapping
for i in range(self.network.num_inputs):
trt_input = self.network.get_input(i)
if trt_input.name not in layer_mapping:
new_graph.add_input(trt_input)
last_blocks = {}
same_spec_mapping = {}
same_spec_layer_mapping = {}
shape_mapping = {}
building_block_id = 0
same_spec_ids = {}
same_spec_count = 0
for block in self.blocks:
if not block.is_subset:
stage_type = None
if not block.is_superset:
if block.block_id == 0:
stage_type = StageType.START
elif block.block_id == len(self.blocks) - 1:
stage_type = StageType.END
input_mapping = block.get_input_mapping(last_blocks)
for from_name, to_name in [*input_mapping.items()]:
if to_name in same_spec_mapping:
input_mapping[from_name] = same_spec_mapping[to_name]
if to_name in layer_mapping:
input_mapping[from_name] = layer_mapping[to_name]
if block.is_superset and block.p2p_type is not None:
for from_name, to_name in [*input_mapping.items()]:
output_tensor = new_graph.get_tensor(to_name)
p2p_layer = new_graph.as_trt().add_identity(
output_tensor.as_trt())
p2p_layer.name = f"p2p_block{block.block_id}_{from_name}"
p2p_layer.metadata = p2p_layer.name
p2p_tensor = p2p_layer.get_output(0)
p2p_tensor.name = f"{p2p_layer.name}_output"
wrapped_layer = new_graph.register_layer(p2p_layer)
wrapped_layer.attrs[
"building_block_id"] = building_block_id
wrapped_layer.attrs["p2p_type"] = block.p2p_type
input_mapping[from_name] = p2p_tensor.name
shape_mapping[p2p_tensor.name] = from_name
building_block_id += 1
for i in block.sorted_layer_ids:
layer = self.network.get_layer(i)
wrapped_layer = new_graph.add_layer(
layer,
input_mapping=input_mapping,
)
wrapped_layer.attrs["building_block_id"] = building_block_id
wrapped_layer.attrs["stage_type"] = stage_type
if block.is_superset:
last_blocks[block.type_id] = block
if block.type_id in same_spec_ids:
same_spec_id = same_spec_ids[block.type_id]
update_same_spec_count = False
else:
same_spec_id = same_spec_count
same_spec_ids[block.type_id] = same_spec_id
update_same_spec_count = True
count = same_spec_id
for i, (layer_offset,
output_index) in enumerate(block.outputs):
layer = self.network.get_layer(block.layer_range.start +
layer_offset)
tensor_name = layer.get_output(output_index).name
output_tensor = new_graph.get_tensor(tensor_name)
same_spec_layer = new_graph.as_trt().add_identity(
output_tensor.as_trt())
same_spec_layer.name = f"{tensor_name}_same_spec"
same_spec_layer.metadata = same_spec_layer.name
same_spec_tensor = same_spec_layer.get_output(0)
same_spec_tensor.name = f"{same_spec_layer.name}_output"
wrapped_layer = new_graph.register_layer(
same_spec_layer)
wrapped_layer.attrs[
"building_block_id"] = building_block_id
wrapped_layer.attrs["same_spec_id"] = count
count += 1
same_spec_mapping[tensor_name] = same_spec_tensor.name
same_spec_layer_mapping[
same_spec_layer.name] = layer.name
shape_mapping[same_spec_tensor.name] = tensor_name
for i, graph_input_index in enumerate(
block.owned_inputs.keys()):
input_name = self.network.get_input(
graph_input_index).name
input_tensor = new_graph.get_input(input_name)
input_tensor.attrs["same_spec_id"] = count
count += 1
if update_same_spec_count:
same_spec_count = count
building_block_id += 1
graph_config.graph_mapping.same_spec_layer_mapping = same_spec_layer_mapping
if len(self.backbone_blocks) >= 2:
start_block = self.backbone_blocks[0]
if start_block.is_subset:
start_block = self.blocks[graph_config.graph_mapping.
block_mapping[start_block.block_id]]
for i in start_block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_start_block"] = True
end_block = self.backbone_blocks[-1]
if end_block.is_subset:
end_block = self.blocks[graph_config.graph_mapping.
block_mapping[end_block.block_id]]
for i in end_block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_end_block"] = True
slowest_p2p_type = None
if graph_config.has_cross_host:
slowest_p2p_type = P2PType.CROSS_HOST
elif graph_config.has_cross_device:
slowest_p2p_type = P2PType.CROSS_DEVICE
if slowest_p2p_type is not None:
for block in self.blocks:
if block.is_superset and block.p2p_type == slowest_p2p_type:
for i in block.layer_range:
layer_name = self.network.get_layer(i).name
layer = new_graph.get_layer(layer_name)
layer.attrs["in_slowest_block"] = True
for i in range(self.network.num_outputs):
trt_output = self.network.get_output(i)
output = self.graph.get_output(trt_output.name)
if output.producer is not None and output.producer.index in self.layer_to_block and self.layer_to_block[
output.producer.index].is_subset:
continue
if trt_output.is_shape_tensor:
new_output = new_graph.add_output_shape(trt_output)
else:
new_output = new_graph.add_output(trt_output)
sharded_io = False
for pattern in self.sharded_io_allowlist:
if re.match(pattern, new_output.name):
sharded_io = True
break
if not sharded_io:
new_output.producer.attrs["is_replicated"] = True
for input in new_graph.inputs:
input_name = input.name
sharded_io = False
for pattern in self.sharded_io_allowlist:
if re.match(pattern, input_name):
sharded_io = True
break
if not sharded_io:
input.attrs["is_replicated"] = True
for pattern, repl in self.same_spec_io.items():
if re.match(pattern, input_name):
output_name = re.sub(pattern, repl, input_name)
output = new_graph.get_output(output_name)
if output is not None:
if "same_spec_id" in input.attrs:
same_spec_id = input.attrs["same_spec_id"]
else:
same_spec_id = same_spec_count
same_spec_count += 1
input.attrs["same_spec_id"] = same_spec_id
output.attrs["same_spec_id"] = same_spec_id
if math.prod(self.graph.get_input(
input_name).shape) < math.prod(
self.graph.get_output(output_name).shape):
input.attrs["no_memory_footprint"] = True
else:
output.attrs["no_memory_footprint"] = True
return new_graph, shape_mapping
def enrich_shape_info(self, shape_mapping):
shapes = self.shape_info.shapes.copy()
max_shapes = self.shape_info.max_shapes.copy()
values = self.shape_info.values.copy()
shape_layers = self.shape_info.shape_layers
for from_name, to_name in shape_mapping.items():
if to_name in shapes:
shapes[from_name] = shapes[to_name]
if to_name in max_shapes:
max_shapes[from_name] = max_shapes[to_name]
if to_name in values:
values[from_name] = values[to_name]
shape_info = ShapeInfo(shapes, values, shape_layers, max_shapes)
return shape_info
def simplify_graph(
self, phy_mesh: PhysicalDeviceMesh, num_stages: int,
num_devices_per_stage: int) -> Tuple[PipelineGraph, GraphConfig]:
num_blocks = len(self.backbone_blocks)
if num_blocks % num_stages != 0:
return None, None
graph_config = GraphConfig()
graph_config.num_micro_batches = self.num_micro_batches
graph_config.num_blocks = num_blocks
graph_config.num_stages = num_stages
graph_config.phy_mesh = phy_mesh
stage_phy_meshes = phy_mesh.split_pipeline_meshes(
num_stages, num_devices_per_stage)
graph_config.stage_phy_meshes = stage_phy_meshes
with silent_trt_logger():
self.clean_blocks()
self.mark_p2p_type(phy_mesh, stage_phy_meshes, graph_config)
graph_config.graph_mapping = self.get_graph_mapping()
new_graph, shape_mapping = self.create_simplified_graph(
graph_config)
shape_info = self.enrich_shape_info(shape_mapping)
new_graph.assign_shapes(shape_info)
return new_graph, graph_config
def get_graph_mapping_for_shape(self):
layer_mapping = {}
tensor_mapping = {}
for block_list in self.blocks_by_edge_hash.values():
head_block = block_list[0]
for block in block_list[1:]:
for from_layer_id, to_layer_id in zip(block.layer_range,
head_block.layer_range):
from_layer = self.network.get_layer(from_layer_id)
to_layer = self.network.get_layer(to_layer_id)
layer_mapping[from_layer.name] = to_layer.name
for i in range(from_layer.num_outputs):
tensor_mapping[from_layer.get_output(
i).name] = to_layer.get_output(i).name
return layer_mapping, tensor_mapping
def create_simplified_graph_for_shape(self):
new_graph = PipelineGraph.create_graph()
for i in range(self.network.num_inputs):
trt_input = self.network.get_input(i)
new_graph.add_input(trt_input)
head_blocks = {}
removed_blocks = set()
removed_layers = set()
for block_list in self.blocks_by_edge_hash.values():
head_block = block_list[0]
head_blocks[head_block.type_id] = head_block
for block in block_list[1:]:
removed_blocks.add(id(block))
for layer_index in block.layer_range:
removed_layers.add(layer_index)
for block in self.blocks:
if not id(block) in removed_blocks:
input_mapping = block.get_input_mapping(head_blocks)
for i in block.sorted_layer_ids:
layer = self.network.get_layer(i)
new_graph.add_layer(
layer,
input_mapping=input_mapping,
)
for i in range(self.network.num_outputs):
trt_output = self.network.get_output(i)
output = self.graph.get_output(trt_output.name)
if output.producer is not None and output.producer.index in removed_layers:
continue
if trt_output.is_shape_tensor:
new_graph.add_output_shape(trt_output)
else:
new_graph.add_output(trt_output)
return new_graph
def get_full_shape_info(self, num_micro_batches):
layer_mapping, tensor_mapping = self.graph_mapping_for_shape
optimization_profiles = self.llm_network._generate_optimization_profiles(
)
if len(optimization_profiles) > 0:
optimization_profile = optimization_profiles[-1]
else:
optimization_profile = None
shape_info = get_shape_info(self.graph_for_shape.as_trt(),
optimization_profile)
max_shape_info = get_shape_info(self.graph_for_shape.as_trt(),
optimization_profile,
shape_type=ShapeType.MAX)
shape_info.max_shapes = max_shape_info.shapes
for removed_tensor_name, tensor_name in tensor_mapping.items():
shape_info.shapes[removed_tensor_name] = shape_info.shapes[
tensor_name]
shape_info.max_shapes[removed_tensor_name] = shape_info.max_shapes[
tensor_name]
if tensor_name in shape_info.values:
shape_info.values[removed_tensor_name] = shape_info.values[
tensor_name]
for removed_layer_name, layer_name in layer_mapping.items():
if layer_name in shape_info.shape_layers:
shape_info.shape_layers.add(removed_layer_name)
return shape_info
def init_layer_hash(self):
with silent_trt_logger():
optimization_profiles = self.llm_network._generate_optimization_profiles(
)
if len(optimization_profiles) > 0:
optimization_profile = optimization_profiles[-1]
else:
optimization_profile = None
shape_info = get_shape_info(self.network, optimization_profile)
dtypes = {tensor.name: tensor.dtype for tensor in self.graph.tensors}
for layer in self.graph.layers:
layer_hash = get_cache_key(
layer.as_trt(),
shape_info.shapes,
shape_info.values,
dtypes,
)
layer.attrs["hash"] = layer_hash
|