ncnn / src /layer /x86 /convolution1d_packed.h
camenduru's picture
thanks to ncnn ❤
be903e2
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
static void convolution1d_transform_kernel_packed(const Mat& kernel, Mat& kernel_tm, int inh, int outh, int kernel_w)
{
// src = kw-inh-outh
// dst = pb-pa-kw-inh/pa-outh/pb
// clang-format off
// *INDENT-OFF*
#if __SSE2__
#if __AVX__
#if __AVX512F__
if (outh >= 16)
{
if (inh >= 16)
kernel_tm.create(16 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else if (inh >= 8)
kernel_tm.create(16 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else if (inh >= 4)
kernel_tm.create(16 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else if (inh >= 2)
kernel_tm.create(16 * 2 * kernel_w, inh / 2 + inh % 2, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else
kernel_tm.create(16 * kernel_w, inh, outh / 16 + (outh % 16) / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
}
else
#endif // __AVX512F__
if (outh >= 8)
{
#if __AVX512F__
if (inh >= 16)
kernel_tm.create(8 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else
#endif // __AVX512F__
if (inh >= 8)
kernel_tm.create(8 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else if (inh >= 4)
kernel_tm.create(8 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else if (inh >= 2)
kernel_tm.create(8 * 2 * kernel_w, inh / 2 + inh % 2, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
else
kernel_tm.create(8 * kernel_w, inh, outh / 8 + (outh % 8) / 4 + (outh % 4) / 2 + outh % 2);
}
else
#endif // __AVX__
if (outh >= 4)
{
#if __AVX__
#if __AVX512F__
if (inh >= 16)
kernel_tm.create(4 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2);
else
#endif // __AVX512F__
if (inh >= 8)
kernel_tm.create(4 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2);
else
#endif // __AVX__
if (inh >= 4)
kernel_tm.create(4 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2);
else if (inh >= 2)
kernel_tm.create(4 * 2 * kernel_w, inh / 2 + inh % 2, outh / 4 + (outh % 4) / 2 + outh % 2);
else
kernel_tm.create(4 * kernel_w, inh, outh / 4 + (outh % 4) / 2 + outh % 2);
}
else
#endif // __SSE2__
if (outh >= 2)
{
#if __SSE2__
#if __AVX__
#if __AVX512F__
if (inh >= 16)
kernel_tm.create(2 * 16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2);
else
#endif // __AVX512F__
if (inh >= 8)
kernel_tm.create(2 * 8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2);
else
#endif // __AVX__
if (inh >= 4)
kernel_tm.create(2 * 4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh / 2 + outh % 2);
else
#endif // __SSE2__
if (inh >= 2)
kernel_tm.create(2 * 2 * kernel_w, inh / 2 + inh % 2, outh / 2 + outh % 2);
else
kernel_tm.create(2 * kernel_w, inh, outh / 2 + outh % 2);
}
else
{
#if __SSE2__
#if __AVX__
#if __AVX512F__
if (inh >= 16)
kernel_tm.create(16 * kernel_w, inh / 16 + (inh % 16) / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh);
else
#endif // __AVX512F__
if (inh >= 8)
kernel_tm.create(8 * kernel_w, inh / 8 + (inh % 8) / 4 + (inh % 4) / 2 + inh % 2, outh);
else
#endif // __AVX__
if (inh >= 4)
kernel_tm.create(4 * kernel_w, inh / 4 + (inh % 4) / 2 + inh % 2, outh);
else
#endif // __SSE2__
if (inh >= 2)
kernel_tm.create(2 * kernel_w, inh / 2 + inh % 2, outh);
else
kernel_tm.create(kernel_w, inh, outh);
}
// *INDENT-ON*
// clang-format on
int q = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; q + 15 < outh; q += 16)
{
const float* kptr0 = (const float*)kernel + q * inh * kernel_w;
const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w;
const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w;
const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w;
const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w;
const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w;
const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w;
const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w;
const float* kptr8 = (const float*)kernel + (q + 8) * inh * kernel_w;
const float* kptr9 = (const float*)kernel + (q + 9) * inh * kernel_w;
const float* kptra = (const float*)kernel + (q + 10) * inh * kernel_w;
const float* kptrb = (const float*)kernel + (q + 11) * inh * kernel_w;
const float* kptrc = (const float*)kernel + (q + 12) * inh * kernel_w;
const float* kptrd = (const float*)kernel + (q + 13) * inh * kernel_w;
const float* kptre = (const float*)kernel + (q + 14) * inh * kernel_w;
const float* kptrf = (const float*)kernel + (q + 15) * inh * kernel_w;
float* g00 = kernel_tm.channel(q / 16);
__m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(kernel_w));
int p = 0;
for (; p + 15 < inh; p += 16)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
const float* k4 = kptr4 + k;
const float* k5 = kptr5 + k;
const float* k6 = kptr6 + k;
const float* k7 = kptr7 + k;
const float* k8 = kptr8 + k;
const float* k9 = kptr9 + k;
const float* ka = kptra + k;
const float* kb = kptrb + k;
const float* kc = kptrc + k;
const float* kd = kptrd + k;
const float* ke = kptre + k;
const float* kf = kptrf + k;
__m512 _k0 = _mm512_i32gather_ps(_vindex, k0, sizeof(float));
__m512 _k1 = _mm512_i32gather_ps(_vindex, k1, sizeof(float));
__m512 _k2 = _mm512_i32gather_ps(_vindex, k2, sizeof(float));
__m512 _k3 = _mm512_i32gather_ps(_vindex, k3, sizeof(float));
__m512 _k4 = _mm512_i32gather_ps(_vindex, k4, sizeof(float));
__m512 _k5 = _mm512_i32gather_ps(_vindex, k5, sizeof(float));
__m512 _k6 = _mm512_i32gather_ps(_vindex, k6, sizeof(float));
__m512 _k7 = _mm512_i32gather_ps(_vindex, k7, sizeof(float));
__m512 _k8 = _mm512_i32gather_ps(_vindex, k8, sizeof(float));
__m512 _k9 = _mm512_i32gather_ps(_vindex, k9, sizeof(float));
__m512 _ka = _mm512_i32gather_ps(_vindex, ka, sizeof(float));
__m512 _kb = _mm512_i32gather_ps(_vindex, kb, sizeof(float));
__m512 _kc = _mm512_i32gather_ps(_vindex, kc, sizeof(float));
__m512 _kd = _mm512_i32gather_ps(_vindex, kd, sizeof(float));
__m512 _ke = _mm512_i32gather_ps(_vindex, ke, sizeof(float));
__m512 _kf = _mm512_i32gather_ps(_vindex, kf, sizeof(float));
transpose16x16_ps(_k0, _k1, _k2, _k3, _k4, _k5, _k6, _k7, _k8, _k9, _ka, _kb, _kc, _kd, _ke, _kf);
_mm512_store_ps(g00, _k0);
_mm512_store_ps(g00 + 16, _k1);
_mm512_store_ps(g00 + 16 * 2, _k2);
_mm512_store_ps(g00 + 16 * 3, _k3);
_mm512_store_ps(g00 + 16 * 4, _k4);
_mm512_store_ps(g00 + 16 * 5, _k5);
_mm512_store_ps(g00 + 16 * 6, _k6);
_mm512_store_ps(g00 + 16 * 7, _k7);
_mm512_store_ps(g00 + 16 * 8, _k8);
_mm512_store_ps(g00 + 16 * 9, _k9);
_mm512_store_ps(g00 + 16 * 10, _ka);
_mm512_store_ps(g00 + 16 * 11, _kb);
_mm512_store_ps(g00 + 16 * 12, _kc);
_mm512_store_ps(g00 + 16 * 13, _kd);
_mm512_store_ps(g00 + 16 * 14, _ke);
_mm512_store_ps(g00 + 16 * 15, _kf);
g00 += 256;
}
kptr0 += kernel_w * 16;
kptr1 += kernel_w * 16;
kptr2 += kernel_w * 16;
kptr3 += kernel_w * 16;
kptr4 += kernel_w * 16;
kptr5 += kernel_w * 16;
kptr6 += kernel_w * 16;
kptr7 += kernel_w * 16;
kptr8 += kernel_w * 16;
kptr9 += kernel_w * 16;
kptra += kernel_w * 16;
kptrb += kernel_w * 16;
kptrc += kernel_w * 16;
kptrd += kernel_w * 16;
kptre += kernel_w * 16;
kptrf += kernel_w * 16;
}
_vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(inh));
for (; p + 7 < inh; p += 8)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
for (int i = 0; i < 8; i++)
{
__m512 _k0 = _mm512_i32gather_ps(_vindex, k0, sizeof(float));
_mm512_store_ps(g00, _k0);
k0 += kernel_w;
g00 += 16;
}
}
kptr0 += kernel_w * 8;
}
for (; p + 3 < inh; p += 4)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
for (int i = 0; i < 4; i++)
{
__m512 _k0 = _mm512_i32gather_ps(_vindex, k0, sizeof(float));
_mm512_store_ps(g00, _k0);
k0 += kernel_w;
g00 += 16;
}
}
kptr0 += kernel_w * 4;
}
for (; p + 1 < inh; p += 2)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
for (int i = 0; i < 2; i++)
{
__m512 _k0 = _mm512_i32gather_ps(_vindex, k0, sizeof(float));
_mm512_store_ps(g00, _k0);
k0 += kernel_w;
g00 += 16;
}
}
kptr0 += kernel_w * 2;
}
for (; p < inh; p++)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
__m512 _k0 = _mm512_i32gather_ps(_vindex, k0, sizeof(float));
_mm512_store_ps(g00, _k0);
g00 += 16;
}
}
}
#endif // __AVX512F__
for (; q + 7 < outh; q += 8)
{
const float* kptr0 = (const float*)kernel + q * inh * kernel_w;
const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w;
const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w;
const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w;
const float* kptr4 = (const float*)kernel + (q + 4) * inh * kernel_w;
const float* kptr5 = (const float*)kernel + (q + 5) * inh * kernel_w;
const float* kptr6 = (const float*)kernel + (q + 6) * inh * kernel_w;
const float* kptr7 = (const float*)kernel + (q + 7) * inh * kernel_w;
#if __AVX512F__
float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8);
#else
float* g00 = kernel_tm.channel(q / 8);
#endif
#if __AVX2__
__m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
_vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(kernel_w));
#if __AVX512F__
__m512i _vindex_512 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex_512 = _mm512_mullo_epi32(_vindex_512, _mm512_set1_epi32(kernel_w));
#endif // __AVX512F__
#endif // __AVX2__
int p = 0;
#if __AVX512F__
for (; p + 15 < inh; p += 16)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
const float* k4 = kptr4 + k;
const float* k5 = kptr5 + k;
const float* k6 = kptr6 + k;
const float* k7 = kptr7 + k;
__m512 _k0 = _mm512_i32gather_ps(_vindex_512, k0, sizeof(float));
__m512 _k1 = _mm512_i32gather_ps(_vindex_512, k1, sizeof(float));
__m512 _k2 = _mm512_i32gather_ps(_vindex_512, k2, sizeof(float));
__m512 _k3 = _mm512_i32gather_ps(_vindex_512, k3, sizeof(float));
__m512 _k4 = _mm512_i32gather_ps(_vindex_512, k4, sizeof(float));
__m512 _k5 = _mm512_i32gather_ps(_vindex_512, k5, sizeof(float));
__m512 _k6 = _mm512_i32gather_ps(_vindex_512, k6, sizeof(float));
__m512 _k7 = _mm512_i32gather_ps(_vindex_512, k7, sizeof(float));
transpose16x8_ps(_k0, _k1, _k2, _k3, _k4, _k5, _k6, _k7);
_mm512_storeu_ps(g00, _k0);
_mm512_storeu_ps(g00 + 16, _k1);
_mm512_storeu_ps(g00 + 16 * 2, _k2);
_mm512_storeu_ps(g00 + 16 * 3, _k3);
_mm512_storeu_ps(g00 + 16 * 4, _k4);
_mm512_storeu_ps(g00 + 16 * 5, _k5);
_mm512_storeu_ps(g00 + 16 * 6, _k6);
_mm512_storeu_ps(g00 + 16 * 7, _k7);
g00 += 128;
}
kptr0 += kernel_w * 16;
kptr1 += kernel_w * 16;
kptr2 += kernel_w * 16;
kptr3 += kernel_w * 16;
kptr4 += kernel_w * 16;
kptr5 += kernel_w * 16;
kptr6 += kernel_w * 16;
kptr7 += kernel_w * 16;
}
#endif // __AVX512F__
for (; p + 7 < inh; p += 8)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
const float* k4 = kptr4 + k;
const float* k5 = kptr5 + k;
const float* k6 = kptr6 + k;
const float* k7 = kptr7 + k;
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex, sizeof(float));
__m256 _k1 = _mm256_i32gather_ps(k1, _vindex, sizeof(float));
__m256 _k2 = _mm256_i32gather_ps(k2, _vindex, sizeof(float));
__m256 _k3 = _mm256_i32gather_ps(k3, _vindex, sizeof(float));
__m256 _k4 = _mm256_i32gather_ps(k4, _vindex, sizeof(float));
__m256 _k5 = _mm256_i32gather_ps(k5, _vindex, sizeof(float));
__m256 _k6 = _mm256_i32gather_ps(k6, _vindex, sizeof(float));
__m256 _k7 = _mm256_i32gather_ps(k7, _vindex, sizeof(float));
transpose8x8_ps(_k0, _k1, _k2, _k3, _k4, _k5, _k6, _k7);
_mm256_store_ps(g00, _k0);
_mm256_store_ps(g00 + 8, _k1);
_mm256_store_ps(g00 + 8 * 2, _k2);
_mm256_store_ps(g00 + 8 * 3, _k3);
_mm256_store_ps(g00 + 8 * 4, _k4);
_mm256_store_ps(g00 + 8 * 5, _k5);
_mm256_store_ps(g00 + 8 * 6, _k6);
_mm256_store_ps(g00 + 8 * 7, _k7);
g00 += 64;
#else // __AVX2__
for (int i = 0; i < 8; i++)
{
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
g00[4] = k4[0];
g00[5] = k5[0];
g00[6] = k6[0];
g00[7] = k7[0];
k0 += kernel_w;
k1 += kernel_w;
k2 += kernel_w;
k3 += kernel_w;
k4 += kernel_w;
k5 += kernel_w;
k6 += kernel_w;
k7 += kernel_w;
g00 += 8;
}
#endif // __AVX2__
}
kptr0 += kernel_w * 8;
kptr1 += kernel_w * 8;
kptr2 += kernel_w * 8;
kptr3 += kernel_w * 8;
kptr4 += kernel_w * 8;
kptr5 += kernel_w * 8;
kptr6 += kernel_w * 8;
kptr7 += kernel_w * 8;
}
#if __AVX2__
_vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(inh));
#endif // __AVX2__
for (; p + 3 < inh; p += 4)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
#if !__AVX2__
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
const float* k4 = kptr4 + k;
const float* k5 = kptr5 + k;
const float* k6 = kptr6 + k;
const float* k7 = kptr7 + k;
#endif // !__AVX2__
for (int i = 0; i < 4; i++)
{
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex, sizeof(float));
_mm256_store_ps(g00, _k0);
k0 += kernel_w;
g00 += 8;
#else // __AVX2__
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
g00[4] = k4[0];
g00[5] = k5[0];
g00[6] = k6[0];
g00[7] = k7[0];
k0 += kernel_w;
k1 += kernel_w;
k2 += kernel_w;
k3 += kernel_w;
k4 += kernel_w;
k5 += kernel_w;
k6 += kernel_w;
k7 += kernel_w;
g00 += 8;
#endif // __AVX2__
}
}
kptr0 += kernel_w * 4;
#if !__AVX2__
kptr1 += kernel_w * 4;
kptr2 += kernel_w * 4;
kptr3 += kernel_w * 4;
kptr4 += kernel_w * 4;
kptr5 += kernel_w * 4;
kptr6 += kernel_w * 4;
kptr7 += kernel_w * 4;
#endif // !__AVX2__
}
for (; p + 1 < inh; p += 2)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
#if !__AVX2__
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
const float* k4 = kptr4 + k;
const float* k5 = kptr5 + k;
const float* k6 = kptr6 + k;
const float* k7 = kptr7 + k;
#endif // !__AVX2__
for (int i = 0; i < 2; i++)
{
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex, sizeof(float));
_mm256_store_ps(g00, _k0);
k0 += kernel_w;
g00 += 8;
#else // __AVX2__
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
g00[4] = k4[0];
g00[5] = k5[0];
g00[6] = k6[0];
g00[7] = k7[0];
k0 += kernel_w;
k1 += kernel_w;
k2 += kernel_w;
k3 += kernel_w;
k4 += kernel_w;
k5 += kernel_w;
k6 += kernel_w;
k7 += kernel_w;
g00 += 8;
#endif // __AVX2__
}
}
kptr0 += kernel_w * 2;
#if !__AVX2__
kptr1 += kernel_w * 2;
kptr2 += kernel_w * 2;
kptr3 += kernel_w * 2;
kptr4 += kernel_w * 2;
kptr5 += kernel_w * 2;
kptr6 += kernel_w * 2;
kptr7 += kernel_w * 2;
#endif // !__AVX2__
}
for (; p < inh; p++)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex, sizeof(float));
_mm256_store_ps(g00, _k0);
g00 += 8;
#else // __AVX2__
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
const float* k4 = kptr4 + k;
const float* k5 = kptr5 + k;
const float* k6 = kptr6 + k;
const float* k7 = kptr7 + k;
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
g00[4] = k4[0];
g00[5] = k5[0];
g00[6] = k6[0];
g00[7] = k7[0];
g00 += 8;
#endif // __AVX2__
}
}
}
#endif // __AVX__
for (; q + 3 < outh; q += 4)
{
const float* kptr0 = (const float*)kernel + q * inh * kernel_w;
const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w;
const float* kptr2 = (const float*)kernel + (q + 2) * inh * kernel_w;
const float* kptr3 = (const float*)kernel + (q + 3) * inh * kernel_w;
#if __AVX512F__
float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4);
#elif __AVX__
float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4);
#else
float* g00 = kernel_tm.channel(q / 4);
#endif
#if __AVX2__
__m128i _vindex = _mm_setr_epi32(0, 1, 2, 3);
_vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(kernel_w));
__m256i _vindex_256 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
_vindex_256 = _mm256_mullo_epi32(_vindex_256, _mm256_set1_epi32(kernel_w));
#if __AVX512F__
__m512i _vindex_512 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex_512 = _mm512_mullo_epi32(_vindex_512, _mm512_set1_epi32(kernel_w));
#endif // __AVX512F__
#endif // __AVX2__
int p = 0;
#if __AVX__
#if __AVX512F__
for (; p + 15 < inh; p += 16)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
__m512 _k0 = _mm512_i32gather_ps(_vindex_512, k0, sizeof(float));
__m512 _k1 = _mm512_i32gather_ps(_vindex_512, k1, sizeof(float));
__m512 _k2 = _mm512_i32gather_ps(_vindex_512, k2, sizeof(float));
__m512 _k3 = _mm512_i32gather_ps(_vindex_512, k3, sizeof(float));
transpose16x4_ps(_k0, _k1, _k2, _k3);
_mm512_storeu_ps(g00, _k0);
_mm512_storeu_ps(g00 + 16, _k1);
_mm512_storeu_ps(g00 + 16 * 2, _k2);
_mm512_storeu_ps(g00 + 16 * 3, _k3);
g00 += 64;
}
kptr0 += kernel_w * 16;
kptr1 += kernel_w * 16;
kptr2 += kernel_w * 16;
kptr3 += kernel_w * 16;
}
#endif // __AVX512F__
for (; p + 7 < inh; p += 8)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex_256, sizeof(float));
__m256 _k1 = _mm256_i32gather_ps(k1, _vindex_256, sizeof(float));
__m256 _k2 = _mm256_i32gather_ps(k2, _vindex_256, sizeof(float));
__m256 _k3 = _mm256_i32gather_ps(k3, _vindex_256, sizeof(float));
transpose8x4_ps(_k0, _k1, _k2, _k3);
_mm256_storeu_ps(g00, _k0);
_mm256_storeu_ps(g00 + 8, _k1);
_mm256_storeu_ps(g00 + 8 * 2, _k2);
_mm256_storeu_ps(g00 + 8 * 3, _k3);
g00 += 32;
#else // __AVX2__
for (int i = 0; i < 8; i++)
{
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
k0 += kernel_w;
k1 += kernel_w;
k2 += kernel_w;
k3 += kernel_w;
g00 += 4;
}
#endif // __AVX2__
}
kptr0 += kernel_w * 8;
kptr1 += kernel_w * 8;
kptr2 += kernel_w * 8;
kptr3 += kernel_w * 8;
}
#endif // __AVX__
for (; p + 3 < inh; p += 4)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
#if __AVX2__
__m128 _k0 = _mm_i32gather_ps(k0, _vindex, sizeof(float));
__m128 _k1 = _mm_i32gather_ps(k1, _vindex, sizeof(float));
__m128 _k2 = _mm_i32gather_ps(k2, _vindex, sizeof(float));
__m128 _k3 = _mm_i32gather_ps(k3, _vindex, sizeof(float));
_MM_TRANSPOSE4_PS(_k0, _k1, _k2, _k3);
_mm_store_ps(g00, _k0);
_mm_store_ps(g00 + 4, _k1);
_mm_store_ps(g00 + 4 * 2, _k2);
_mm_store_ps(g00 + 4 * 3, _k3);
g00 += 16;
#else // __AVX2__
for (int i = 0; i < 4; i++)
{
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
k0 += kernel_w;
k1 += kernel_w;
k2 += kernel_w;
k3 += kernel_w;
g00 += 4;
}
#endif // __AVX2__
}
kptr0 += kernel_w * 4;
kptr1 += kernel_w * 4;
kptr2 += kernel_w * 4;
kptr3 += kernel_w * 4;
}
#if __AVX2__
_vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(inh));
#endif // __AVX2__
for (; p + 1 < inh; p += 2)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
#if !__AVX2__
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
#endif // !__AVX2__
for (int i = 0; i < 2; i++)
{
#if __AVX2__
__m128 _k0 = _mm_i32gather_ps(k0, _vindex, sizeof(float));
_mm_store_ps(g00, _k0);
k0 += kernel_w;
g00 += 4;
#else // __AVX2__
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
k0 += kernel_w;
k1 += kernel_w;
k2 += kernel_w;
k3 += kernel_w;
g00 += 4;
#endif // __AVX2__
}
}
kptr0 += kernel_w * 2;
#if !__AVX2__
kptr1 += kernel_w * 2;
kptr2 += kernel_w * 2;
kptr3 += kernel_w * 2;
#endif // !__AVX2__
}
for (; p < inh; p++)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
#if __AVX2__
__m128 _k0 = _mm_i32gather_ps(k0, _vindex, sizeof(float));
_mm_store_ps(g00, _k0);
g00 += 4;
#else // __AVX2__
const float* k1 = kptr1 + k;
const float* k2 = kptr2 + k;
const float* k3 = kptr3 + k;
g00[0] = k0[0];
g00[1] = k1[0];
g00[2] = k2[0];
g00[3] = k3[0];
g00 += 4;
#endif // __AVX2__
}
}
}
#endif // __SSE2__
for (; q + 1 < outh; q += 2)
{
const float* kptr0 = (const float*)kernel + q * inh * kernel_w;
const float* kptr1 = (const float*)kernel + (q + 1) * inh * kernel_w;
#if __AVX512F__
float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4 + (q % 4) / 2);
#elif __AVX__
float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2);
#elif __SSE2__
float* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2);
#else
float* g00 = kernel_tm.channel(q / 2);
#endif
#if __AVX2__
__m128i _vindex = _mm_setr_epi32(0, 1, 2, 3);
_vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(kernel_w));
__m256i _vindex_256 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
_vindex_256 = _mm256_mullo_epi32(_vindex_256, _mm256_set1_epi32(kernel_w));
#if __AVX512F__
__m512i _vindex_512 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex_512 = _mm512_mullo_epi32(_vindex_512, _mm512_set1_epi32(kernel_w));
#endif // __AVX512F__
#endif // __AVX2__
int p = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; p + 15 < inh; p += 16)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
__m512 _k0 = _mm512_i32gather_ps(_vindex_512, k0, sizeof(float));
__m512 _k1 = _mm512_i32gather_ps(_vindex_512, k1, sizeof(float));
_mm512_storeu_ps(g00, _k0);
_mm512_storeu_ps(g00 + 16, _k1);
g00 += 32;
}
kptr0 += kernel_w * 16;
kptr1 += kernel_w * 16;
}
#endif // __AVX512F__
for (; p + 7 < inh; p += 8)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex_256, sizeof(float));
__m256 _k1 = _mm256_i32gather_ps(k1, _vindex_256, sizeof(float));
_mm256_storeu_ps(g00, _k0);
_mm256_storeu_ps(g00 + 8, _k1);
g00 += 16;
#else // __AVX2__
g00[0] = k0[0];
g00[1] = k0[kernel_w];
g00[2] = k0[kernel_w * 2];
g00[3] = k0[kernel_w * 3];
g00[4] = k0[kernel_w * 4];
g00[5] = k0[kernel_w * 5];
g00[6] = k0[kernel_w * 6];
g00[7] = k0[kernel_w * 7];
g00[8] = k1[0];
g00[9] = k1[kernel_w];
g00[10] = k1[kernel_w * 2];
g00[11] = k1[kernel_w * 3];
g00[12] = k1[kernel_w * 4];
g00[13] = k1[kernel_w * 5];
g00[14] = k1[kernel_w * 6];
g00[15] = k1[kernel_w * 7];
g00 += 16;
#endif // __AVX2__
}
kptr0 += kernel_w * 8;
kptr1 += kernel_w * 8;
}
#endif // __AVX__
for (; p + 3 < inh; p += 4)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
#if __AVX2__
__m128 _k0 = _mm_i32gather_ps(k0, _vindex, sizeof(float));
__m128 _k1 = _mm_i32gather_ps(k1, _vindex, sizeof(float));
_mm_storeu_ps(g00, _k0);
_mm_storeu_ps(g00 + 4, _k1);
g00 += 8;
#else // __AVX2__
g00[0] = k0[0];
g00[1] = k0[kernel_w];
g00[2] = k0[kernel_w * 2];
g00[3] = k0[kernel_w * 3];
g00[4] = k1[0];
g00[5] = k1[kernel_w];
g00[6] = k1[kernel_w * 2];
g00[7] = k1[kernel_w * 3];
g00 += 8;
#endif // __AVX2__
}
kptr0 += kernel_w * 4;
kptr1 += kernel_w * 4;
}
#endif // __SSE2__
for (; p + 1 < inh; p += 2)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
for (int i = 0; i < 2; i++)
{
g00[0] = k0[0];
g00[1] = k1[0];
k0 += kernel_w;
k1 += kernel_w;
g00 += 2;
}
}
kptr0 += kernel_w * 2;
kptr1 += kernel_w * 2;
}
for (; p < inh; p++)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr0 + k;
const float* k1 = kptr1 + k;
g00[0] = k0[0];
g00[1] = k1[0];
g00 += 2;
}
}
}
for (; q < outh; q++)
{
const float* kptr = (const float*)kernel + q * inh * kernel_w;
#if __AVX512F__
float* g00 = kernel_tm.channel(q / 16 + (q % 16) / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2);
#elif __AVX__
float* g00 = kernel_tm.channel(q / 8 + (q % 8) / 4 + (q % 4) / 2 + q % 2);
#elif __SSE2__
float* g00 = kernel_tm.channel(q / 4 + (q % 4) / 2 + q % 2);
#else
float* g00 = kernel_tm.channel(q / 2 + q % 2);
#endif
#if __AVX2__
__m128i _vindex = _mm_setr_epi32(0, 1, 2, 3);
_vindex = _mm_mullo_epi32(_vindex, _mm_set1_epi32(kernel_w));
__m256i _vindex_256 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
_vindex_256 = _mm256_mullo_epi32(_vindex_256, _mm256_set1_epi32(kernel_w));
#if __AVX512F__
__m512i _vindex_512 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
_vindex_512 = _mm512_mullo_epi32(_vindex_512, _mm512_set1_epi32(kernel_w));
#endif // __AVX512F__
#endif // __AVX2__
int p = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
for (; p + 15 < inh; p += 16)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr + k;
__m512 _k0 = _mm512_i32gather_ps(_vindex_512, k0, sizeof(float));
_mm512_storeu_ps(g00, _k0);
g00 += 16;
}
kptr += kernel_w * 16;
}
#endif // __AVX512F__
for (; p + 7 < inh; p += 8)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr + k;
#if __AVX2__
__m256 _k0 = _mm256_i32gather_ps(k0, _vindex_256, sizeof(float));
_mm256_storeu_ps(g00, _k0);
g00 += 8;
#else // __AVX2__
for (int i = 0; i < 8; i++)
{
g00[0] = k0[0];
k0 += kernel_w;
g00 += 1;
}
#endif // __AVX2__
}
kptr += kernel_w * 8;
}
#endif // __AVX__
for (; p + 3 < inh; p += 4)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr + k;
#if __AVX2__
__m128 _k0 = _mm_i32gather_ps(k0, _vindex, sizeof(float));
_mm_storeu_ps(g00, _k0);
g00 += 4;
#else // __AVX2__
for (int i = 0; i < 4; i++)
{
g00[0] = k0[0];
k0 += kernel_w;
g00 += 1;
}
#endif // __AVX2__
}
kptr += kernel_w * 4;
}
#endif // __SSE2__
for (; p + 1 < inh; p += 2)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr + k;
for (int i = 0; i < 2; i++)
{
g00[0] = k0[0];
k0 += kernel_w;
g00 += 1;
}
}
kptr += kernel_w * 2;
}
for (; p < inh; p++)
{
for (int k = 0; k < kernel_w; k++)
{
const float* k0 = kptr + k;
g00[0] = k0[0];
g00++;
}
}
}
}
static void convolution1d_packed(const Mat& bottom_blob, Mat& top_blob, const Mat& weight_data_tm, const Mat& bias_data, int kernel_w, int dilation_w, int stride_w, int activation_type, const Mat& activation_params, const Option& opt)
{
const int elempack = bottom_blob.elempack;
const int inh = bottom_blob.h * elempack;
const int N = bottom_blob.w * elempack;
const int outw = top_blob.w;
const int out_elempack = top_blob.elempack;
const int outh = top_blob.h * out_elempack;
const int M = top_blob.w * out_elempack;
const float* bias_data_ptr = bias_data;
int nn_outh = 0;
int remain_outh_start = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
nn_outh = outh / 16;
#pragma omp parallel for num_threads(opt.num_threads)
for (int pp = 0; pp < nn_outh; pp++)
{
const int p = pp * 16;
// shadowed variable for less openmp task args
const int elempack = bottom_blob.elempack;
const int inh = bottom_blob.h * elempack;
const int outw = top_blob.w;
const int out_elempack = top_blob.elempack;
float* outptr = top_blob.row(p / out_elempack);
for (int j = 0; j < outw; j++)
{
__m512 _sum0 = _mm512_setzero_ps();
__m512 _sum1 = _mm512_setzero_ps();
__m512 _sum2 = _mm512_setzero_ps();
__m512 _sum3 = _mm512_setzero_ps();
if (bias_data_ptr)
{
_sum0 = _mm512_loadu_ps(bias_data_ptr + p);
}
const float* kptr = weight_data_tm.channel(p / 16);
int q = 0;
for (; q + 15 < inh; q += 16)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 16)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
__m512 _w8 = _mm512_load_ps(kptr + 16 * 8);
__m512 _w9 = _mm512_load_ps(kptr + 16 * 9);
__m512 _wa = _mm512_load_ps(kptr + 16 * 10);
__m512 _wb = _mm512_load_ps(kptr + 16 * 11);
__m512 _wc = _mm512_load_ps(kptr + 16 * 12);
__m512 _wd = _mm512_load_ps(kptr + 16 * 13);
__m512 _we = _mm512_load_ps(kptr + 16 * 14);
__m512 _wf = _mm512_load_ps(kptr + 16 * 15);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[4]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[5]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[6]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[7]), _sum3);
_sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r0[8]), _sum0);
_sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r0[9]), _sum1);
_sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r0[10]), _sum2);
_sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r0[11]), _sum3);
_sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r0[12]), _sum0);
_sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r0[13]), _sum1);
_sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r0[14]), _sum2);
_sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r0[15]), _sum3);
r0 += dilation_w * 16;
kptr += 256;
}
}
if (elempack == 8)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
__m512 _w8 = _mm512_load_ps(kptr + 16 * 8);
__m512 _w9 = _mm512_load_ps(kptr + 16 * 9);
__m512 _wa = _mm512_load_ps(kptr + 16 * 10);
__m512 _wb = _mm512_load_ps(kptr + 16 * 11);
__m512 _wc = _mm512_load_ps(kptr + 16 * 12);
__m512 _wd = _mm512_load_ps(kptr + 16 * 13);
__m512 _we = _mm512_load_ps(kptr + 16 * 14);
__m512 _wf = _mm512_load_ps(kptr + 16 * 15);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[4]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[5]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[6]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[7]), _sum3);
_sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r1[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r1[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r1[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r1[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r1[4]), _sum0);
_sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r1[5]), _sum1);
_sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r1[6]), _sum2);
_sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r1[7]), _sum3);
r0 += dilation_w * 8;
r1 += dilation_w * 8;
kptr += 256;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
const float* r2 = r0 + N * 2;
const float* r3 = r0 + N * 3;
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
__m512 _w8 = _mm512_load_ps(kptr + 16 * 8);
__m512 _w9 = _mm512_load_ps(kptr + 16 * 9);
__m512 _wa = _mm512_load_ps(kptr + 16 * 10);
__m512 _wb = _mm512_load_ps(kptr + 16 * 11);
__m512 _wc = _mm512_load_ps(kptr + 16 * 12);
__m512 _wd = _mm512_load_ps(kptr + 16 * 13);
__m512 _we = _mm512_load_ps(kptr + 16 * 14);
__m512 _wf = _mm512_load_ps(kptr + 16 * 15);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r1[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r1[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r1[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r1[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r2[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r2[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r2[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r2[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r3[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r3[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r3[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r3[3]), _sum3);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
r2 += dilation_w * 4;
r3 += dilation_w * 4;
kptr += 256;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
__m512 _w8 = _mm512_load_ps(kptr + 16 * 8);
__m512 _w9 = _mm512_load_ps(kptr + 16 * 9);
__m512 _wa = _mm512_load_ps(kptr + 16 * 10);
__m512 _wb = _mm512_load_ps(kptr + 16 * 11);
__m512 _wc = _mm512_load_ps(kptr + 16 * 12);
__m512 _wd = _mm512_load_ps(kptr + 16 * 13);
__m512 _we = _mm512_load_ps(kptr + 16 * 14);
__m512 _wf = _mm512_load_ps(kptr + 16 * 15);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[N * 3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[N * 4]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[N * 5]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[N * 6]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[N * 7]), _sum3);
_sum0 = _mm512_fmadd_ps(_w8, _mm512_set1_ps(r0[N * 8]), _sum0);
_sum1 = _mm512_fmadd_ps(_w9, _mm512_set1_ps(r0[N * 9]), _sum1);
_sum2 = _mm512_fmadd_ps(_wa, _mm512_set1_ps(r0[N * 10]), _sum2);
_sum3 = _mm512_fmadd_ps(_wb, _mm512_set1_ps(r0[N * 11]), _sum3);
_sum0 = _mm512_fmadd_ps(_wc, _mm512_set1_ps(r0[N * 12]), _sum0);
_sum1 = _mm512_fmadd_ps(_wd, _mm512_set1_ps(r0[N * 13]), _sum1);
_sum2 = _mm512_fmadd_ps(_we, _mm512_set1_ps(r0[N * 14]), _sum2);
_sum3 = _mm512_fmadd_ps(_wf, _mm512_set1_ps(r0[N * 15]), _sum3);
r0 += dilation_w;
kptr += 256;
}
}
}
for (; q + 7 < inh; q += 8)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 8)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[4]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[5]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[6]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[7]), _sum3);
r0 += dilation_w * 8;
kptr += 128;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r1[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r1[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r1[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r1[3]), _sum3);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
kptr += 128;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr + 16 * 0);
__m512 _w1 = _mm512_load_ps(kptr + 16 * 1);
__m512 _w2 = _mm512_load_ps(kptr + 16 * 2);
__m512 _w3 = _mm512_load_ps(kptr + 16 * 3);
__m512 _w4 = _mm512_load_ps(kptr + 16 * 4);
__m512 _w5 = _mm512_load_ps(kptr + 16 * 5);
__m512 _w6 = _mm512_load_ps(kptr + 16 * 6);
__m512 _w7 = _mm512_load_ps(kptr + 16 * 7);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[N * 3]), _sum3);
_sum0 = _mm512_fmadd_ps(_w4, _mm512_set1_ps(r0[N * 4]), _sum0);
_sum1 = _mm512_fmadd_ps(_w5, _mm512_set1_ps(r0[N * 5]), _sum1);
_sum2 = _mm512_fmadd_ps(_w6, _mm512_set1_ps(r0[N * 6]), _sum2);
_sum3 = _mm512_fmadd_ps(_w7, _mm512_set1_ps(r0[N * 7]), _sum3);
r0 += dilation_w;
kptr += 128;
}
}
}
for (; q + 3 < inh; q += 4)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 4)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
__m512 _w2 = _mm512_load_ps(kptr + 32);
__m512 _w3 = _mm512_load_ps(kptr + 48);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[1]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[3]), _sum3);
r0 += dilation_w * 4;
kptr += 64;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
__m512 _w2 = _mm512_load_ps(kptr + 32);
__m512 _w3 = _mm512_load_ps(kptr + 48);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1);
_sum2 = _mm512_fmadd_ps(_w2, _mm512_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm512_fmadd_ps(_w3, _mm512_set1_ps(r0[N * 3]), _sum3);
r0 += dilation_w;
kptr += 64;
}
}
}
for (; q + 1 < inh; q += 2)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
_sum0 = _mm512_fmadd_ps(_w0, _mm512_set1_ps(r0[0]), _sum0);
_sum1 = _mm512_fmadd_ps(_w1, _mm512_set1_ps(r0[N]), _sum1);
r0 += dilation_w;
kptr += 32;
}
}
}
for (; q < inh; q++)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _val = _mm512_set1_ps(r0[0]);
__m512 _w = _mm512_load_ps(kptr);
_sum0 = _mm512_fmadd_ps(_val, _w, _sum0);
r0 += dilation_w;
kptr += 16;
}
}
}
_sum0 = _mm512_add_ps(_sum0, _sum1);
_sum2 = _mm512_add_ps(_sum2, _sum3);
_sum0 = _mm512_add_ps(_sum0, _sum2);
_sum0 = activation_avx512(_sum0, activation_type, activation_params);
if (out_elempack == 16)
{
_mm512_store_ps(outptr, _sum0);
outptr += 16;
}
if (out_elempack == 8)
{
_mm256_store_ps(outptr, _mm512_extractf32x8_ps(_sum0, 0));
_mm256_store_ps(outptr + M, _mm512_extractf32x8_ps(_sum0, 1));
outptr += 8;
}
if (out_elempack == 4)
{
_mm_store_ps(outptr, _mm512_extractf32x4_ps(_sum0, 0));
_mm_store_ps(outptr + M, _mm512_extractf32x4_ps(_sum0, 1));
_mm_store_ps(outptr + M * 2, _mm512_extractf32x4_ps(_sum0, 2));
_mm_store_ps(outptr + M * 3, _mm512_extractf32x4_ps(_sum0, 3));
outptr += 4;
}
if (out_elempack == 1)
{
float sum[16];
_mm512_storeu_ps(sum, _sum0);
outptr[0] = sum[0];
outptr[M] = sum[1];
outptr[M * 2] = sum[2];
outptr[M * 3] = sum[3];
outptr[M * 4] = sum[4];
outptr[M * 5] = sum[5];
outptr[M * 6] = sum[6];
outptr[M * 7] = sum[7];
outptr[M * 8] = sum[8];
outptr[M * 9] = sum[9];
outptr[M * 10] = sum[10];
outptr[M * 11] = sum[11];
outptr[M * 12] = sum[12];
outptr[M * 13] = sum[13];
outptr[M * 14] = sum[14];
outptr[M * 15] = sum[15];
outptr += 1;
}
}
}
remain_outh_start += nn_outh * 16;
nn_outh = (outh - remain_outh_start) / 8;
#else // __AVX512F__
nn_outh = (outh - remain_outh_start) / 8;
#pragma omp parallel for num_threads(opt.num_threads)
#endif // __AVX512F__
for (int pp = 0; pp < nn_outh; pp++)
{
const int p = remain_outh_start + pp * 8;
// shadowed variable for less openmp task args
const int elempack = bottom_blob.elempack;
const int inh = bottom_blob.h * elempack;
const int outw = top_blob.w;
const int out_elempack = top_blob.elempack;
float* outptr = top_blob.row(p / out_elempack);
for (int j = 0; j < outw; j++)
{
__m256 _sum0 = _mm256_setzero_ps();
__m256 _sum1 = _mm256_setzero_ps();
__m256 _sum2 = _mm256_setzero_ps();
__m256 _sum3 = _mm256_setzero_ps();
if (bias_data_ptr)
{
_sum0 = _mm256_loadu_ps(bias_data_ptr + p);
}
#if __AVX512F__
const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8);
#else
const float* kptr = weight_data_tm.channel(p / 8);
#endif
int q = 0;
#if __AVX512F__
for (; q + 15 < inh; q += 16)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 16)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr + 8 * 0);
__m256 _w1 = _mm256_load_ps(kptr + 8 * 1);
__m256 _w2 = _mm256_load_ps(kptr + 8 * 2);
__m256 _w3 = _mm256_load_ps(kptr + 8 * 3);
__m256 _w4 = _mm256_load_ps(kptr + 8 * 4);
__m256 _w5 = _mm256_load_ps(kptr + 8 * 5);
__m256 _w6 = _mm256_load_ps(kptr + 8 * 6);
__m256 _w7 = _mm256_load_ps(kptr + 8 * 7);
__m256 _w8 = _mm256_load_ps(kptr + 8 * 8);
__m256 _w9 = _mm256_load_ps(kptr + 8 * 9);
__m256 _wa = _mm256_load_ps(kptr + 8 * 10);
__m256 _wb = _mm256_load_ps(kptr + 8 * 11);
__m256 _wc = _mm256_load_ps(kptr + 8 * 12);
__m256 _wd = _mm256_load_ps(kptr + 8 * 13);
__m256 _we = _mm256_load_ps(kptr + 8 * 14);
__m256 _wf = _mm256_load_ps(kptr + 8 * 15);
_sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3);
_sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[4]), _sum0);
_sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[5]), _sum1);
_sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[6]), _sum2);
_sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[7]), _sum3);
_sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r0[8]), _sum0);
_sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r0[9]), _sum1);
_sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r0[10]), _sum2);
_sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r0[11]), _sum3);
_sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r0[12]), _sum0);
_sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r0[13]), _sum1);
_sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r0[14]), _sum2);
_sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r0[15]), _sum3);
r0 += dilation_w * 16;
kptr += 128;
}
}
if (elempack == 8)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr + 8 * 0);
__m256 _w1 = _mm256_load_ps(kptr + 8 * 1);
__m256 _w2 = _mm256_load_ps(kptr + 8 * 2);
__m256 _w3 = _mm256_load_ps(kptr + 8 * 3);
__m256 _w4 = _mm256_load_ps(kptr + 8 * 4);
__m256 _w5 = _mm256_load_ps(kptr + 8 * 5);
__m256 _w6 = _mm256_load_ps(kptr + 8 * 6);
__m256 _w7 = _mm256_load_ps(kptr + 8 * 7);
__m256 _w8 = _mm256_load_ps(kptr + 8 * 8);
__m256 _w9 = _mm256_load_ps(kptr + 8 * 9);
__m256 _wa = _mm256_load_ps(kptr + 8 * 10);
__m256 _wb = _mm256_load_ps(kptr + 8 * 11);
__m256 _wc = _mm256_load_ps(kptr + 8 * 12);
__m256 _wd = _mm256_load_ps(kptr + 8 * 13);
__m256 _we = _mm256_load_ps(kptr + 8 * 14);
__m256 _wf = _mm256_load_ps(kptr + 8 * 15);
_sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3);
_sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[4]), _sum0);
_sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[5]), _sum1);
_sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[6]), _sum2);
_sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[7]), _sum3);
_sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r1[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r1[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r1[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r1[3]), _sum3);
_sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r1[4]), _sum0);
_sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r1[5]), _sum1);
_sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r1[6]), _sum2);
_sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r1[7]), _sum3);
r0 += dilation_w * 8;
r1 += dilation_w * 8;
kptr += 128;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
const float* r2 = r0 + N * 2;
const float* r3 = r0 + N * 3;
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr + 8 * 0);
__m256 _w1 = _mm256_load_ps(kptr + 8 * 1);
__m256 _w2 = _mm256_load_ps(kptr + 8 * 2);
__m256 _w3 = _mm256_load_ps(kptr + 8 * 3);
__m256 _w4 = _mm256_load_ps(kptr + 8 * 4);
__m256 _w5 = _mm256_load_ps(kptr + 8 * 5);
__m256 _w6 = _mm256_load_ps(kptr + 8 * 6);
__m256 _w7 = _mm256_load_ps(kptr + 8 * 7);
__m256 _w8 = _mm256_load_ps(kptr + 8 * 8);
__m256 _w9 = _mm256_load_ps(kptr + 8 * 9);
__m256 _wa = _mm256_load_ps(kptr + 8 * 10);
__m256 _wb = _mm256_load_ps(kptr + 8 * 11);
__m256 _wc = _mm256_load_ps(kptr + 8 * 12);
__m256 _wd = _mm256_load_ps(kptr + 8 * 13);
__m256 _we = _mm256_load_ps(kptr + 8 * 14);
__m256 _wf = _mm256_load_ps(kptr + 8 * 15);
_sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3);
_sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r1[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r1[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r1[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r1[3]), _sum3);
_sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r2[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r2[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r2[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r2[3]), _sum3);
_sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r3[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r3[1]), _sum1);
_sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r3[2]), _sum2);
_sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r3[3]), _sum3);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
r2 += dilation_w * 4;
r3 += dilation_w * 4;
kptr += 128;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr + 8 * 0);
__m256 _w1 = _mm256_load_ps(kptr + 8 * 1);
__m256 _w2 = _mm256_load_ps(kptr + 8 * 2);
__m256 _w3 = _mm256_load_ps(kptr + 8 * 3);
__m256 _w4 = _mm256_load_ps(kptr + 8 * 4);
__m256 _w5 = _mm256_load_ps(kptr + 8 * 5);
__m256 _w6 = _mm256_load_ps(kptr + 8 * 6);
__m256 _w7 = _mm256_load_ps(kptr + 8 * 7);
__m256 _w8 = _mm256_load_ps(kptr + 8 * 8);
__m256 _w9 = _mm256_load_ps(kptr + 8 * 9);
__m256 _wa = _mm256_load_ps(kptr + 8 * 10);
__m256 _wb = _mm256_load_ps(kptr + 8 * 11);
__m256 _wc = _mm256_load_ps(kptr + 8 * 12);
__m256 _wd = _mm256_load_ps(kptr + 8 * 13);
__m256 _we = _mm256_load_ps(kptr + 8 * 14);
__m256 _wf = _mm256_load_ps(kptr + 8 * 15);
_sum0 = _mm256_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1);
_sum2 = _mm256_fmadd_ps(_w2, _mm256_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm256_fmadd_ps(_w3, _mm256_set1_ps(r0[N * 3]), _sum3);
_sum0 = _mm256_fmadd_ps(_w4, _mm256_set1_ps(r0[N * 4]), _sum0);
_sum1 = _mm256_fmadd_ps(_w5, _mm256_set1_ps(r0[N * 5]), _sum1);
_sum2 = _mm256_fmadd_ps(_w6, _mm256_set1_ps(r0[N * 6]), _sum2);
_sum3 = _mm256_fmadd_ps(_w7, _mm256_set1_ps(r0[N * 7]), _sum3);
_sum0 = _mm256_fmadd_ps(_w8, _mm256_set1_ps(r0[N * 8]), _sum0);
_sum1 = _mm256_fmadd_ps(_w9, _mm256_set1_ps(r0[N * 9]), _sum1);
_sum2 = _mm256_fmadd_ps(_wa, _mm256_set1_ps(r0[N * 10]), _sum2);
_sum3 = _mm256_fmadd_ps(_wb, _mm256_set1_ps(r0[N * 11]), _sum3);
_sum0 = _mm256_fmadd_ps(_wc, _mm256_set1_ps(r0[N * 12]), _sum0);
_sum1 = _mm256_fmadd_ps(_wd, _mm256_set1_ps(r0[N * 13]), _sum1);
_sum2 = _mm256_fmadd_ps(_we, _mm256_set1_ps(r0[N * 14]), _sum2);
_sum3 = _mm256_fmadd_ps(_wf, _mm256_set1_ps(r0[N * 15]), _sum3);
r0 += dilation_w;
kptr += 128;
}
}
}
#endif // __AVX512F__
for (; q + 7 < inh; q += 8)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 8)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
__m256 _w2 = _mm256_load_ps(kptr + 16);
__m256 _w3 = _mm256_load_ps(kptr + 24);
__m256 _w4 = _mm256_load_ps(kptr + 32);
__m256 _w5 = _mm256_load_ps(kptr + 40);
__m256 _w6 = _mm256_load_ps(kptr + 48);
__m256 _w7 = _mm256_load_ps(kptr + 56);
_sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3);
_sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r0[4]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r0[5]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r0[6]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r0[7]), _sum3);
r0 += dilation_w * 8;
kptr += 64;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
__m256 _w2 = _mm256_load_ps(kptr + 16);
__m256 _w3 = _mm256_load_ps(kptr + 24);
__m256 _w4 = _mm256_load_ps(kptr + 32);
__m256 _w5 = _mm256_load_ps(kptr + 40);
__m256 _w6 = _mm256_load_ps(kptr + 48);
__m256 _w7 = _mm256_load_ps(kptr + 56);
_sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3);
_sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r1[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r1[1]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r1[2]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r1[3]), _sum3);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
kptr += 64;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
__m256 _w2 = _mm256_load_ps(kptr + 16);
__m256 _w3 = _mm256_load_ps(kptr + 24);
__m256 _w4 = _mm256_load_ps(kptr + 32);
__m256 _w5 = _mm256_load_ps(kptr + 40);
__m256 _w6 = _mm256_load_ps(kptr + 48);
__m256 _w7 = _mm256_load_ps(kptr + 56);
_sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[N * 3]), _sum3);
_sum0 = _mm256_comp_fmadd_ps(_w4, _mm256_set1_ps(r0[N * 4]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w5, _mm256_set1_ps(r0[N * 5]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w6, _mm256_set1_ps(r0[N * 6]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w7, _mm256_set1_ps(r0[N * 7]), _sum3);
r0 += dilation_w;
kptr += 64;
}
}
}
for (; q + 3 < inh; q += 4)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 4)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
__m256 _w2 = _mm256_load_ps(kptr + 16);
__m256 _w3 = _mm256_load_ps(kptr + 24);
_sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[1]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[2]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[3]), _sum3);
r0 += dilation_w * 4;
kptr += 32;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
__m256 _w2 = _mm256_load_ps(kptr + 16);
__m256 _w3 = _mm256_load_ps(kptr + 24);
_sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1);
_sum2 = _mm256_comp_fmadd_ps(_w2, _mm256_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm256_comp_fmadd_ps(_w3, _mm256_set1_ps(r0[N * 3]), _sum3);
r0 += dilation_w;
kptr += 32;
}
}
}
for (; q + 1 < inh; q += 2)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
_sum0 = _mm256_comp_fmadd_ps(_w0, _mm256_set1_ps(r0[0]), _sum0);
_sum1 = _mm256_comp_fmadd_ps(_w1, _mm256_set1_ps(r0[N]), _sum1);
r0 += dilation_w;
kptr += 16;
}
}
}
for (; q < inh; q++)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _val = _mm256_set1_ps(r0[0]);
__m256 _w = _mm256_load_ps(kptr);
_sum0 = _mm256_comp_fmadd_ps(_val, _w, _sum0);
r0 += dilation_w;
kptr += 8;
}
}
}
_sum0 = _mm256_add_ps(_sum0, _sum1);
_sum2 = _mm256_add_ps(_sum2, _sum3);
_sum0 = _mm256_add_ps(_sum0, _sum2);
_sum0 = activation_avx(_sum0, activation_type, activation_params);
if (out_elempack == 8)
{
_mm256_store_ps(outptr, _sum0);
outptr += 8;
}
if (out_elempack == 4)
{
_mm_store_ps(outptr, _mm256_extractf128_ps(_sum0, 0));
_mm_store_ps(outptr + M, _mm256_extractf128_ps(_sum0, 1));
outptr += 4;
}
if (out_elempack == 1)
{
float sum[8];
_mm256_storeu_ps(sum, _sum0);
outptr[0] = sum[0];
outptr[M] = sum[1];
outptr[M * 2] = sum[2];
outptr[M * 3] = sum[3];
outptr[M * 4] = sum[4];
outptr[M * 5] = sum[5];
outptr[M * 6] = sum[6];
outptr[M * 7] = sum[7];
outptr += 1;
}
}
}
remain_outh_start += nn_outh * 8;
nn_outh = (outh - remain_outh_start) / 4;
#else // __AVX__
nn_outh = (outh - remain_outh_start) / 4;
#pragma omp parallel for num_threads(opt.num_threads)
#endif // __AVX__
for (int pp = 0; pp < nn_outh; pp++)
{
const int p = remain_outh_start + pp * 4;
// shadowed variable for less openmp task args
const int elempack = bottom_blob.elempack;
const int inh = bottom_blob.h * elempack;
const int outw = top_blob.w;
const int out_elempack = top_blob.elempack;
float* outptr = top_blob.row(p / out_elempack);
for (int j = 0; j < outw; j++)
{
__m128 _sum0 = _mm_setzero_ps();
__m128 _sum1 = _mm_setzero_ps();
__m128 _sum2 = _mm_setzero_ps();
__m128 _sum3 = _mm_setzero_ps();
if (bias_data_ptr)
{
_sum0 = _mm_loadu_ps(bias_data_ptr + p);
}
#if __AVX512F__
const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4);
#elif __AVX__
const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4);
#else
const float* kptr = weight_data_tm.channel(p / 4);
#endif
int q = 0;
#if __AVX__
#if __AVX512F__
for (; q + 15 < inh; q += 16)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 16)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr + 4 * 0);
__m128 _w1 = _mm_load_ps(kptr + 4 * 1);
__m128 _w2 = _mm_load_ps(kptr + 4 * 2);
__m128 _w3 = _mm_load_ps(kptr + 4 * 3);
__m128 _w4 = _mm_load_ps(kptr + 4 * 4);
__m128 _w5 = _mm_load_ps(kptr + 4 * 5);
__m128 _w6 = _mm_load_ps(kptr + 4 * 6);
__m128 _w7 = _mm_load_ps(kptr + 4 * 7);
__m128 _w8 = _mm_load_ps(kptr + 4 * 8);
__m128 _w9 = _mm_load_ps(kptr + 4 * 9);
__m128 _wa = _mm_load_ps(kptr + 4 * 10);
__m128 _wb = _mm_load_ps(kptr + 4 * 11);
__m128 _wc = _mm_load_ps(kptr + 4 * 12);
__m128 _wd = _mm_load_ps(kptr + 4 * 13);
__m128 _we = _mm_load_ps(kptr + 4 * 14);
__m128 _wf = _mm_load_ps(kptr + 4 * 15);
_sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1);
_sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2);
_sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3);
_sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[4]), _sum0);
_sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[5]), _sum1);
_sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[6]), _sum2);
_sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[7]), _sum3);
_sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r0[8]), _sum0);
_sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r0[9]), _sum1);
_sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r0[10]), _sum2);
_sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r0[11]), _sum3);
_sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r0[12]), _sum0);
_sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r0[13]), _sum1);
_sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r0[14]), _sum2);
_sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r0[15]), _sum3);
r0 += dilation_w * 16;
kptr += 64;
}
}
if (elempack == 8)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr + 4 * 0);
__m128 _w1 = _mm_load_ps(kptr + 4 * 1);
__m128 _w2 = _mm_load_ps(kptr + 4 * 2);
__m128 _w3 = _mm_load_ps(kptr + 4 * 3);
__m128 _w4 = _mm_load_ps(kptr + 4 * 4);
__m128 _w5 = _mm_load_ps(kptr + 4 * 5);
__m128 _w6 = _mm_load_ps(kptr + 4 * 6);
__m128 _w7 = _mm_load_ps(kptr + 4 * 7);
__m128 _w8 = _mm_load_ps(kptr + 4 * 8);
__m128 _w9 = _mm_load_ps(kptr + 4 * 9);
__m128 _wa = _mm_load_ps(kptr + 4 * 10);
__m128 _wb = _mm_load_ps(kptr + 4 * 11);
__m128 _wc = _mm_load_ps(kptr + 4 * 12);
__m128 _wd = _mm_load_ps(kptr + 4 * 13);
__m128 _we = _mm_load_ps(kptr + 4 * 14);
__m128 _wf = _mm_load_ps(kptr + 4 * 15);
_sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1);
_sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2);
_sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3);
_sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[4]), _sum0);
_sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[5]), _sum1);
_sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[6]), _sum2);
_sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[7]), _sum3);
_sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r1[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r1[1]), _sum1);
_sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r1[2]), _sum2);
_sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r1[3]), _sum3);
_sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r1[4]), _sum0);
_sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r1[5]), _sum1);
_sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r1[6]), _sum2);
_sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r1[7]), _sum3);
r0 += dilation_w * 8;
r1 += dilation_w * 8;
kptr += 64;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
const float* r2 = r0 + N * 2;
const float* r3 = r0 + N * 3;
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr + 4 * 0);
__m128 _w1 = _mm_load_ps(kptr + 4 * 1);
__m128 _w2 = _mm_load_ps(kptr + 4 * 2);
__m128 _w3 = _mm_load_ps(kptr + 4 * 3);
__m128 _w4 = _mm_load_ps(kptr + 4 * 4);
__m128 _w5 = _mm_load_ps(kptr + 4 * 5);
__m128 _w6 = _mm_load_ps(kptr + 4 * 6);
__m128 _w7 = _mm_load_ps(kptr + 4 * 7);
__m128 _w8 = _mm_load_ps(kptr + 4 * 8);
__m128 _w9 = _mm_load_ps(kptr + 4 * 9);
__m128 _wa = _mm_load_ps(kptr + 4 * 10);
__m128 _wb = _mm_load_ps(kptr + 4 * 11);
__m128 _wc = _mm_load_ps(kptr + 4 * 12);
__m128 _wd = _mm_load_ps(kptr + 4 * 13);
__m128 _we = _mm_load_ps(kptr + 4 * 14);
__m128 _wf = _mm_load_ps(kptr + 4 * 15);
_sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1);
_sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2);
_sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3);
_sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r1[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r1[1]), _sum1);
_sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r1[2]), _sum2);
_sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r1[3]), _sum3);
_sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r2[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r2[1]), _sum1);
_sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r2[2]), _sum2);
_sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r2[3]), _sum3);
_sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r3[0]), _sum0);
_sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r3[1]), _sum1);
_sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r3[2]), _sum2);
_sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r3[3]), _sum3);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
r2 += dilation_w * 4;
r3 += dilation_w * 4;
kptr += 64;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr + 4 * 0);
__m128 _w1 = _mm_load_ps(kptr + 4 * 1);
__m128 _w2 = _mm_load_ps(kptr + 4 * 2);
__m128 _w3 = _mm_load_ps(kptr + 4 * 3);
__m128 _w4 = _mm_load_ps(kptr + 4 * 4);
__m128 _w5 = _mm_load_ps(kptr + 4 * 5);
__m128 _w6 = _mm_load_ps(kptr + 4 * 6);
__m128 _w7 = _mm_load_ps(kptr + 4 * 7);
__m128 _w8 = _mm_load_ps(kptr + 4 * 8);
__m128 _w9 = _mm_load_ps(kptr + 4 * 9);
__m128 _wa = _mm_load_ps(kptr + 4 * 10);
__m128 _wb = _mm_load_ps(kptr + 4 * 11);
__m128 _wc = _mm_load_ps(kptr + 4 * 12);
__m128 _wd = _mm_load_ps(kptr + 4 * 13);
__m128 _we = _mm_load_ps(kptr + 4 * 14);
__m128 _wf = _mm_load_ps(kptr + 4 * 15);
_sum0 = _mm_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1);
_sum2 = _mm_fmadd_ps(_w2, _mm_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm_fmadd_ps(_w3, _mm_set1_ps(r0[N * 3]), _sum3);
_sum0 = _mm_fmadd_ps(_w4, _mm_set1_ps(r0[N * 4]), _sum0);
_sum1 = _mm_fmadd_ps(_w5, _mm_set1_ps(r0[N * 5]), _sum1);
_sum2 = _mm_fmadd_ps(_w6, _mm_set1_ps(r0[N * 6]), _sum2);
_sum3 = _mm_fmadd_ps(_w7, _mm_set1_ps(r0[N * 7]), _sum3);
_sum0 = _mm_fmadd_ps(_w8, _mm_set1_ps(r0[N * 8]), _sum0);
_sum1 = _mm_fmadd_ps(_w9, _mm_set1_ps(r0[N * 9]), _sum1);
_sum2 = _mm_fmadd_ps(_wa, _mm_set1_ps(r0[N * 10]), _sum2);
_sum3 = _mm_fmadd_ps(_wb, _mm_set1_ps(r0[N * 11]), _sum3);
_sum0 = _mm_fmadd_ps(_wc, _mm_set1_ps(r0[N * 12]), _sum0);
_sum1 = _mm_fmadd_ps(_wd, _mm_set1_ps(r0[N * 13]), _sum1);
_sum2 = _mm_fmadd_ps(_we, _mm_set1_ps(r0[N * 14]), _sum2);
_sum3 = _mm_fmadd_ps(_wf, _mm_set1_ps(r0[N * 15]), _sum3);
r0 += dilation_w;
kptr += 64;
}
}
}
#endif // __AVX512F__
for (; q + 7 < inh; q += 8)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 8)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
__m128 _w2 = _mm_load_ps(kptr + 8);
__m128 _w3 = _mm_load_ps(kptr + 12);
__m128 _w4 = _mm_load_ps(kptr + 16);
__m128 _w5 = _mm_load_ps(kptr + 20);
__m128 _w6 = _mm_load_ps(kptr + 24);
__m128 _w7 = _mm_load_ps(kptr + 28);
_sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3);
_sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r0[4]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r0[5]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r0[6]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r0[7]), _sum3);
r0 += dilation_w * 8;
kptr += 32;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
__m128 _w2 = _mm_load_ps(kptr + 8);
__m128 _w3 = _mm_load_ps(kptr + 12);
__m128 _w4 = _mm_load_ps(kptr + 16);
__m128 _w5 = _mm_load_ps(kptr + 20);
__m128 _w6 = _mm_load_ps(kptr + 24);
__m128 _w7 = _mm_load_ps(kptr + 28);
_sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3);
_sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r1[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r1[1]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r1[2]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r1[3]), _sum3);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
kptr += 32;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
__m128 _w2 = _mm_load_ps(kptr + 8);
__m128 _w3 = _mm_load_ps(kptr + 12);
__m128 _w4 = _mm_load_ps(kptr + 16);
__m128 _w5 = _mm_load_ps(kptr + 20);
__m128 _w6 = _mm_load_ps(kptr + 24);
__m128 _w7 = _mm_load_ps(kptr + 28);
_sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[N * 3]), _sum3);
_sum0 = _mm_comp_fmadd_ps(_w4, _mm_set1_ps(r0[N * 4]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w5, _mm_set1_ps(r0[N * 5]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w6, _mm_set1_ps(r0[N * 6]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w7, _mm_set1_ps(r0[N * 7]), _sum3);
r0 += dilation_w;
kptr += 32;
}
}
}
#endif // __AVX__
for (; q + 3 < inh; q += 4)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 4)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
__m128 _w2 = _mm_load_ps(kptr + 8);
__m128 _w3 = _mm_load_ps(kptr + 12);
_sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[1]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[2]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[3]), _sum3);
r0 += dilation_w * 4;
kptr += 16;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
__m128 _w2 = _mm_load_ps(kptr + 8);
__m128 _w3 = _mm_load_ps(kptr + 12);
_sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1);
_sum2 = _mm_comp_fmadd_ps(_w2, _mm_set1_ps(r0[N * 2]), _sum2);
_sum3 = _mm_comp_fmadd_ps(_w3, _mm_set1_ps(r0[N * 3]), _sum3);
r0 += dilation_w;
kptr += 16;
}
}
}
for (; q + 1 < inh; q += 2)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
_sum0 = _mm_comp_fmadd_ps(_w0, _mm_set1_ps(r0[0]), _sum0);
_sum1 = _mm_comp_fmadd_ps(_w1, _mm_set1_ps(r0[N]), _sum1);
r0 += dilation_w;
kptr += 8;
}
}
}
for (; q < inh; q++)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _val = _mm_set1_ps(r0[0]);
__m128 _w = _mm_load_ps(kptr);
_sum0 = _mm_comp_fmadd_ps(_val, _w, _sum0);
r0 += dilation_w;
kptr += 4;
}
}
}
_sum0 = _mm_add_ps(_sum0, _sum1);
_sum2 = _mm_add_ps(_sum2, _sum3);
_sum0 = _mm_add_ps(_sum0, _sum2);
_sum0 = activation_sse(_sum0, activation_type, activation_params);
if (out_elempack == 4)
{
_mm_storeu_ps(outptr, _sum0);
outptr += 4;
}
if (out_elempack == 1)
{
float sum[4];
_mm_storeu_ps(sum, _sum0);
outptr[0] = sum[0];
outptr[M] = sum[1];
outptr[M * 2] = sum[2];
outptr[M * 3] = sum[3];
outptr += 1;
}
}
}
remain_outh_start += nn_outh * 4;
nn_outh = (outh - remain_outh_start) / 2;
#else // __SSE2__
nn_outh = (outh - remain_outh_start) / 2;
#pragma omp parallel for num_threads(opt.num_threads)
#endif // __SSE2__
for (int pp = 0; pp < nn_outh; pp++)
{
const int p = remain_outh_start + pp * 2;
// shadowed variable for less openmp task args
const int elempack = bottom_blob.elempack;
const int inh = bottom_blob.h * elempack;
const int outw = top_blob.w;
float* outptr0 = top_blob.row(p);
float* outptr1 = top_blob.row(p + 1);
for (int j = 0; j < outw; j++)
{
float sum0 = 0.f;
float sum1 = 0.f;
if (bias_data_ptr)
{
sum0 = bias_data_ptr[p];
sum1 = bias_data_ptr[p + 1];
}
#if __AVX512F__
const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2);
#elif __AVX__
const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2);
#elif __SSE2__
const float* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2);
#else
const float* kptr = weight_data_tm.channel(p / 2);
#endif
int q = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
__m512 _sum0_avx512 = _mm512_setzero_ps();
__m512 _sum1_avx512 = _mm512_setzero_ps();
for (; q + 15 < inh; q += 16)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 16)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_load_ps(r0);
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
_sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512);
_sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512);
r0 += dilation_w * 16;
kptr += 32;
}
}
if (elempack == 8)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_load_ps(r0)), _mm256_load_ps(r1), 1);
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
_sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512);
_sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512);
r0 += dilation_w * 8;
r1 += dilation_w * 8;
kptr += 32;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
const float* r2 = r0 + N * 2;
const float* r3 = r0 + N * 3;
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r2)), _mm_load_ps(r3), 1), 1);
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
_sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512);
_sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
r2 += dilation_w * 4;
r3 += dilation_w * 4;
kptr += 32;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_set_ps(r0[N * 15], r0[N * 14], r0[N * 13], r0[N * 12], r0[N * 11], r0[N * 10], r0[N * 9], r0[N * 8], r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]);
__m512 _w0 = _mm512_load_ps(kptr);
__m512 _w1 = _mm512_load_ps(kptr + 16);
_sum0_avx512 = _mm512_fmadd_ps(_r0, _w0, _sum0_avx512);
_sum1_avx512 = _mm512_fmadd_ps(_r0, _w1, _sum1_avx512);
r0 += dilation_w;
kptr += 32;
}
}
}
sum0 += _mm512_comp_reduce_add_ps(_sum0_avx512);
sum1 += _mm512_comp_reduce_add_ps(_sum1_avx512);
#endif // __AVX512F__
__m256 _sum0_avx = _mm256_setzero_ps();
__m256 _sum1_avx = _mm256_setzero_ps();
for (; q + 7 < inh; q += 8)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 8)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _r0 = _mm256_load_ps(r0);
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
_sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx);
_sum1_avx = _mm256_comp_fmadd_ps(_r0, _w1, _sum1_avx);
r0 += dilation_w * 8;
kptr += 16;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m256 _r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1);
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
_sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx);
_sum1_avx = _mm256_comp_fmadd_ps(_r0, _w1, _sum1_avx);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
kptr += 16;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _r0 = _mm256_set_ps(r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]);
__m256 _w0 = _mm256_load_ps(kptr);
__m256 _w1 = _mm256_load_ps(kptr + 8);
_sum0_avx = _mm256_comp_fmadd_ps(_r0, _w0, _sum0_avx);
_sum1_avx = _mm256_comp_fmadd_ps(_r0, _w1, _sum1_avx);
r0 += dilation_w;
kptr += 16;
}
}
}
sum0 += _mm256_reduce_add_ps(_sum0_avx);
sum1 += _mm256_reduce_add_ps(_sum1_avx);
#endif // __AVX__
__m128 _sum0 = _mm_setzero_ps();
__m128 _sum1 = _mm_setzero_ps();
for (; q + 3 < inh; q += 4)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 4)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _r0 = _mm_load_ps(r0);
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
_sum0 = _mm_comp_fmadd_ps(_r0, _w0, _sum0);
_sum1 = _mm_comp_fmadd_ps(_r0, _w1, _sum1);
r0 += dilation_w * 4;
kptr += 8;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _r0 = _mm_set_ps(r0[N * 3], r0[N * 2], r0[N], r0[0]);
__m128 _w0 = _mm_load_ps(kptr);
__m128 _w1 = _mm_load_ps(kptr + 4);
_sum0 = _mm_comp_fmadd_ps(_r0, _w0, _sum0);
_sum1 = _mm_comp_fmadd_ps(_r0, _w1, _sum1);
r0 += dilation_w;
kptr += 8;
}
}
}
sum0 += _mm_reduce_add_ps(_sum0);
sum1 += _mm_reduce_add_ps(_sum1);
#endif // __SSE2__
for (; q + 1 < inh; q += 2)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
sum0 += r0[0] * kptr[0];
sum1 += r0[0] * kptr[1];
sum0 += r0[N] * kptr[2];
sum1 += r0[N] * kptr[3];
r0 += dilation_w;
kptr += 4;
}
}
}
for (; q < inh; q++)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
float val = r0[0];
sum0 += val * kptr[0];
sum1 += val * kptr[1];
r0 += dilation_w;
kptr += 2;
}
}
}
sum0 = activation_ss(sum0, activation_type, activation_params);
sum1 = activation_ss(sum1, activation_type, activation_params);
outptr0[0] = sum0;
outptr1[0] = sum1;
outptr0 += 1;
outptr1 += 1;
}
}
remain_outh_start += nn_outh * 2;
for (int p = remain_outh_start; p < outh; p++)
{
float* outptr = top_blob.row(p);
for (int j = 0; j < outw; j++)
{
float sum = 0.f;
if (bias_data_ptr)
{
sum = bias_data_ptr[p];
}
#if __AVX512F__
const float* kptr = weight_data_tm.channel(p / 16 + (p % 16) / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2);
#elif __AVX__
const float* kptr = weight_data_tm.channel(p / 8 + (p % 8) / 4 + (p % 4) / 2 + p % 2);
#elif __SSE2__
const float* kptr = weight_data_tm.channel(p / 4 + (p % 4) / 2 + p % 2);
#else
const float* kptr = weight_data_tm.channel(p / 2 + p % 2);
#endif
int q = 0;
#if __SSE2__
#if __AVX__
#if __AVX512F__
__m512 _sum_avx512 = _mm512_setzero_ps();
for (; q + 15 < inh; q += 16)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 16)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_load_ps(r0);
__m512 _w = _mm512_load_ps(kptr);
_sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512);
r0 += dilation_w * 16;
kptr += 16;
}
}
if (elempack == 8)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_load_ps(r0)), _mm256_load_ps(r1), 1);
__m512 _w = _mm512_load_ps(kptr);
_sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512);
r0 += dilation_w * 8;
r1 += dilation_w * 8;
kptr += 16;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
const float* r2 = r0 + N * 2;
const float* r3 = r0 + N * 3;
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1)), _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r2)), _mm_load_ps(r3), 1), 1);
__m512 _w = _mm512_load_ps(kptr);
_sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
r2 += dilation_w * 4;
r3 += dilation_w * 4;
kptr += 16;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m512 _r0 = _mm512_set_ps(r0[N * 15], r0[N * 14], r0[N * 13], r0[N * 12], r0[N * 11], r0[N * 10], r0[N * 9], r0[N * 8], r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]);
__m512 _w = _mm512_load_ps(kptr);
_sum_avx512 = _mm512_fmadd_ps(_r0, _w, _sum_avx512);
r0 += dilation_w;
kptr += 16;
}
}
}
sum += _mm512_comp_reduce_add_ps(_sum_avx512);
#endif // __AVX512F__
__m256 _sum_avx = _mm256_setzero_ps();
for (; q + 7 < inh; q += 8)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 8)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _r0 = _mm256_load_ps(r0);
__m256 _w = _mm256_load_ps(kptr);
_sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx);
r0 += dilation_w * 8;
kptr += 8;
}
}
if (elempack == 4)
{
const float* r1 = r0 + N;
for (int k = 0; k < kernel_w; k++)
{
__m256 _r0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_mm_load_ps(r0)), _mm_load_ps(r1), 1);
__m256 _w = _mm256_load_ps(kptr);
_sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx);
r0 += dilation_w * 4;
r1 += dilation_w * 4;
kptr += 8;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m256 _r0 = _mm256_set_ps(r0[N * 7], r0[N * 6], r0[N * 5], r0[N * 4], r0[N * 3], r0[N * 2], r0[N], r0[0]);
__m256 _w = _mm256_load_ps(kptr);
_sum_avx = _mm256_comp_fmadd_ps(_r0, _w, _sum_avx);
r0 += dilation_w;
kptr += 8;
}
}
}
sum += _mm256_reduce_add_ps(_sum_avx);
#endif // __AVX__
__m128 _sum = _mm_setzero_ps();
for (; q + 3 < inh; q += 4)
{
const float* r0 = bottom_blob.row(q / elempack) + j * stride_w * elempack;
if (elempack == 4)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _r0 = _mm_load_ps(r0);
__m128 _w = _mm_load_ps(kptr);
_sum = _mm_comp_fmadd_ps(_r0, _w, _sum);
r0 += dilation_w * 4;
kptr += 4;
}
}
if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
__m128 _r0 = _mm_set_ps(r0[N * 3], r0[N * 2], r0[N], r0[0]);
__m128 _w = _mm_load_ps(kptr);
_sum = _mm_comp_fmadd_ps(_r0, _w, _sum);
r0 += dilation_w;
kptr += 4;
}
}
}
sum += _mm_reduce_add_ps(_sum);
#endif // __SSE2__
for (; q + 1 < inh; q += 2)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
sum += r0[0] * kptr[0];
sum += r0[N] * kptr[1];
r0 += dilation_w;
kptr += 2;
}
}
}
for (; q < inh; q++)
{
const float* r0 = bottom_blob.row(q) + j * stride_w;
// if (elempack == 1)
{
for (int k = 0; k < kernel_w; k++)
{
float val = r0[0];
sum += val * kptr[0];
r0 += dilation_w;
kptr += 1;
}
}
}
sum = activation_ss(sum, activation_type, activation_params);
outptr[0] = sum;
outptr += 1;
}
}
}