ncnn / src /layer /x86 /convolution_2x2_pack8.h
camenduru's picture
thanks to ncnn ❤
be903e2
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
static void conv2x2s1_pack8_avx(const Mat& bottom_blob, Mat& top_blob, const Mat& kernel, const Mat& _bias, const Option& opt)
{
int inch = bottom_blob.c;
int outw = top_blob.w;
int outh = top_blob.h;
int outch = top_blob.c;
const float* bias = _bias;
#pragma omp parallel for num_threads(opt.num_threads)
for (int p = 0; p < outch; p++)
{
Mat out0 = top_blob.channel(p);
__m256 _bias0 = bias ? _mm256_loadu_ps((const float*)bias + p * 8) : _mm256_set1_ps(0.f);
out0.fill(_bias0);
for (int q = 0; q < inch; q++)
{
float* outptr0 = out0.row(0);
const Mat img0 = bottom_blob.channel(q);
const float* r0 = img0.row(0);
const float* r1 = img0.row(1);
const float* kptr = (const float*)kernel.channel(p).row(q);
// const float* kptr = (const float*)kernel + 4 * inch * p * 64;
int i = 0;
for (; i < outh; i++)
{
int j = 0;
for (; j + 1 < outw; j += 2)
{
__m256 _sum0 = _mm256_loadu_ps(outptr0);
__m256 _sum1 = _mm256_loadu_ps(outptr0 + 8);
__m256 _r00 = _mm256_broadcast_ss(r0);
__m256 _r01 = _mm256_broadcast_ss(r0 + 1);
__m256 _r02 = _mm256_broadcast_ss(r0 + 2);
__m256 _r03 = _mm256_broadcast_ss(r0 + 3);
__m256 _r04 = _mm256_broadcast_ss(r0 + 4);
__m256 _r05 = _mm256_broadcast_ss(r0 + 5);
__m256 _r06 = _mm256_broadcast_ss(r0 + 6);
__m256 _r07 = _mm256_broadcast_ss(r0 + 7);
r0 += 8;
__m256 _k00 = _mm256_loadu_ps(kptr);
__m256 _k01 = _mm256_loadu_ps(kptr + 8);
__m256 _k02 = _mm256_loadu_ps(kptr + 16);
__m256 _k03 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k00, _r00, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k01, _r01, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k02, _r02, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k03, _r03, _sum0);
__m256 _k04 = _mm256_loadu_ps(kptr);
__m256 _k05 = _mm256_loadu_ps(kptr + 8);
__m256 _k06 = _mm256_loadu_ps(kptr + 16);
__m256 _k07 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k04, _r04, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k05, _r05, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k06, _r06, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k07, _r07, _sum0);
//========================================
_r00 = _mm256_broadcast_ss(r0);
_r01 = _mm256_broadcast_ss(r0 + 1);
_r02 = _mm256_broadcast_ss(r0 + 2);
_r03 = _mm256_broadcast_ss(r0 + 3);
_r04 = _mm256_broadcast_ss(r0 + 4);
_r05 = _mm256_broadcast_ss(r0 + 5);
_r06 = _mm256_broadcast_ss(r0 + 6);
_r07 = _mm256_broadcast_ss(r0 + 7);
r0 += 8;
_sum1 = _mm256_comp_fmadd_ps(_k00, _r00, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k01, _r01, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k02, _r02, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k03, _r03, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k04, _r04, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k05, _r05, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k06, _r06, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k07, _r07, _sum1);
_k00 = _mm256_loadu_ps(kptr);
_k01 = _mm256_loadu_ps(kptr + 8);
_k02 = _mm256_loadu_ps(kptr + 16);
_k03 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k00, _r00, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k01, _r01, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k02, _r02, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k03, _r03, _sum0);
_k04 = _mm256_loadu_ps(kptr);
_k05 = _mm256_loadu_ps(kptr + 8);
_k06 = _mm256_loadu_ps(kptr + 16);
_k07 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k04, _r04, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k05, _r05, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k06, _r06, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k07, _r07, _sum0);
_r00 = _mm256_broadcast_ss(r0);
_r01 = _mm256_broadcast_ss(r0 + 1);
_r02 = _mm256_broadcast_ss(r0 + 2);
_r03 = _mm256_broadcast_ss(r0 + 3);
_r04 = _mm256_broadcast_ss(r0 + 4);
_r05 = _mm256_broadcast_ss(r0 + 5);
_r06 = _mm256_broadcast_ss(r0 + 6);
_r07 = _mm256_broadcast_ss(r0 + 7);
_sum1 = _mm256_comp_fmadd_ps(_k00, _r00, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k01, _r01, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k02, _r02, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k03, _r03, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k04, _r04, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k05, _r05, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k06, _r06, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k07, _r07, _sum1);
//===============
__m256 _r10 = _mm256_broadcast_ss(r1);
__m256 _r11 = _mm256_broadcast_ss(r1 + 1);
__m256 _r12 = _mm256_broadcast_ss(r1 + 2);
__m256 _r13 = _mm256_broadcast_ss(r1 + 3);
__m256 _r14 = _mm256_broadcast_ss(r1 + 4);
__m256 _r15 = _mm256_broadcast_ss(r1 + 5);
__m256 _r16 = _mm256_broadcast_ss(r1 + 6);
__m256 _r17 = _mm256_broadcast_ss(r1 + 7);
__m256 _k10 = _mm256_loadu_ps(kptr);
__m256 _k11 = _mm256_loadu_ps(kptr + 8);
__m256 _k12 = _mm256_loadu_ps(kptr + 16);
__m256 _k13 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k10, _r10, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k11, _r11, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k12, _r12, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k13, _r13, _sum0);
__m256 _k14 = _mm256_loadu_ps(kptr);
__m256 _k15 = _mm256_loadu_ps(kptr + 8);
__m256 _k16 = _mm256_loadu_ps(kptr + 16);
__m256 _k17 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k14, _r14, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k15, _r15, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k16, _r16, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k17, _r17, _sum0);
//=======================================
r1 += 8;
_r10 = _mm256_broadcast_ss(r1);
_r11 = _mm256_broadcast_ss(r1 + 1);
_r12 = _mm256_broadcast_ss(r1 + 2);
_r13 = _mm256_broadcast_ss(r1 + 3);
_r14 = _mm256_broadcast_ss(r1 + 4);
_r15 = _mm256_broadcast_ss(r1 + 5);
_r16 = _mm256_broadcast_ss(r1 + 6);
_r17 = _mm256_broadcast_ss(r1 + 7);
_sum1 = _mm256_comp_fmadd_ps(_k10, _r10, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k11, _r11, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k12, _r12, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k13, _r13, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k14, _r14, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k15, _r15, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k16, _r16, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k17, _r17, _sum1);
_k10 = _mm256_loadu_ps(kptr);
_k11 = _mm256_loadu_ps(kptr + 8);
_k12 = _mm256_loadu_ps(kptr + 16);
_k13 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum0 = _mm256_comp_fmadd_ps(_k10, _r10, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k11, _r11, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k12, _r12, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k13, _r13, _sum0);
_k14 = _mm256_loadu_ps(kptr);
_k15 = _mm256_loadu_ps(kptr + 8);
_k16 = _mm256_loadu_ps(kptr + 16);
_k17 = _mm256_loadu_ps(kptr + 24);
_sum0 = _mm256_comp_fmadd_ps(_k14, _r14, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k15, _r15, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k16, _r16, _sum0);
_sum0 = _mm256_comp_fmadd_ps(_k17, _r17, _sum0);
r1 += 8;
_r10 = _mm256_broadcast_ss(r1);
_r11 = _mm256_broadcast_ss(r1 + 1);
_r12 = _mm256_broadcast_ss(r1 + 2);
_r13 = _mm256_broadcast_ss(r1 + 3);
_r14 = _mm256_broadcast_ss(r1 + 4);
_r15 = _mm256_broadcast_ss(r1 + 5);
_r16 = _mm256_broadcast_ss(r1 + 6);
_r17 = _mm256_broadcast_ss(r1 + 7);
_sum1 = _mm256_comp_fmadd_ps(_k10, _r10, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k11, _r11, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k12, _r12, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k13, _r13, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k14, _r14, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k15, _r15, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k16, _r16, _sum1);
_sum1 = _mm256_comp_fmadd_ps(_k17, _r17, _sum1);
kptr -= 224;
_mm256_storeu_ps(outptr0, _sum0);
_mm256_storeu_ps(outptr0 + 8, _sum1);
outptr0 += 16;
}
for (; j < outw; j++)
{
__m256 _sum = _mm256_loadu_ps(outptr0);
__m256 _r00 = _mm256_broadcast_ss(r0);
__m256 _r01 = _mm256_broadcast_ss(r0 + 1);
__m256 _r02 = _mm256_broadcast_ss(r0 + 2);
__m256 _r03 = _mm256_broadcast_ss(r0 + 3);
__m256 _r04 = _mm256_broadcast_ss(r0 + 4);
__m256 _r05 = _mm256_broadcast_ss(r0 + 5);
__m256 _r06 = _mm256_broadcast_ss(r0 + 6);
__m256 _r07 = _mm256_broadcast_ss(r0 + 7);
__m256 _k00 = _mm256_loadu_ps(kptr);
__m256 _k01 = _mm256_loadu_ps(kptr + 8);
__m256 _k02 = _mm256_loadu_ps(kptr + 16);
__m256 _k03 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k00, _r00, _sum);
_sum = _mm256_comp_fmadd_ps(_k01, _r01, _sum);
_sum = _mm256_comp_fmadd_ps(_k02, _r02, _sum);
_sum = _mm256_comp_fmadd_ps(_k03, _r03, _sum);
__m256 _k04 = _mm256_loadu_ps(kptr);
__m256 _k05 = _mm256_loadu_ps(kptr + 8);
__m256 _k06 = _mm256_loadu_ps(kptr + 16);
__m256 _k07 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k04, _r04, _sum);
_sum = _mm256_comp_fmadd_ps(_k05, _r05, _sum);
_sum = _mm256_comp_fmadd_ps(_k06, _r06, _sum);
_sum = _mm256_comp_fmadd_ps(_k07, _r07, _sum);
//========================================
r0 += 8;
_r00 = _mm256_broadcast_ss(r0);
_r01 = _mm256_broadcast_ss(r0 + 1);
_r02 = _mm256_broadcast_ss(r0 + 2);
_r03 = _mm256_broadcast_ss(r0 + 3);
_r04 = _mm256_broadcast_ss(r0 + 4);
_r05 = _mm256_broadcast_ss(r0 + 5);
_r06 = _mm256_broadcast_ss(r0 + 6);
_r07 = _mm256_broadcast_ss(r0 + 7);
_k00 = _mm256_loadu_ps(kptr);
_k01 = _mm256_loadu_ps(kptr + 8);
_k02 = _mm256_loadu_ps(kptr + 16);
_k03 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k00, _r00, _sum);
_sum = _mm256_comp_fmadd_ps(_k01, _r01, _sum);
_sum = _mm256_comp_fmadd_ps(_k02, _r02, _sum);
_sum = _mm256_comp_fmadd_ps(_k03, _r03, _sum);
_k04 = _mm256_loadu_ps(kptr);
_k05 = _mm256_loadu_ps(kptr + 8);
_k06 = _mm256_loadu_ps(kptr + 16);
_k07 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k04, _r04, _sum);
_sum = _mm256_comp_fmadd_ps(_k05, _r05, _sum);
_sum = _mm256_comp_fmadd_ps(_k06, _r06, _sum);
_sum = _mm256_comp_fmadd_ps(_k07, _r07, _sum);
//===============
__m256 _r10 = _mm256_broadcast_ss(r1);
__m256 _r11 = _mm256_broadcast_ss(r1 + 1);
__m256 _r12 = _mm256_broadcast_ss(r1 + 2);
__m256 _r13 = _mm256_broadcast_ss(r1 + 3);
__m256 _r14 = _mm256_broadcast_ss(r1 + 4);
__m256 _r15 = _mm256_broadcast_ss(r1 + 5);
__m256 _r16 = _mm256_broadcast_ss(r1 + 6);
__m256 _r17 = _mm256_broadcast_ss(r1 + 7);
__m256 _k10 = _mm256_loadu_ps(kptr);
__m256 _k11 = _mm256_loadu_ps(kptr + 8);
__m256 _k12 = _mm256_loadu_ps(kptr + 16);
__m256 _k13 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k10, _r10, _sum);
_sum = _mm256_comp_fmadd_ps(_k11, _r11, _sum);
_sum = _mm256_comp_fmadd_ps(_k12, _r12, _sum);
_sum = _mm256_comp_fmadd_ps(_k13, _r13, _sum);
__m256 _k14 = _mm256_loadu_ps(kptr);
__m256 _k15 = _mm256_loadu_ps(kptr + 8);
__m256 _k16 = _mm256_loadu_ps(kptr + 16);
__m256 _k17 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k14, _r14, _sum);
_sum = _mm256_comp_fmadd_ps(_k15, _r15, _sum);
_sum = _mm256_comp_fmadd_ps(_k16, _r16, _sum);
_sum = _mm256_comp_fmadd_ps(_k17, _r17, _sum);
//=======================================
r1 += 8;
_r10 = _mm256_broadcast_ss(r1);
_r11 = _mm256_broadcast_ss(r1 + 1);
_r12 = _mm256_broadcast_ss(r1 + 2);
_r13 = _mm256_broadcast_ss(r1 + 3);
_r14 = _mm256_broadcast_ss(r1 + 4);
_r15 = _mm256_broadcast_ss(r1 + 5);
_r16 = _mm256_broadcast_ss(r1 + 6);
_r17 = _mm256_broadcast_ss(r1 + 7);
_k10 = _mm256_loadu_ps(kptr);
_k11 = _mm256_loadu_ps(kptr + 8);
_k12 = _mm256_loadu_ps(kptr + 16);
_k13 = _mm256_loadu_ps(kptr + 24);
kptr += 32;
_sum = _mm256_comp_fmadd_ps(_k10, _r10, _sum);
_sum = _mm256_comp_fmadd_ps(_k11, _r11, _sum);
_sum = _mm256_comp_fmadd_ps(_k12, _r12, _sum);
_sum = _mm256_comp_fmadd_ps(_k13, _r13, _sum);
_k14 = _mm256_loadu_ps(kptr);
_k15 = _mm256_loadu_ps(kptr + 8);
_k16 = _mm256_loadu_ps(kptr + 16);
_k17 = _mm256_loadu_ps(kptr + 24);
_sum = _mm256_comp_fmadd_ps(_k14, _r14, _sum);
_sum = _mm256_comp_fmadd_ps(_k15, _r15, _sum);
_sum = _mm256_comp_fmadd_ps(_k16, _r16, _sum);
_sum = _mm256_comp_fmadd_ps(_k17, _r17, _sum);
kptr -= 224;
_mm256_storeu_ps(outptr0, _sum);
outptr0 += 8;
}
r0 += 8;
r1 += 8;
}
}
}
}