|
|
"""ONNX exporter exceptions.""" |
|
|
from __future__ import annotations |
|
|
|
|
|
import textwrap |
|
|
from typing import Optional |
|
|
|
|
|
from torch import _C |
|
|
from torch.onnx import _constants |
|
|
from torch.onnx._internal import diagnostics |
|
|
|
|
|
__all__ = [ |
|
|
"OnnxExporterError", |
|
|
"OnnxExporterWarning", |
|
|
"CallHintViolationWarning", |
|
|
"CheckerError", |
|
|
"UnsupportedOperatorError", |
|
|
"SymbolicValueError", |
|
|
] |
|
|
|
|
|
|
|
|
class OnnxExporterWarning(UserWarning): |
|
|
"""Base class for all warnings in the ONNX exporter.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class CallHintViolationWarning(OnnxExporterWarning): |
|
|
"""Warning raised when a type hint is violated during a function call.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class OnnxExporterError(RuntimeError): |
|
|
"""Errors raised by the ONNX exporter.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class CheckerError(OnnxExporterError): |
|
|
"""Raised when ONNX checker detects an invalid model.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class UnsupportedOperatorError(OnnxExporterError): |
|
|
"""Raised when an operator is unsupported by the exporter.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
domain: str, |
|
|
op_name: str, |
|
|
version: int, |
|
|
supported_version: Optional[int], |
|
|
): |
|
|
if domain in {"", "aten", "prim", "quantized"}: |
|
|
msg = f"Exporting the operator '{domain}::{op_name}' to ONNX opset version {version} is not supported. " |
|
|
if supported_version is not None: |
|
|
msg += ( |
|
|
f"Support for this operator was added in version {supported_version}, " |
|
|
"try exporting with this version." |
|
|
) |
|
|
diagnostics.context.diagnose( |
|
|
diagnostics.rules.operator_supported_in_newer_opset_version, |
|
|
diagnostics.levels.ERROR, |
|
|
message_args=( |
|
|
f"{domain}::{op_name}", |
|
|
version, |
|
|
supported_version, |
|
|
), |
|
|
) |
|
|
else: |
|
|
msg += "Please feel free to request support or submit a pull request on PyTorch GitHub: " |
|
|
msg += _constants.PYTORCH_GITHUB_ISSUES_URL |
|
|
diagnostics.context.diagnose( |
|
|
diagnostics.rules.missing_standard_symbolic_function, |
|
|
diagnostics.levels.ERROR, |
|
|
message_args=( |
|
|
f"{domain}::{op_name}", |
|
|
version, |
|
|
_constants.PYTORCH_GITHUB_ISSUES_URL, |
|
|
), |
|
|
) |
|
|
else: |
|
|
msg = ( |
|
|
f"ONNX export failed on an operator with unrecognized namespace '{domain}::{op_name}'. " |
|
|
"If you are trying to export a custom operator, make sure you registered " |
|
|
"it with the right domain and version." |
|
|
) |
|
|
diagnostics.context.diagnose( |
|
|
diagnostics.rules.missing_custom_symbolic_function, |
|
|
diagnostics.levels.ERROR, |
|
|
message_args=(f"{domain}::{op_name}",), |
|
|
) |
|
|
super().__init__(msg) |
|
|
|
|
|
|
|
|
class SymbolicValueError(OnnxExporterError): |
|
|
"""Errors around TorchScript values and nodes.""" |
|
|
|
|
|
def __init__(self, msg: str, value: _C.Value): |
|
|
message = ( |
|
|
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " |
|
|
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " |
|
|
) |
|
|
|
|
|
code_location = value.node().sourceRange() |
|
|
if code_location: |
|
|
message += f"\n (node defined in {code_location})" |
|
|
|
|
|
try: |
|
|
|
|
|
message += "\n\n" |
|
|
message += textwrap.indent( |
|
|
( |
|
|
"Inputs:\n" |
|
|
+ ( |
|
|
"\n".join( |
|
|
f" #{i}: {input_} (type '{input_.type()}')" |
|
|
for i, input_ in enumerate(value.node().inputs()) |
|
|
) |
|
|
or " Empty" |
|
|
) |
|
|
+ "\n" |
|
|
+ "Outputs:\n" |
|
|
+ ( |
|
|
"\n".join( |
|
|
f" #{i}: {output} (type '{output.type()}')" |
|
|
for i, output in enumerate(value.node().outputs()) |
|
|
) |
|
|
or " Empty" |
|
|
) |
|
|
), |
|
|
" ", |
|
|
) |
|
|
except AttributeError: |
|
|
message += ( |
|
|
" Failed to obtain its input and output for debugging. " |
|
|
"Please refer to the TorchScript graph for debugging information." |
|
|
) |
|
|
|
|
|
super().__init__(message) |
|
|
|