| """This file exports ONNX ops for opset 17. |
| |
| Note [ONNX Operators that are added/updated in opset 17] |
| |
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set |
| New operators: |
| BlackmanWindow |
| DFT |
| HammingWindow |
| HannWindow |
| LayerNormalization |
| MelWeightMatrix |
| STFT |
| SequenceMap |
| """ |
|
|
| import functools |
| from typing import Sequence |
|
|
| from torch import _C |
| from torch.onnx import symbolic_helper |
| from torch.onnx._internal import jit_utils, registration |
|
|
| |
| |
|
|
| __all__ = ["layer_norm"] |
|
|
| _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) |
|
|
|
|
| @_onnx_symbolic("aten::layer_norm") |
| @symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") |
| def layer_norm( |
| g: jit_utils.GraphContext, |
| input: _C.Value, |
| normalized_shape: Sequence[int], |
| weight: _C.Value, |
| bias: _C.Value, |
| eps: float, |
| cudnn_enable: bool, |
| ): |
| |
| |
| |
| |
| axis = -len(normalized_shape) |
| return g.op( |
| "LayerNormalization", |
| input, |
| weight, |
| bias, |
| epsilon_f=eps, |
| axis_i=axis, |
| ) |
|
|