|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
|
import contextlib |
|
|
import hashlib |
|
|
import inspect |
|
|
import weakref |
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, Iterable, List, Optional, OrderedDict, Set, Tuple |
|
|
|
|
|
import numpy as np |
|
|
import tensorrt as trt |
|
|
|
|
|
from tensorrt_llm.module import Module |
|
|
|
|
|
from ._common import set_network |
|
|
from ._utils import get_extra_attr, has_extra_attr, set_extra_attr, trt_gte_10_1 |
|
|
from .logger import logger |
|
|
from .plugin import PluginConfig |
|
|
|
|
|
|
|
|
class _UniqueNameGenerator(object): |
|
|
|
|
|
def __init__(self, prefix=''): |
|
|
self.ids = collections.defaultdict(int) |
|
|
self.prefix = prefix |
|
|
|
|
|
def __call__(self, key, module_name=''): |
|
|
if module_name != '': |
|
|
module_name = module_name.replace(".", "/") |
|
|
key = module_name + '/' + key |
|
|
tmp = self.ids[key] |
|
|
self.ids[key] += 1 |
|
|
return f"{self.prefix}{key}_{tmp}" |
|
|
|
|
|
|
|
|
class PluginInfo: |
|
|
plugin_creator: trt.IPluginCreator |
|
|
plugin_name: str |
|
|
pfc: trt.PluginFieldCollection |
|
|
|
|
|
def __init__(self, plugin_creator: trt.IPluginCreator, plugin_name: str, |
|
|
pfc: trt.PluginFieldCollection): |
|
|
self.plugin_creator = plugin_creator |
|
|
self.plugin_name = plugin_name |
|
|
self.pfc = pfc |
|
|
self._parse_pfc(pfc) |
|
|
|
|
|
def _parse_pfc(self, pfc: trt.PluginFieldCollection): |
|
|
self.pfc_as_ndarray = {} |
|
|
self.pfc_as_list = {} |
|
|
for i in range(len(pfc)): |
|
|
name, data = pfc[i].name, pfc[i].data |
|
|
array_data = data |
|
|
self.pfc_as_ndarray[name] = array_data.copy() |
|
|
list_data = array_data.tolist() |
|
|
self.pfc_as_list[name] = list_data |
|
|
|
|
|
|
|
|
def get_plugin_info(trt_network: trt.INetworkDefinition, |
|
|
layer_name: str) -> PluginInfo: |
|
|
if not has_extra_attr(trt_network, "plugin_infos"): |
|
|
return None |
|
|
plugin_infos = get_extra_attr(trt_network, "plugin_infos") |
|
|
if layer_name not in plugin_infos: |
|
|
return None |
|
|
return plugin_infos[layer_name] |
|
|
|
|
|
|
|
|
def set_plugin_info(trt_network: trt.INetworkDefinition, layer_name: str, |
|
|
plugin_info: PluginInfo): |
|
|
if not has_extra_attr(trt_network, "plugin_infos"): |
|
|
set_extra_attr(trt_network, "plugin_infos", {}) |
|
|
plugin_infos = get_extra_attr(trt_network, "plugin_infos") |
|
|
plugin_infos[layer_name] = plugin_info |
|
|
|
|
|
|
|
|
def delete_plugin_info(trt_network: trt.INetworkDefinition, layer_name: str): |
|
|
if not has_extra_attr(trt_network, "plugin_infos"): |
|
|
return |
|
|
plugin_infos = get_extra_attr(trt_network, "plugin_infos") |
|
|
if layer_name not in plugin_infos: |
|
|
return |
|
|
del plugin_infos[layer_name] |
|
|
|
|
|
|
|
|
|
|
|
def get_np_weight(trt_network: trt.INetworkDefinition, |
|
|
layer_name: str) -> np.array: |
|
|
if not has_extra_attr(trt_network, "np_weights"): |
|
|
return None |
|
|
np_weights = get_extra_attr(trt_network, "np_weights") |
|
|
if layer_name not in np_weights: |
|
|
return None |
|
|
return np_weights[layer_name] |
|
|
|
|
|
|
|
|
|
|
|
def set_np_weight(trt_network: trt.INetworkDefinition, layer_name: str, |
|
|
np_weight: np.array): |
|
|
if not has_extra_attr(trt_network, "np_weights"): |
|
|
set_extra_attr(trt_network, "np_weights", {}) |
|
|
np_weights = get_extra_attr(trt_network, "np_weights") |
|
|
np_weights[layer_name] = np_weight |
|
|
|
|
|
|
|
|
class Network(object): |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._removed_layers: Set[str] = set() |
|
|
|
|
|
self.is_graph_altered = False |
|
|
|
|
|
from .graph_rewriting import FLayerInfoMemo |
|
|
self.flayer_memo = FLayerInfoMemo() |
|
|
|
|
|
def _init(self, trt_network): |
|
|
self._trt_network = trt_network |
|
|
self._inputs = {} |
|
|
self._named_parameters = None |
|
|
|
|
|
self._dtype = None |
|
|
self._name_generator = _UniqueNameGenerator() |
|
|
self._plugin_config = PluginConfig() |
|
|
self._module_call_stack = _TrtLlmModuleCallStack() |
|
|
self._registered_ndarrays = [] |
|
|
self._strongly_typed = trt.INetworkDefinition.get_flag( |
|
|
self._trt_network, trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) |
|
|
self._unfilled_weights: Dict[str, Tuple[np.array, np.array]] = {} |
|
|
self._auto_parallel_config: Dict[str, Any] = None |
|
|
|
|
|
return self |
|
|
|
|
|
def _register_unfilled_weights(self, layer_name: str, weights: np.array, |
|
|
values: np.array): |
|
|
self._unfilled_weights[layer_name] = (weights, values) |
|
|
|
|
|
def _fill_weights(self): |
|
|
from tensorrt_llm.parameter import Parameter |
|
|
|
|
|
for layer_name in list(self._unfilled_weights.keys()): |
|
|
weights, values = self._unfilled_weights.pop(layer_name) |
|
|
self.register_ndarray(weights) |
|
|
if values is not None: |
|
|
np.copyto(weights, values, casting='no') |
|
|
else: |
|
|
Parameter.xavier_init(weights) |
|
|
|
|
|
@property |
|
|
def dtype(self) -> trt.DataType: |
|
|
return self._dtype |
|
|
|
|
|
@dtype.setter |
|
|
def dtype(self, dtype: trt.DataType): |
|
|
assert isinstance(dtype, trt.DataType) or dtype is None |
|
|
self._dtype = dtype |
|
|
|
|
|
@property |
|
|
def trt_network(self) -> trt.INetworkDefinition: |
|
|
return self._trt_network |
|
|
|
|
|
@property |
|
|
def plugin_config(self) -> PluginConfig: |
|
|
return self._plugin_config |
|
|
|
|
|
@plugin_config.setter |
|
|
def plugin_config(self, cfg: PluginConfig): |
|
|
assert isinstance( |
|
|
cfg, |
|
|
PluginConfig), f"Expecting a PluginConfig object, got {type(cfg)}" |
|
|
self._plugin_config = cfg |
|
|
|
|
|
@property |
|
|
def strongly_typed(self) -> bool: |
|
|
return self._strongly_typed |
|
|
|
|
|
@property |
|
|
def auto_parallel_config(self) -> Dict[str, Any]: |
|
|
return self._auto_parallel_config |
|
|
|
|
|
def _add_input(self, |
|
|
tensor, |
|
|
name, |
|
|
dtype, |
|
|
shape, |
|
|
dim_range: OrderedDict = None): |
|
|
assert isinstance(dtype, trt.DataType) |
|
|
tensor.trt_tensor = self.trt_network.add_input( |
|
|
name=name, |
|
|
shape=shape, |
|
|
dtype=dtype, |
|
|
) |
|
|
assert tensor.trt_tensor is not None, f"Couldn't create TRT tensor for {name} {dtype} {shape}" |
|
|
if dim_range is not None: |
|
|
logger.debug( |
|
|
f'Add input: {name}, shape: {shape}, dtype: {dtype}, dimension names:{list(dim_range.keys())}' |
|
|
) |
|
|
|
|
|
|
|
|
if trt_gte_10_1(): |
|
|
for i, dim_name in enumerate(dim_range.keys()): |
|
|
tensor.trt_tensor.set_dimension_name(i, str(dim_name)) |
|
|
else: |
|
|
logger.debug(f'Add input: {name}, shape: {shape}, dtype: {dtype}') |
|
|
self._inputs[name] = tensor |
|
|
|
|
|
def _mark_output(self, tensor, name, dtype): |
|
|
from .functional import cast |
|
|
|
|
|
|
|
|
if dtype is not None and self.strongly_typed: |
|
|
tensor = cast(tensor, dtype) |
|
|
self.trt_network.mark_output(tensor.trt_tensor) |
|
|
tensor.trt_tensor.name = name |
|
|
if not self.strongly_typed: |
|
|
tensor.trt_tensor.dtype = dtype or tensor.trt_tensor.dtype |
|
|
logger.debug(f'Mark output: {name}, dtype: {dtype}') |
|
|
|
|
|
def set_named_parameters(self, named_parameters): |
|
|
self._named_parameters = named_parameters |
|
|
|
|
|
@property |
|
|
def named_parameters(self): |
|
|
return self._named_parameters |
|
|
|
|
|
def _set_layer_name(self, layer): |
|
|
original_layer_name = layer.name |
|
|
layer_name = str(layer.type).split('.')[-1] |
|
|
current_module = self._module_call_stack.get_current_module() |
|
|
|
|
|
func_stack = [] |
|
|
frame = inspect.currentframe().f_back.f_back |
|
|
while frame: |
|
|
func_name = frame.f_code.co_name |
|
|
line_num = frame.f_lineno |
|
|
if func_name == "forward": |
|
|
break |
|
|
func_stack.insert(0, f"{func_name}_L{line_num}") |
|
|
if len(func_stack) >= 10: |
|
|
|
|
|
|
|
|
|
|
|
break |
|
|
frame = frame.f_back |
|
|
current_module = f"{current_module}.{'.'.join(func_stack)}" |
|
|
|
|
|
if layer.type == trt.LayerType.PLUGIN_V2: |
|
|
layer_name = '_'.join( |
|
|
[layer_name, |
|
|
str(layer.plugin.plugin_type).split('.')[-1]]) |
|
|
elif layer.type in [ |
|
|
trt.LayerType.UNARY, trt.LayerType.REDUCE, |
|
|
trt.LayerType.ELEMENTWISE |
|
|
]: |
|
|
layer_name = '_'.join([layer_name, str(layer.op).split('.')[-1]]) |
|
|
|
|
|
layer.name = self._name_generator(layer_name, current_module) |
|
|
for idx in range(layer.num_outputs): |
|
|
|
|
|
|
|
|
|
|
|
layer.get_output(idx).name = f"{layer.name}_output_{idx}" |
|
|
if original_layer_name != layer.name: |
|
|
if layer.type == trt.LayerType.PLUGIN_V2: |
|
|
plugin_info = get_plugin_info(self.trt_network, |
|
|
original_layer_name) |
|
|
if plugin_info is not None: |
|
|
set_plugin_info(self.trt_network, layer.name, plugin_info) |
|
|
delete_plugin_info(self.trt_network, original_layer_name) |
|
|
|
|
|
|
|
|
layer.metadata = layer.name |
|
|
|
|
|
def register_ndarray(self, ndarray: np.ndarray) -> None: |
|
|
''' When the functional APIs need to create local numpy array and use as weights for constant or other layers, |
|
|
they need to register the ndarray objects to the TRT-LLM Network to prolong the lifetime of the ndarray, such that weights are |
|
|
still valid when functional API returned. |
|
|
All the weights referenced by the trt Network are weak referenced, it's TRT-LLM's responsibility to keep the weights alive |
|
|
during the TRT network construction and TRT engine building process. |
|
|
''' |
|
|
self._registered_ndarrays.append(ndarray) |
|
|
|
|
|
def _generate_optimization_profiles(self) -> List[trt.IOptimizationProfile]: |
|
|
input_tensors = self._inputs |
|
|
if len(input_tensors) == 0: |
|
|
return [] |
|
|
num_profiles = len(list(input_tensors.values())[0].profiles) |
|
|
profiles = [] |
|
|
for i in range(num_profiles): |
|
|
logger.debug(f'Adding optimization profile {i+1}/{num_profiles}') |
|
|
profile = self._trt_network.builder.create_optimization_profile() |
|
|
for input_name, input_tensor in input_tensors.items(): |
|
|
shape_profile = input_tensor.profiles[i] |
|
|
min_shape = list(shape_profile.min) |
|
|
opt_shape = list(shape_profile.opt) |
|
|
max_shape = list(shape_profile.max) |
|
|
if input_tensor.trt_tensor.is_shape_tensor: |
|
|
profile.set_shape_input(input_name, min_shape, opt_shape, |
|
|
max_shape) |
|
|
else: |
|
|
profile.set_shape(input_name, min_shape, opt_shape, |
|
|
max_shape) |
|
|
logger.debug( |
|
|
f'{input_name}, min: {min_shape}, opt: {opt_shape}, max: {max_shape}' |
|
|
) |
|
|
profiles.append(profile) |
|
|
return profiles |
|
|
|
|
|
def get_inputs(self): |
|
|
''' |
|
|
Get the inputs of the network. |
|
|
|
|
|
Returns: |
|
|
Iterable[Tensor] |
|
|
''' |
|
|
return self._inputs.values() |
|
|
|
|
|
def get_outputs(self): |
|
|
''' |
|
|
Get the outputs of the network. |
|
|
|
|
|
Returns: |
|
|
Iterable[Tensor] |
|
|
''' |
|
|
from .functional import Tensor |
|
|
for i in range(self._trt_network.num_outputs): |
|
|
tensor = self._trt_network.get_output(i) |
|
|
yield Tensor(trt_tensor=tensor, |
|
|
network=self, |
|
|
is_network_input=False) |
|
|
|
|
|
def is_input(self, tensor) -> bool: |
|
|
''' |
|
|
Tell if a tensor is a input of the network. |
|
|
|
|
|
Parameters: |
|
|
tensor: Union[Tensor, str, trt.ITensor] |
|
|
''' |
|
|
from .functional import Tensor |
|
|
|
|
|
if isinstance(tensor, str): |
|
|
tensor_name = tensor |
|
|
elif isinstance(tensor, (trt.ITensor, Tensor)): |
|
|
tensor_name = tensor.name |
|
|
else: |
|
|
raise ValueError( |
|
|
f"tensor should be Tensor, str or ITensor, got {tensor}") |
|
|
|
|
|
return self._inputs.get(tensor_name, False) |
|
|
|
|
|
def is_output(self, tensor) -> bool: |
|
|
''' |
|
|
Tell if a tensor is a output of the network. |
|
|
|
|
|
Parameters: |
|
|
tensor: Tensor |
|
|
''' |
|
|
for i in range(self._trt_network.num_outputs): |
|
|
if tensor.trt_tensor is self._trt_network.get_output(i): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def get_layers(self) -> Iterable["Layer"]: |
|
|
''' |
|
|
Get all the layers of network. |
|
|
|
|
|
Returns: |
|
|
Iterable[Layer] |
|
|
''' |
|
|
from .graph_rewriting import Layer |
|
|
for i in range(self._trt_network.num_layers): |
|
|
layer = Layer(network=self, |
|
|
trt_layer=self._trt_network.get_layer(i)) |
|
|
yield layer |
|
|
|
|
|
def get_layer_by_name(self, name: str) -> Optional["Layer"]: |
|
|
state = self._get_graph() |
|
|
return state.name_to_layer.get(name, None) |
|
|
|
|
|
def get_tensor_users(self, tensor) -> Iterable["Layer"]: |
|
|
''' |
|
|
Get the layers those consumes this tensor. |
|
|
''' |
|
|
state = self._get_graph() |
|
|
for layer in state.tensor_to_consumers[tensor]: |
|
|
yield layer |
|
|
|
|
|
def get_tensor_parent(self, tensor) -> Optional["Layer"]: |
|
|
''' |
|
|
Get the layer that produces this tensor. |
|
|
''' |
|
|
state = self._get_graph() |
|
|
return state.tensor_to_producer.get(tensor, None) |
|
|
|
|
|
def mark_removed_layer(self, layer: "Layer"): |
|
|
from .graph_rewriting import FLayerInfoMemo |
|
|
self._removed_layers.add(layer.name) |
|
|
|
|
|
|
|
|
FLayerInfoMemo.instance().remove(layer.name) |
|
|
|
|
|
def is_removed_layer(self, layer: "Layer") -> bool: |
|
|
return layer.name in self._removed_layers |
|
|
|
|
|
@property |
|
|
def removed_layers(self) -> Iterable["Layer"]: |
|
|
for layer_name in self._removed_layers: |
|
|
layer = self.get_layer_by_name(layer_name) |
|
|
assert layer, "Invalid layer name" |
|
|
yield layer |
|
|
|
|
|
def to_dot(self, path=None) -> Optional[str]: |
|
|
''' |
|
|
Get a graphviz representation of the network. |
|
|
|
|
|
NOTE, the graph might be redundancy since TRT's INetwork won't clean the unused inputs and layers |
|
|
automatically. |
|
|
TODO: add an flag to hide all the removed layers and their output tensors |
|
|
TODO: replace this when TensorRT provides a better way to get the graph of INetworkDefinition |
|
|
TODO: a little feature, add blocks in the figure to highlight the subgraphes of Modules |
|
|
|
|
|
Parameters: |
|
|
path: the path to save the graphviz file, if not provided, will return the graphviz source code |
|
|
''' |
|
|
format = 'text' if not path else path.split('.')[-1] |
|
|
|
|
|
try: |
|
|
import graphviz |
|
|
except ImportError: |
|
|
logger.error( |
|
|
"Failed to import graphviz, please install graphviz to enable Network.to_dot()" |
|
|
) |
|
|
return |
|
|
|
|
|
dot = graphviz.Digraph( |
|
|
comment= |
|
|
f'TensorRT Graph of {self._get_network_hash(lightweight=False)}', |
|
|
format=format if format != 'text' else None) |
|
|
|
|
|
inputs_names = set([x.name for x in self.get_inputs()]) |
|
|
output_names = set([x.name for x in self.get_outputs()]) |
|
|
|
|
|
node_style = dict( |
|
|
shape='box', |
|
|
style='rounded,filled,bold', |
|
|
fontname='Arial', |
|
|
fillcolor='#ffffff', |
|
|
color='#303A3A', |
|
|
width='1.3', |
|
|
height='0.84', |
|
|
) |
|
|
|
|
|
hl_node_style = dict( |
|
|
shape='box', |
|
|
style='rounded,filled,bold', |
|
|
fontname='Arial', |
|
|
fillcolor='lightblue', |
|
|
color='#303A3A', |
|
|
width='1.3', |
|
|
height='0.84', |
|
|
) |
|
|
|
|
|
state = self._get_graph() |
|
|
nodes = set() |
|
|
tensor_to_alias = {} |
|
|
tensor_id = [0] |
|
|
|
|
|
def get_alias(tensor, tensor_id): |
|
|
if tensor not in tensor_to_alias: |
|
|
if (not tensor in inputs_names) and (not tensor |
|
|
in output_names): |
|
|
tensor_to_alias[tensor] = f"t{tensor_id[0]}" |
|
|
tensor_id[0] += 1 |
|
|
else: |
|
|
tensor_to_alias[tensor] = tensor |
|
|
|
|
|
return tensor_to_alias[tensor] |
|
|
|
|
|
def create_tensor_node(tensor: str, dtype=None, shape=None): |
|
|
tensor_alias = get_alias(tensor, tensor_id) |
|
|
if tensor_alias not in nodes: |
|
|
dot.node(tensor_alias, |
|
|
str(dtype) + "\n" + tensor_alias + "\n" + str(shape), |
|
|
**node_style) |
|
|
nodes.add(tensor_alias) |
|
|
return tensor_alias |
|
|
|
|
|
def create_layer_node(layer: str): |
|
|
if layer not in nodes: |
|
|
dot.node(layer, layer, **hl_node_style) |
|
|
nodes.add(layer) |
|
|
|
|
|
for tensor, layer in state.tensor_to_producer.items(): |
|
|
tensor_alias = create_tensor_node(tensor.name, tensor.dtype, |
|
|
tensor.shape) |
|
|
create_layer_node(layer.name) |
|
|
dot.edge(layer.name, tensor_alias) |
|
|
for tensor, layers in state.tensor_to_consumers.items(): |
|
|
tensor_alias = create_tensor_node(tensor.name, tensor.dtype, |
|
|
tensor.shape) |
|
|
for layer in layers: |
|
|
create_layer_node(layer.name) |
|
|
dot.edge(tensor_alias, layer.name) |
|
|
|
|
|
if format == "text": |
|
|
return dot.source |
|
|
dot.save(path) |
|
|
|
|
|
def _get_graph(self) -> "Network._GraphState": |
|
|
''' |
|
|
Get the graph of the network. |
|
|
|
|
|
Returns: |
|
|
Network._GraphState |
|
|
''' |
|
|
return self._get_graph_impl(self._get_network_hash()) |
|
|
|
|
|
|
|
|
def _get_graph_impl(self, network_hash: bytes) -> "Network._GraphState": |
|
|
graph = Network._GraphState() |
|
|
graph.build(self) |
|
|
return graph |
|
|
|
|
|
@dataclass |
|
|
class _GraphState: |
|
|
|
|
|
tensor_to_consumers: Dict[Any, List["Layer"]] = field( |
|
|
default_factory=lambda: defaultdict(list)) |
|
|
|
|
|
tensor_to_producer: Dict[Any, "Layer"] = field(default_factory=dict) |
|
|
inputs: Dict[str, Any] = field(default_factory=OrderedDict) |
|
|
outputs: Dict[str, Any] = field(default_factory=OrderedDict) |
|
|
name_to_layer: Dict[str, "Layer"] = field(default_factory=dict) |
|
|
|
|
|
def build(self, network: "Network") -> None: |
|
|
from .graph_rewriting import Layer |
|
|
self.inputs = network.get_inputs() |
|
|
self.outputs = network.get_outputs() |
|
|
|
|
|
for layer in network.get_layers(): |
|
|
self.name_to_layer[layer.name] = Layer( |
|
|
network=network, trt_layer=layer.trt_layer) |
|
|
for i in range(layer.num_inputs): |
|
|
input_tensor = layer.get_inputs(i)[0] |
|
|
if input_tensor.is_trt_wrapper(): |
|
|
self.tensor_to_consumers[input_tensor].append(layer) |
|
|
for i in range(layer.num_outputs): |
|
|
output_tensor = layer.get_outputs(i)[0] |
|
|
if output_tensor.is_trt_wrapper(): |
|
|
self.tensor_to_producer[output_tensor] = layer |
|
|
|
|
|
def _get_network_hash(self, lightweight=True) -> bytes: |
|
|
|
|
|
num_layers = self.trt_network.num_layers |
|
|
|
|
|
|
|
|
get_tensor_tag = lambda tensor: tensor.trt_tensor.name if tensor.is_trt_wrapper( |
|
|
) else 'None' |
|
|
|
|
|
if lightweight and not self.is_graph_altered: |
|
|
return num_layers |
|
|
self.is_graph_altered = False |
|
|
|
|
|
data = hashlib.sha256() |
|
|
|
|
|
data.update(str(num_layers).encode()) |
|
|
|
|
|
data.update(','.join( |
|
|
[get_tensor_tag(tensor) for tensor in self.get_inputs()]).encode()) |
|
|
|
|
|
data.update(','.join( |
|
|
[get_tensor_tag(tensor) for tensor in self.get_outputs()]).encode()) |
|
|
|
|
|
data.update(','.join( |
|
|
[layer.trt_layer.name for layer in self.get_layers()]).encode()) |
|
|
|
|
|
|
|
|
data.update(','.join([ |
|
|
f'{layer.trt_layer.name}->{get_tensor_tag(tensor)}' |
|
|
for layer in self.get_layers() for tensor in layer.get_outputs() |
|
|
]).encode()) |
|
|
|
|
|
|
|
|
data.update(','.join([ |
|
|
f'{get_tensor_tag(tensor)}->{layer.trt_layer.name}' |
|
|
for layer in self.get_layers() for tensor in layer.get_inputs() |
|
|
]).encode()) |
|
|
|
|
|
return data.hexdigest() |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def net_guard(network): |
|
|
from ._common import net |
|
|
assert isinstance( |
|
|
network, Network |
|
|
), f"Invalid network, can only guard Network instance, got: {network}" |
|
|
|
|
|
old_net = net |
|
|
set_network(network) |
|
|
yield |
|
|
set_network(old_net) |
|
|
|
|
|
|
|
|
class _TrtLlmModuleCallStack(object): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.call_stack = [] |
|
|
self.module_name_map = weakref.WeakKeyDictionary() |
|
|
self.module_to_layer_range_map: Dict[str, range] = {} |
|
|
self.mod_names_set = False |
|
|
|
|
|
def module_names_set(self): |
|
|
return self.mod_names_set |
|
|
|
|
|
def set_module_names(self, top_level_module): |
|
|
assert top_level_module, "Expected a top level module" |
|
|
for name, mod in top_level_module.named_modules( |
|
|
prefix=top_level_module._get_name()): |
|
|
if mod not in self.module_name_map: |
|
|
self.module_name_map[mod] = name |
|
|
self.mod_names_set = True |
|
|
return |
|
|
|
|
|
def get_current_module(self): |
|
|
mod_name = '' |
|
|
if len(self.call_stack): |
|
|
mod_name = self.call_stack[-1] |
|
|
return mod_name |
|
|
|
|
|
def get_mod_name(self, mod_obj): |
|
|
name = '' |
|
|
if mod_obj in self.module_name_map: |
|
|
name = self.module_name_map[mod_obj] |
|
|
return name |
|
|
|
|
|
def set_layer_range(self, mod_obj: Module, layer_range: range): |
|
|
if mod_obj in self.module_name_map: |
|
|
name = self.module_name_map[mod_obj] |
|
|
self.module_to_layer_range_map[name] = layer_range |
|
|
|
|
|
def get_stack(self): |
|
|
return self.call_stack |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def call_stack_mgr(self): |
|
|
call_stack = self.get_stack() |
|
|
try: |
|
|
yield call_stack |
|
|
finally: |
|
|
call_stack.pop() |
|
|
|