| | |
| | |
| | """Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns: |
| | - BatchNormalization ∘ Conv -> Conv |
| | - BatchNormalization ∘ ConvTranpose -> ConvTranpose |
| | - BatchNormalization ∘ Gemm -> Gemm |
| | |
| | Approach: |
| | Given an inbound operation output: Y = W * X + B |
| | And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps) |
| | |
| | The fusion updates the inbound weights as follows: |
| | - W_fused = W * (gamma / std) |
| | - B_fused = (B - μ) * (gamma / std) + β |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import Mapping |
| |
|
| | import numpy as np |
| |
|
| | from onnxscript import ir |
| | from onnxscript.rewriter._basics import MatchResult |
| | from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet |
| |
|
| |
|
| | def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray: |
| | |
| | broadcast_shape = [1 if axis != i else -1 for i in range(rank)] |
| | return np.reshape(x, broadcast_shape) |
| |
|
| |
|
| | class _FuseBatchNormBase(RewriteRuleClassBase, ABC): |
| | """Interface for BatchNormalization nodes fusion.""" |
| |
|
| | def __init__( |
| | self, |
| | op_type: str, |
| | name: str | None = None, |
| | remove_nodes: bool = True, |
| | as_function: bool = False, |
| | ) -> None: |
| | super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function) |
| | self.op_type = op_type |
| |
|
| | @abstractmethod |
| | def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: |
| | """Return the axis along which BatchNorm scale should be broadcasted.""" |
| |
|
| | def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value): |
| | batchnorm_node = batchnorm_out.producer() |
| | |
| | gamma, beta, input_mean, input_var = [ |
| | inp.const_value.numpy() for inp in batchnorm_node.inputs[1:] |
| | ] |
| |
|
| | |
| | |
| | default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5) |
| | eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float() |
| |
|
| | |
| | scale_factor = gamma / np.sqrt(input_var + eps) |
| |
|
| | |
| | inbound_node = inbound_out.producer() |
| | weights = inbound_node.inputs[1].const_value.numpy() |
| |
|
| | |
| | axis = self.get_filters_axis(inbound_node.attributes) |
| | fused_weights = ir.tensor( |
| | weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis) |
| | ) |
| |
|
| | |
| | if len(inbound_node.inputs) > 2: |
| | original_bias = inbound_node.inputs[2].const_value.numpy() |
| | bias_name = inbound_node.inputs[2].name |
| | else: |
| | original_bias = np.zeros_like(input_mean) |
| | bias_name = x.name + "_bias" |
| | fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta) |
| |
|
| | return op.op( |
| | self.op_type, |
| | inputs=[ |
| | x, |
| | op.initializer(fused_weights, name=inbound_node.inputs[1].name), |
| | op.initializer(fused_bias, name=bias_name), |
| | ], |
| | attributes=inbound_node.attributes, |
| | ) |
| |
|
| | def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> MatchResult: |
| | del context |
| | check_result = MatchResult() |
| |
|
| | inbound_node = inbound_out.producer() |
| | batchnorm_node = batchnorm_out.producer() |
| |
|
| | |
| | |
| | initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]] |
| | if len(inbound_node.inputs) > 2: |
| | initializers.append(inbound_node.inputs[2]) |
| |
|
| | for initializer in initializers: |
| | if not initializer.is_initializer() or initializer.const_value is None: |
| | return check_result.fail(f"{initializer.name} is not a constant initializer.") |
| | if initializer.is_graph_input(): |
| | return check_result.fail(f"{initializer.name} is a graph input.") |
| |
|
| | return check_result |
| |
|
| |
|
| | class FuseBatchNormIntoConv(_FuseBatchNormBase): |
| | """Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``.""" |
| |
|
| | def __init__(self): |
| | super().__init__("Conv") |
| |
|
| | def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: |
| | return 0 |
| |
|
| | def pattern(self, op, x): |
| | return op.BatchNormalization( |
| | op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]), |
| | _allow_other_inputs=True, |
| | _outputs=["batchnorm_out"], |
| | ) |
| |
|
| |
|
| | class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase): |
| | """Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``.""" |
| |
|
| | def __init__(self): |
| | super().__init__("ConvTranspose") |
| |
|
| | def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: |
| | return 1 |
| |
|
| | def pattern(self, op, x): |
| | return op.BatchNormalization( |
| | op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]), |
| | _allow_other_inputs=True, |
| | _outputs=["batchnorm_out"], |
| | ) |
| |
|
| |
|
| | class FuseBatchNormIntoGemm(_FuseBatchNormBase): |
| | """Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``.""" |
| |
|
| | def __init__(self): |
| | super().__init__("Gemm") |
| |
|
| | def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int: |
| | return ( |
| | 0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1 |
| | ) |
| |
|
| | def pattern(self, op, x): |
| | return op.BatchNormalization( |
| | op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]), |
| | _allow_other_inputs=True, |
| | _outputs=["batchnorm_out"], |
| | ) |
| |
|
| |
|
| | fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule() |
| | fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule() |
| | fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule() |
| |
|
| |
|
| | def fuse_batchnorm_rule_set() -> RewriteRuleSet: |
| | """Returns a set of rewrite rules that fuse BatchNormalization nodes |
| | into preceding nodes such as Conv, ConvTranspose, and Gemm. |
| | |
| | Returns: |
| | RewriteRuleSet |
| | """ |
| | return RewriteRuleSet( |
| | [ |
| | fuse_batchnorm_into_conv_rule, |
| | fuse_batchnorm_into_convtranspose_rule, |
| | fuse_batchnorm_into_gemm_rule, |
| | ] |
| | ) |
| |
|