|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from abc import ABC |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from lightning.pytorch.core.module import _jit_is_scripting |
|
|
|
|
|
from nemo.core.classes import typecheck |
|
|
from nemo.core.neural_types import NeuralType |
|
|
from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names |
|
|
from nemo.utils import logging, monkeypatched |
|
|
from nemo.utils.export_utils import ( |
|
|
ExportFormat, |
|
|
augment_filename, |
|
|
get_export_format, |
|
|
parse_input_example, |
|
|
rename_onnx_io, |
|
|
replace_for_export, |
|
|
verify_runtime, |
|
|
verify_torchscript, |
|
|
wrap_forward_method, |
|
|
) |
|
|
|
|
|
__all__ = ['ExportFormat', 'Exportable'] |
|
|
|
|
|
|
|
|
class Exportable(ABC): |
|
|
""" |
|
|
This Interface should be implemented by particular classes derived from nemo.core.NeuralModule or nemo.core.ModelPT. |
|
|
It gives these entities ability to be exported for deployment to formats such as ONNX. |
|
|
|
|
|
Usage: |
|
|
# exporting pre-trained model to ONNX file for deployment. |
|
|
model.eval() |
|
|
model.to('cuda') # or to('cpu') if you don't have GPU |
|
|
|
|
|
model.export('mymodel.onnx', [options]) # all arguments apart from `output` are optional. |
|
|
""" |
|
|
|
|
|
@property |
|
|
def input_module(self): |
|
|
return self |
|
|
|
|
|
@property |
|
|
def output_module(self): |
|
|
return self |
|
|
|
|
|
def export( |
|
|
self, |
|
|
output: str, |
|
|
input_example=None, |
|
|
verbose=False, |
|
|
do_constant_folding=True, |
|
|
onnx_opset_version=None, |
|
|
check_trace: Union[bool, List[torch.Tensor]] = False, |
|
|
dynamic_axes=None, |
|
|
check_tolerance=0.01, |
|
|
export_modules_as_functions=False, |
|
|
keep_initializers_as_inputs=None, |
|
|
use_dynamo=False, |
|
|
): |
|
|
""" |
|
|
Exports the model to the specified format. The format is inferred from the file extension of the output file. |
|
|
|
|
|
Args: |
|
|
output (str): Output file name. File extension be .onnx, .pt, or .ts, and is used to select export |
|
|
path of the model. |
|
|
input_example (list or dict): Example input to the model's forward function. This is used to |
|
|
trace the model and export it to ONNX/TorchScript. If the model takes multiple inputs, then input_example |
|
|
should be a list of input examples. If the model takes named inputs, then input_example |
|
|
should be a dictionary of input examples. |
|
|
verbose (bool): If True, will print out a detailed description of the model's export steps, along with |
|
|
the internal trace logs of the export process. |
|
|
do_constant_folding (bool): If True, will execute constant folding optimization on the model's graph |
|
|
before exporting. This is ONNX specific. |
|
|
onnx_opset_version (int): The ONNX opset version to export the model to. If None, will use a reasonable |
|
|
default version. |
|
|
check_trace (bool): If True, will verify that the model's output matches the output of the traced |
|
|
model, upto some tolerance. |
|
|
dynamic_axes (dict): A dictionary mapping input and output names to their dynamic axes. This is |
|
|
used to specify the dynamic axes of the model's inputs and outputs. If the model takes multiple inputs, |
|
|
then dynamic_axes should be a list of dictionaries. If the model takes named inputs, then dynamic_axes |
|
|
should be a dictionary of dictionaries. If None, will use the dynamic axes of the input_example |
|
|
derived from the NeuralType of the input and output of the model. |
|
|
check_tolerance (float): The tolerance to use when checking the model's output against the traced |
|
|
model's output. This is only used if check_trace is True. Note the high tolerance is used because |
|
|
the traced model is not guaranteed to be 100% accurate. |
|
|
export_modules_as_functions (bool): If True, will export the model's submodules as functions. This is |
|
|
ONNX specific. |
|
|
keep_initializers_as_inputs (bool): If True, will keep the model's initializers as inputs in the onnx graph. |
|
|
This is ONNX specific. |
|
|
use_dynamo (bool): If True, use onnx.dynamo_export() instead of onnx.export(). This is ONNX specific. |
|
|
|
|
|
Returns: |
|
|
A tuple of two outputs. |
|
|
Item 0 in the output is a list of outputs, the outputs of each subnet exported. |
|
|
Item 1 in the output is a list of string descriptions. The description of each subnet exported can be |
|
|
used for logging purposes. |
|
|
""" |
|
|
all_out = [] |
|
|
all_descr = [] |
|
|
for subnet_name in self.list_export_subnets(): |
|
|
model = self.get_export_subnet(subnet_name) |
|
|
out_name = augment_filename(output, subnet_name) |
|
|
out, descr, out_example = model._export( |
|
|
out_name, |
|
|
input_example=input_example, |
|
|
verbose=verbose, |
|
|
do_constant_folding=do_constant_folding, |
|
|
onnx_opset_version=onnx_opset_version, |
|
|
check_trace=check_trace, |
|
|
dynamic_axes=dynamic_axes, |
|
|
check_tolerance=check_tolerance, |
|
|
export_modules_as_functions=export_modules_as_functions, |
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs, |
|
|
use_dynamo=use_dynamo, |
|
|
) |
|
|
|
|
|
if input_example is not None: |
|
|
input_example = out_example |
|
|
all_out.append(out) |
|
|
all_descr.append(descr) |
|
|
logging.info("Successfully exported {} to {}".format(model.__class__.__name__, out_name)) |
|
|
return (all_out, all_descr) |
|
|
|
|
|
def _export( |
|
|
self, |
|
|
output: str, |
|
|
input_example=None, |
|
|
verbose=False, |
|
|
do_constant_folding=True, |
|
|
onnx_opset_version=None, |
|
|
check_trace: Union[bool, List[torch.Tensor]] = False, |
|
|
dynamic_axes=None, |
|
|
check_tolerance=0.01, |
|
|
export_modules_as_functions=False, |
|
|
keep_initializers_as_inputs=None, |
|
|
use_dynamo=False, |
|
|
): |
|
|
my_args = locals().copy() |
|
|
my_args.pop('self') |
|
|
|
|
|
self.eval() |
|
|
for param in self.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
exportables = [] |
|
|
for m in self.modules(): |
|
|
if isinstance(m, Exportable): |
|
|
exportables.append(m) |
|
|
|
|
|
qual_name = self.__module__ + '.' + self.__class__.__qualname__ |
|
|
format = get_export_format(output) |
|
|
output_descr = f"{qual_name} exported to {format}" |
|
|
|
|
|
|
|
|
if onnx_opset_version is None: |
|
|
onnx_opset_version = 17 |
|
|
|
|
|
try: |
|
|
|
|
|
typecheck.set_typecheck_enabled(enabled=False) |
|
|
|
|
|
|
|
|
forward_method, old_forward_method = wrap_forward_method(self) |
|
|
|
|
|
|
|
|
with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting(): |
|
|
|
|
|
if input_example is None: |
|
|
input_example = self.input_module.input_example() |
|
|
|
|
|
|
|
|
my_args.pop('output') |
|
|
my_args.pop('input_example') |
|
|
|
|
|
|
|
|
for ex in exportables: |
|
|
ex._prepare_for_export(**my_args, noreplace=True) |
|
|
self._prepare_for_export(output=output, input_example=input_example, **my_args) |
|
|
|
|
|
input_list, input_dict = parse_input_example(input_example) |
|
|
input_names = self.input_names |
|
|
output_names = self.output_names |
|
|
output_example = self.forward(*input_list, **input_dict) |
|
|
if not isinstance(output_example, tuple): |
|
|
output_example = (output_example,) |
|
|
|
|
|
if check_trace: |
|
|
if isinstance(check_trace, bool): |
|
|
check_trace_input = [input_example] |
|
|
else: |
|
|
check_trace_input = check_trace |
|
|
|
|
|
if format == ExportFormat.TORCHSCRIPT: |
|
|
jitted_model = torch.jit.trace_module( |
|
|
self, |
|
|
{"forward": tuple(input_list) + tuple(input_dict.values())}, |
|
|
strict=True, |
|
|
check_trace=check_trace, |
|
|
check_tolerance=check_tolerance, |
|
|
) |
|
|
jitted_model = torch.jit.freeze(jitted_model) |
|
|
if verbose: |
|
|
logging.info(f"JIT code:\n{jitted_model.code}") |
|
|
jitted_model.save(output) |
|
|
jitted_model = torch.jit.load(output) |
|
|
|
|
|
if check_trace: |
|
|
verify_torchscript(jitted_model, output, check_trace_input, check_tolerance) |
|
|
elif format == ExportFormat.ONNX: |
|
|
|
|
|
if dynamic_axes is None: |
|
|
dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) |
|
|
if use_dynamo: |
|
|
typecheck.enable_wrapping(enabled=False) |
|
|
|
|
|
with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): |
|
|
logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mem_params = sum([param.nelement() * param.element_size() for param in self.parameters()]) |
|
|
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) |
|
|
mem = mem_params + mem_bufs |
|
|
|
|
|
if mem > 2 * 1000 * 1000 * 1000: |
|
|
ex_model = torch.export.export( |
|
|
self, |
|
|
tuple(input_list), |
|
|
kwargs=input_dict, |
|
|
dynamic_shapes=dynamic_axes, |
|
|
strict=False, |
|
|
) |
|
|
ex_model = ex_model.run_decompositions() |
|
|
model_state = ex_model.state_dict |
|
|
else: |
|
|
model_state = None |
|
|
ex_model = self |
|
|
|
|
|
options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) |
|
|
ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) |
|
|
ex.save(output, model_state=model_state) |
|
|
|
|
|
del ex |
|
|
del ex_model |
|
|
|
|
|
rename_onnx_io(output, input_names, output_names) |
|
|
else: |
|
|
torch.onnx.export( |
|
|
self, |
|
|
input_example, |
|
|
output, |
|
|
input_names=input_names, |
|
|
output_names=output_names, |
|
|
verbose=verbose, |
|
|
do_constant_folding=do_constant_folding, |
|
|
dynamic_axes=dynamic_axes, |
|
|
opset_version=onnx_opset_version, |
|
|
keep_initializers_as_inputs=keep_initializers_as_inputs, |
|
|
export_modules_as_functions=export_modules_as_functions, |
|
|
) |
|
|
|
|
|
if check_trace: |
|
|
verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance) |
|
|
else: |
|
|
raise ValueError(f'Encountered unknown export format {format}.') |
|
|
finally: |
|
|
typecheck.enable_wrapping(enabled=True) |
|
|
typecheck.set_typecheck_enabled(enabled=True) |
|
|
if forward_method: |
|
|
type(self).forward = old_forward_method |
|
|
self._export_teardown() |
|
|
return (output, output_descr, output_example) |
|
|
|
|
|
@property |
|
|
def disabled_deployment_input_names(self) -> List[str]: |
|
|
"""Implement this method to return a set of input names disabled for export""" |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def disabled_deployment_output_names(self) -> List[str]: |
|
|
"""Implement this method to return a set of output names disabled for export""" |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def supported_export_formats(self) -> List[ExportFormat]: |
|
|
"""Implement this method to return a set of export formats supported. Default is all types.""" |
|
|
return [ExportFormat.ONNX, ExportFormat.TORCHSCRIPT] |
|
|
|
|
|
def _prepare_for_export(self, **kwargs): |
|
|
""" |
|
|
Override this method to prepare module for export. This is in-place operation. |
|
|
Base version does common necessary module replacements (Apex etc) |
|
|
""" |
|
|
if not 'noreplace' in kwargs: |
|
|
replace_for_export(self) |
|
|
|
|
|
def _export_teardown(self): |
|
|
""" |
|
|
Override this method for any teardown code after export. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
def input_names(self): |
|
|
return get_io_names(self.input_module.input_types_for_export, self.disabled_deployment_input_names) |
|
|
|
|
|
@property |
|
|
def output_names(self): |
|
|
return get_io_names(self.output_module.output_types_for_export, self.disabled_deployment_output_names) |
|
|
|
|
|
@property |
|
|
def input_types_for_export(self) -> Optional[Dict[str, NeuralType]]: |
|
|
return self.input_types |
|
|
|
|
|
@property |
|
|
def output_types_for_export(self): |
|
|
return self.output_types |
|
|
|
|
|
def dynamic_shapes_for_export(self, use_dynamo=False): |
|
|
return get_dynamic_axes(self.input_module.input_types_for_export, self.input_names, use_dynamo) |
|
|
|
|
|
def get_export_subnet(self, subnet=None): |
|
|
""" |
|
|
Returns Exportable subnet model/module to export |
|
|
""" |
|
|
if subnet is None or subnet == 'self': |
|
|
return self |
|
|
else: |
|
|
return getattr(self, subnet) |
|
|
|
|
|
def list_export_subnets(self): |
|
|
""" |
|
|
Returns default set of subnet names exported for this model |
|
|
First goes the one receiving input (input_example) |
|
|
""" |
|
|
return ['self'] |
|
|
|
|
|
def get_export_config(self): |
|
|
""" |
|
|
Returns export_config dictionary |
|
|
""" |
|
|
return getattr(self, 'export_config', {}) |
|
|
|
|
|
def set_export_config(self, args): |
|
|
""" |
|
|
Sets/updates export_config dictionary |
|
|
""" |
|
|
ex_config = self.get_export_config() |
|
|
ex_config.update(args) |
|
|
self.export_config = ex_config |
|
|
|