File size: 16,164 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 |
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import Dict, List, Optional, Union
import torch
from lightning.pytorch.core.module import _jit_is_scripting
from nemo.core.classes import typecheck
from nemo.core.neural_types import NeuralType
from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names
from nemo.utils import logging, monkeypatched
from nemo.utils.export_utils import (
ExportFormat,
augment_filename,
get_export_format,
parse_input_example,
rename_onnx_io,
replace_for_export,
verify_runtime,
verify_torchscript,
wrap_forward_method,
)
__all__ = ['ExportFormat', 'Exportable']
class Exportable(ABC):
"""
This Interface should be implemented by particular classes derived from nemo.core.NeuralModule or nemo.core.ModelPT.
It gives these entities ability to be exported for deployment to formats such as ONNX.
Usage:
# exporting pre-trained model to ONNX file for deployment.
model.eval()
model.to('cuda') # or to('cpu') if you don't have GPU
model.export('mymodel.onnx', [options]) # all arguments apart from `output` are optional.
"""
@property
def input_module(self):
return self
@property
def output_module(self):
return self
def export(
self,
output: str,
input_example=None,
verbose=False,
do_constant_folding=True,
onnx_opset_version=None,
check_trace: Union[bool, List[torch.Tensor]] = False,
dynamic_axes=None,
check_tolerance=0.01,
export_modules_as_functions=False,
keep_initializers_as_inputs=None,
use_dynamo=False,
):
"""
Exports the model to the specified format. The format is inferred from the file extension of the output file.
Args:
output (str): Output file name. File extension be .onnx, .pt, or .ts, and is used to select export
path of the model.
input_example (list or dict): Example input to the model's forward function. This is used to
trace the model and export it to ONNX/TorchScript. If the model takes multiple inputs, then input_example
should be a list of input examples. If the model takes named inputs, then input_example
should be a dictionary of input examples.
verbose (bool): If True, will print out a detailed description of the model's export steps, along with
the internal trace logs of the export process.
do_constant_folding (bool): If True, will execute constant folding optimization on the model's graph
before exporting. This is ONNX specific.
onnx_opset_version (int): The ONNX opset version to export the model to. If None, will use a reasonable
default version.
check_trace (bool): If True, will verify that the model's output matches the output of the traced
model, upto some tolerance.
dynamic_axes (dict): A dictionary mapping input and output names to their dynamic axes. This is
used to specify the dynamic axes of the model's inputs and outputs. If the model takes multiple inputs,
then dynamic_axes should be a list of dictionaries. If the model takes named inputs, then dynamic_axes
should be a dictionary of dictionaries. If None, will use the dynamic axes of the input_example
derived from the NeuralType of the input and output of the model.
check_tolerance (float): The tolerance to use when checking the model's output against the traced
model's output. This is only used if check_trace is True. Note the high tolerance is used because
the traced model is not guaranteed to be 100% accurate.
export_modules_as_functions (bool): If True, will export the model's submodules as functions. This is
ONNX specific.
keep_initializers_as_inputs (bool): If True, will keep the model's initializers as inputs in the onnx graph.
This is ONNX specific.
use_dynamo (bool): If True, use onnx.dynamo_export() instead of onnx.export(). This is ONNX specific.
Returns:
A tuple of two outputs.
Item 0 in the output is a list of outputs, the outputs of each subnet exported.
Item 1 in the output is a list of string descriptions. The description of each subnet exported can be
used for logging purposes.
"""
all_out = []
all_descr = []
for subnet_name in self.list_export_subnets():
model = self.get_export_subnet(subnet_name)
out_name = augment_filename(output, subnet_name)
out, descr, out_example = model._export(
out_name,
input_example=input_example,
verbose=verbose,
do_constant_folding=do_constant_folding,
onnx_opset_version=onnx_opset_version,
check_trace=check_trace,
dynamic_axes=dynamic_axes,
check_tolerance=check_tolerance,
export_modules_as_functions=export_modules_as_functions,
keep_initializers_as_inputs=keep_initializers_as_inputs,
use_dynamo=use_dynamo,
)
# Propagate input example (default scenario, may need to be overriden)
if input_example is not None:
input_example = out_example
all_out.append(out)
all_descr.append(descr)
logging.info("Successfully exported {} to {}".format(model.__class__.__name__, out_name))
return (all_out, all_descr)
def _export(
self,
output: str,
input_example=None,
verbose=False,
do_constant_folding=True,
onnx_opset_version=None,
check_trace: Union[bool, List[torch.Tensor]] = False,
dynamic_axes=None,
check_tolerance=0.01,
export_modules_as_functions=False,
keep_initializers_as_inputs=None,
use_dynamo=False,
):
my_args = locals().copy()
my_args.pop('self')
self.eval()
for param in self.parameters():
param.requires_grad = False
exportables = []
for m in self.modules():
if isinstance(m, Exportable):
exportables.append(m)
qual_name = self.__module__ + '.' + self.__class__.__qualname__
format = get_export_format(output)
output_descr = f"{qual_name} exported to {format}"
# Pytorch's default opset version is too low, using reasonable latest one
if onnx_opset_version is None:
onnx_opset_version = 17
try:
# Disable typechecks
typecheck.set_typecheck_enabled(enabled=False)
# Allow user to completely override forward method to export
forward_method, old_forward_method = wrap_forward_method(self)
# Set module mode
with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting():
if input_example is None:
input_example = self.input_module.input_example()
# Remove i/o examples from args we propagate to enclosed Exportables
my_args.pop('output')
my_args.pop('input_example')
# Run (posibly overridden) prepare methods before calling forward()
for ex in exportables:
ex._prepare_for_export(**my_args, noreplace=True)
self._prepare_for_export(output=output, input_example=input_example, **my_args)
input_list, input_dict = parse_input_example(input_example)
input_names = self.input_names
output_names = self.output_names
output_example = self.forward(*input_list, **input_dict)
if not isinstance(output_example, tuple):
output_example = (output_example,)
if check_trace:
if isinstance(check_trace, bool):
check_trace_input = [input_example]
else:
check_trace_input = check_trace
if format == ExportFormat.TORCHSCRIPT:
jitted_model = torch.jit.trace_module(
self,
{"forward": tuple(input_list) + tuple(input_dict.values())},
strict=True,
check_trace=check_trace,
check_tolerance=check_tolerance,
)
jitted_model = torch.jit.freeze(jitted_model)
if verbose:
logging.info(f"JIT code:\n{jitted_model.code}")
jitted_model.save(output)
jitted_model = torch.jit.load(output)
if check_trace:
verify_torchscript(jitted_model, output, check_trace_input, check_tolerance)
elif format == ExportFormat.ONNX:
# dynamic axis is a mapping from input/output_name => list of "dynamic" indices
if dynamic_axes is None:
dynamic_axes = self.dynamic_shapes_for_export(use_dynamo)
if use_dynamo:
typecheck.enable_wrapping(enabled=False)
# https://github.com/pytorch/pytorch/issues/126339
with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None):
logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n")
# We have to use different types of arguments for dynamo_export to achieve
# same external weights behaviour as onnx.export :
# https://github.com/pytorch/pytorch/issues/126479
# https://github.com/pytorch/pytorch/issues/126269
mem_params = sum([param.nelement() * param.element_size() for param in self.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem_params + mem_bufs
if mem > 2 * 1000 * 1000 * 1000:
ex_model = torch.export.export(
self,
tuple(input_list),
kwargs=input_dict,
dynamic_shapes=dynamic_axes,
strict=False,
)
ex_model = ex_model.run_decompositions()
model_state = ex_model.state_dict
else:
model_state = None
ex_model = self
options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True)
ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options)
ex.save(output, model_state=model_state)
del ex
del ex_model
# Rename I/O after save - don't want to risk modifying ex._model_proto
rename_onnx_io(output, input_names, output_names)
else:
torch.onnx.export(
self,
input_example,
output,
input_names=input_names,
output_names=output_names,
verbose=verbose,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
opset_version=onnx_opset_version,
keep_initializers_as_inputs=keep_initializers_as_inputs,
export_modules_as_functions=export_modules_as_functions,
)
if check_trace:
verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance)
else:
raise ValueError(f'Encountered unknown export format {format}.')
finally:
typecheck.enable_wrapping(enabled=True)
typecheck.set_typecheck_enabled(enabled=True)
if forward_method:
type(self).forward = old_forward_method
self._export_teardown()
return (output, output_descr, output_example)
@property
def disabled_deployment_input_names(self) -> List[str]:
"""Implement this method to return a set of input names disabled for export"""
return []
@property
def disabled_deployment_output_names(self) -> List[str]:
"""Implement this method to return a set of output names disabled for export"""
return []
@property
def supported_export_formats(self) -> List[ExportFormat]:
"""Implement this method to return a set of export formats supported. Default is all types."""
return [ExportFormat.ONNX, ExportFormat.TORCHSCRIPT]
def _prepare_for_export(self, **kwargs):
"""
Override this method to prepare module for export. This is in-place operation.
Base version does common necessary module replacements (Apex etc)
"""
if not 'noreplace' in kwargs:
replace_for_export(self)
def _export_teardown(self):
"""
Override this method for any teardown code after export.
"""
pass
@property
def input_names(self):
return get_io_names(self.input_module.input_types_for_export, self.disabled_deployment_input_names)
@property
def output_names(self):
return get_io_names(self.output_module.output_types_for_export, self.disabled_deployment_output_names)
@property
def input_types_for_export(self) -> Optional[Dict[str, NeuralType]]:
return self.input_types
@property
def output_types_for_export(self):
return self.output_types
def dynamic_shapes_for_export(self, use_dynamo=False):
return get_dynamic_axes(self.input_module.input_types_for_export, self.input_names, use_dynamo)
def get_export_subnet(self, subnet=None):
"""
Returns Exportable subnet model/module to export
"""
if subnet is None or subnet == 'self':
return self
else:
return getattr(self, subnet)
def list_export_subnets(self):
"""
Returns default set of subnet names exported for this model
First goes the one receiving input (input_example)
"""
return ['self']
def get_export_config(self):
"""
Returns export_config dictionary
"""
return getattr(self, 'export_config', {})
def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""
ex_config = self.get_export_config()
ex_config.update(args)
self.export_config = ex_config
|