File size: 5,135 Bytes
be903e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | // Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "fuse_conv1d_batchnorm1d.h"
#include "pass_level2.h"
#include <math.h>
#include <string.h>
namespace pnnx {
class fuse_conv1d_batchnorm1d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.Conv1d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
nn.BatchNorm1d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}
const char* type_str() const
{
return "nn.Conv1d";
}
const char* name_str() const
{
return "convbn1d";
}
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = captured_params.at("padding_mode");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;
// resolve merged conv2d weight and bias
int channels = captured_params.at("num_features").i;
float bn_eps = captured_params.at("eps").f;
bool has_bn_affine = captured_params.at("affine").b;
bool has_conv_bias = captured_params.at("bias").b;
auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();
// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
// value = value * b + a
std::vector<float> a(channels);
std::vector<float> b(channels);
for (int i = 0; i < channels; i++)
{
double sqrt_var = sqrt(bn_running_var[i] + bn_eps);
if (has_bn_affine)
{
a[i] = (float)(bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var);
b[i] = (float)(bn_weight[i] / sqrt_var);
}
else
{
a[i] = (float)(-bn_running_mean[i] / sqrt_var);
b[i] = (float)(1.f / sqrt_var);
}
}
op->attrs["weight"] = captured_attrs.at("op_0.weight");
if (has_conv_bias)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
else
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}
auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();
const int outch = captured_params.at("out_channels").i;
const int weight_per_outch = op->attrs["weight"].elemcount() / outch;
for (int i = 0; i < channels; i++)
{
float* conv_weight_outch = (float*)conv_weight.data() + weight_per_outch * i;
for (int j = 0; j < weight_per_outch; j++)
{
conv_weight_outch[j] *= b[i];
}
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}
op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};
void fuse_conv1d_batchnorm1d(Graph& graph)
{
fuse_conv1d_batchnorm1d_pass a;
int opindex = 0;
pnnx_graph_rewrite(graph, &a, opindex);
}
} // namespace pnnx
|