xiaoanyu123's picture
Add files using upload-large-folder tool
6a22ec9 verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Fuses BatchNormalization nodes into preceding nodes. Supported fusion patterns:
- BatchNormalization ∘ Conv -> Conv
- BatchNormalization ∘ ConvTranpose -> ConvTranpose
- BatchNormalization ∘ Gemm -> Gemm
Approach:
Given an inbound operation output: Y = W * X + B
And a BatchNormalization outputs: Y_BN = (gamma * (Y - μ) / std) + β, where std = sqrt(var + eps)
The fusion updates the inbound weights as follows:
- W_fused = W * (gamma / std)
- B_fused = (B - μ) * (gamma / std) + β
"""
from abc import ABC, abstractmethod
from typing import Mapping
import numpy as np
from onnxscript import ir
from onnxscript.rewriter._basics import MatchResult
from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet
def _reshape_for_broadcast(x: np.ndarray, rank: int, axis: int = 1) -> np.ndarray:
# Build shape: 1s everywhere except -1 at the target axis
broadcast_shape = [1 if axis != i else -1 for i in range(rank)]
return np.reshape(x, broadcast_shape)
class _FuseBatchNormBase(RewriteRuleClassBase, ABC):
"""Interface for BatchNormalization nodes fusion."""
def __init__(
self,
op_type: str,
name: str | None = None,
remove_nodes: bool = True,
as_function: bool = False,
) -> None:
super().__init__(name=name, remove_nodes=remove_nodes, as_function=as_function)
self.op_type = op_type
@abstractmethod
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
"""Return the axis along which BatchNorm scale should be broadcasted."""
def rewrite(self, op, x: ir.Value, inbound_out: ir.Value, batchnorm_out: ir.Value):
batchnorm_node = batchnorm_out.producer()
# Get BatchNorm parameters
gamma, beta, input_mean, input_var = [
inp.const_value.numpy() for inp in batchnorm_node.inputs[1:]
]
# 1e-5 is the default value for epsilon according to
# https://onnx.ai/onnx/operators/onnx__BatchNormalization.html#attributes
default_eps = ir.Attr("epsilon", ir.AttributeType.FLOAT, 1e-5)
eps = batchnorm_node.attributes.get("epsilon", default_eps).as_float()
# Compute the scale_factor to update the inbound weights and bias
scale_factor = gamma / np.sqrt(input_var + eps)
# Update inbound weights
inbound_node = inbound_out.producer()
weights = inbound_node.inputs[1].const_value.numpy()
# Reshape scale factor so it is broadcastable
axis = self.get_filters_axis(inbound_node.attributes)
fused_weights = ir.tensor(
weights * _reshape_for_broadcast(scale_factor, weights.ndim, axis=axis)
)
# Update bias
if len(inbound_node.inputs) > 2:
original_bias = inbound_node.inputs[2].const_value.numpy()
bias_name = inbound_node.inputs[2].name
else:
original_bias = np.zeros_like(input_mean)
bias_name = x.name + "_bias"
fused_bias = ir.tensor((original_bias - input_mean) * scale_factor + beta)
return op.op(
self.op_type,
inputs=[
x,
op.initializer(fused_weights, name=inbound_node.inputs[1].name),
op.initializer(fused_bias, name=bias_name),
],
attributes=inbound_node.attributes,
)
def check(self, context, x, inbound_out: ir.Value, batchnorm_out: ir.Value) -> MatchResult:
del context # Unused
check_result = MatchResult()
inbound_node = inbound_out.producer()
batchnorm_node = batchnorm_out.producer()
# Check that inbound weights + (inbound bias) + batchnorm params are initializers
# and that they are not graph inputs
initializers = [inbound_node.inputs[1], *batchnorm_node.inputs[1:]]
if len(inbound_node.inputs) > 2:
initializers.append(inbound_node.inputs[2])
for initializer in initializers:
if not initializer.is_initializer() or initializer.const_value is None:
return check_result.fail(f"{initializer.name} is not a constant initializer.")
if initializer.is_graph_input():
return check_result.fail(f"{initializer.name} is a graph input.")
return check_result
class FuseBatchNormIntoConv(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(Conv(x))`` with ``Conv(x)``."""
def __init__(self):
super().__init__("Conv")
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return 0
def pattern(self, op, x):
return op.BatchNormalization(
op.Conv(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
_allow_other_inputs=True,
_outputs=["batchnorm_out"],
)
class FuseBatchNormIntoConvTranspose(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(ConvTranspose(x))`` with ``ConvTranspose(x)``."""
def __init__(self):
super().__init__("ConvTranspose")
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return 1
def pattern(self, op, x):
return op.BatchNormalization(
op.ConvTranspose(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
_allow_other_inputs=True,
_outputs=["batchnorm_out"],
)
class FuseBatchNormIntoGemm(_FuseBatchNormBase):
"""Replaces ``BatchNormalization(Gemm(x))`` with ``Gemm(x)``."""
def __init__(self):
super().__init__("Gemm")
def get_filters_axis(self, attributes: Mapping[str, ir.Attr]) -> int:
return (
0 if attributes.get("transB") is not None and attributes["transB"].as_int() else 1
)
def pattern(self, op, x):
return op.BatchNormalization(
op.Gemm(x, _allow_other_inputs=True, _outputs=["inbound_out"]),
_allow_other_inputs=True,
_outputs=["batchnorm_out"],
)
fuse_batchnorm_into_conv_rule = FuseBatchNormIntoConv().rule()
fuse_batchnorm_into_convtranspose_rule = FuseBatchNormIntoConvTranspose().rule()
fuse_batchnorm_into_gemm_rule = FuseBatchNormIntoGemm().rule()
def fuse_batchnorm_rule_set() -> RewriteRuleSet:
"""Returns a set of rewrite rules that fuse BatchNormalization nodes
into preceding nodes such as Conv, ConvTranspose, and Gemm.
Returns:
RewriteRuleSet
"""
return RewriteRuleSet(
[
fuse_batchnorm_into_conv_rule,
fuse_batchnorm_into_convtranspose_rule,
fuse_batchnorm_into_gemm_rule,
]
)