ncnn / src /layer /vulkan /shader /convolution_3x3s1d1_winograd43_transform_output.comp
camenduru's picture
thanks to ncnn ❤
be903e2
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#version 450
#if NCNN_fp16_storage
#extension GL_EXT_shader_16bit_storage: require
#endif
#if NCNN_fp16_arithmetic
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
#endif
#extension GL_GOOGLE_include_directive: enable
#include "vulkan_activation.comp"
layout (constant_id = 0) const int bias_term = 0;
layout (constant_id = 1) const int activation_type = 0;
layout (constant_id = 2) const float activation_param_0 = 0;
layout (constant_id = 3) const float activation_param_1 = 0;
#define shape_constant_id_offset 4
layout (constant_id = shape_constant_id_offset + 0) const int c = 0;
layout (constant_id = shape_constant_id_offset + 1) const int cstep = 0;
layout (constant_id = shape_constant_id_offset + 2) const int block_x = 0;
layout (constant_id = shape_constant_id_offset + 3) const int block_y = 0;
layout (constant_id = shape_constant_id_offset + 4) const int outw = 0;
layout (constant_id = shape_constant_id_offset + 5) const int outh = 0;
layout (constant_id = shape_constant_id_offset + 6) const int outcstep = 0;
#if NCNN_image_shader
layout (binding = 0) uniform unfp sampler3D top_tm_blob;
layout (binding = 1, imfmtc1) writeonly uniform unfp image3D top_blob;
layout (binding = 2) uniform unfp sampler3D bias_blob;
#else
layout (binding = 0) readonly buffer top_tm_blob { sfp top_tm_blob_data[]; };
layout (binding = 1) writeonly buffer top_blob { sfp top_blob_data[]; };
layout (binding = 2) readonly buffer bias_blob { sfp bias_data[]; };
#endif
layout (push_constant) uniform parameter
{
int c;
int cstep;
int block_x;
int block_y;
int outw;
int outh;
int outcstep;
} p;
void main()
{
int gx = int(gl_GlobalInvocationID.x);
int gy = int(gl_GlobalInvocationID.y);
int gz = int(gl_GlobalInvocationID.z);
if (gx >= psc(block_x) || gy >= psc(block_y) || gz >= psc(c))
return;
// load 36
#if NCNN_image_shader
int sx = gy * psc(block_x) + gx;
afp v00 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 0));
afp v01 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 1));
afp v02 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 2));
afp v03 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 3));
afp v04 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 4));
afp v05 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 5));
afp v10 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 6));
afp v11 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 7));
afp v12 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 8));
afp v13 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 9));
afp v14 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 10));
afp v15 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 11));
afp v20 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 12));
afp v21 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 13));
afp v22 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 14));
afp v23 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 15));
afp v24 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 16));
afp v25 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 17));
afp v30 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 18));
afp v31 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 19));
afp v32 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 20));
afp v33 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 21));
afp v34 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 22));
afp v35 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 23));
afp v40 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 24));
afp v41 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 25));
afp v42 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 26));
afp v43 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 27));
afp v44 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 28));
afp v45 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 29));
afp v50 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 30));
afp v51 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 31));
afp v52 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 32));
afp v53 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 33));
afp v54 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 34));
afp v55 = image3d_ld1(top_tm_blob, ivec3(sx, gz, 35));
#else
int v_tm_offset = gz * psc(block_x) * psc(block_y) + gy * psc(block_x) + gx;
afp v00 = buffer_ld1(top_tm_blob_data, v_tm_offset + 0 * psc(cstep));
afp v01 = buffer_ld1(top_tm_blob_data, v_tm_offset + 1 * psc(cstep));
afp v02 = buffer_ld1(top_tm_blob_data, v_tm_offset + 2 * psc(cstep));
afp v03 = buffer_ld1(top_tm_blob_data, v_tm_offset + 3 * psc(cstep));
afp v04 = buffer_ld1(top_tm_blob_data, v_tm_offset + 4 * psc(cstep));
afp v05 = buffer_ld1(top_tm_blob_data, v_tm_offset + 5 * psc(cstep));
afp v10 = buffer_ld1(top_tm_blob_data, v_tm_offset + 6 * psc(cstep));
afp v11 = buffer_ld1(top_tm_blob_data, v_tm_offset + 7 * psc(cstep));
afp v12 = buffer_ld1(top_tm_blob_data, v_tm_offset + 8 * psc(cstep));
afp v13 = buffer_ld1(top_tm_blob_data, v_tm_offset + 9 * psc(cstep));
afp v14 = buffer_ld1(top_tm_blob_data, v_tm_offset + 10 * psc(cstep));
afp v15 = buffer_ld1(top_tm_blob_data, v_tm_offset + 11 * psc(cstep));
afp v20 = buffer_ld1(top_tm_blob_data, v_tm_offset + 12 * psc(cstep));
afp v21 = buffer_ld1(top_tm_blob_data, v_tm_offset + 13 * psc(cstep));
afp v22 = buffer_ld1(top_tm_blob_data, v_tm_offset + 14 * psc(cstep));
afp v23 = buffer_ld1(top_tm_blob_data, v_tm_offset + 15 * psc(cstep));
afp v24 = buffer_ld1(top_tm_blob_data, v_tm_offset + 16 * psc(cstep));
afp v25 = buffer_ld1(top_tm_blob_data, v_tm_offset + 17 * psc(cstep));
afp v30 = buffer_ld1(top_tm_blob_data, v_tm_offset + 18 * psc(cstep));
afp v31 = buffer_ld1(top_tm_blob_data, v_tm_offset + 19 * psc(cstep));
afp v32 = buffer_ld1(top_tm_blob_data, v_tm_offset + 20 * psc(cstep));
afp v33 = buffer_ld1(top_tm_blob_data, v_tm_offset + 21 * psc(cstep));
afp v34 = buffer_ld1(top_tm_blob_data, v_tm_offset + 22 * psc(cstep));
afp v35 = buffer_ld1(top_tm_blob_data, v_tm_offset + 23 * psc(cstep));
afp v40 = buffer_ld1(top_tm_blob_data, v_tm_offset + 24 * psc(cstep));
afp v41 = buffer_ld1(top_tm_blob_data, v_tm_offset + 25 * psc(cstep));
afp v42 = buffer_ld1(top_tm_blob_data, v_tm_offset + 26 * psc(cstep));
afp v43 = buffer_ld1(top_tm_blob_data, v_tm_offset + 27 * psc(cstep));
afp v44 = buffer_ld1(top_tm_blob_data, v_tm_offset + 28 * psc(cstep));
afp v45 = buffer_ld1(top_tm_blob_data, v_tm_offset + 29 * psc(cstep));
afp v50 = buffer_ld1(top_tm_blob_data, v_tm_offset + 30 * psc(cstep));
afp v51 = buffer_ld1(top_tm_blob_data, v_tm_offset + 31 * psc(cstep));
afp v52 = buffer_ld1(top_tm_blob_data, v_tm_offset + 32 * psc(cstep));
afp v53 = buffer_ld1(top_tm_blob_data, v_tm_offset + 33 * psc(cstep));
afp v54 = buffer_ld1(top_tm_blob_data, v_tm_offset + 34 * psc(cstep));
afp v55 = buffer_ld1(top_tm_blob_data, v_tm_offset + 35 * psc(cstep));
#endif
#define sq2 1.41421356237
#define sq2_m2 1.41421356237*2
#define sq2_d2 1.41421356237/2
#define sq2_d4 1.41421356237/4
// const float otm[4][6] = {
// {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.0f},
// {0.0f, sq2/2, -sq2/2, sq2, -sq2, 0.0f},
// {0.0f, 0.5f, 0.5f, 2.0f, 2.0f, 0.0f},
// {0.0f, sq2/4, -sq2/4, sq2*2, -sq2*2, 1.0f}
// };
// implicit transpose
afp m00 = v00 + v01 + v02 + v03 + v04;
afp m01 = v10 + v11 + v12 + v13 + v14;
afp m02 = v20 + v21 + v22 + v23 + v24;
afp m03 = v30 + v31 + v32 + v33 + v34;
afp m04 = v40 + v41 + v42 + v43 + v44;
afp m05 = v50 + v51 + v52 + v53 + v54;
afp m10 = (v01 - v02) * afp(sq2_d2) + (v03 - v04) * afp(sq2);
afp m11 = (v11 - v12) * afp(sq2_d2) + (v13 - v14) * afp(sq2);
afp m12 = (v21 - v22) * afp(sq2_d2) + (v23 - v24) * afp(sq2);
afp m13 = (v31 - v32) * afp(sq2_d2) + (v33 - v34) * afp(sq2);
afp m14 = (v41 - v42) * afp(sq2_d2) + (v43 - v44) * afp(sq2);
afp m15 = (v51 - v52) * afp(sq2_d2) + (v53 - v54) * afp(sq2);
afp m20 = (v01 + v02) * afp(0.5) + (v03 + v04) * afp(2);
afp m21 = (v11 + v12) * afp(0.5) + (v13 + v14) * afp(2);
afp m22 = (v21 + v22) * afp(0.5) + (v23 + v24) * afp(2);
afp m23 = (v31 + v32) * afp(0.5) + (v33 + v34) * afp(2);
afp m24 = (v41 + v42) * afp(0.5) + (v43 + v44) * afp(2);
afp m25 = (v51 + v52) * afp(0.5) + (v53 + v54) * afp(2);
afp m30 = v05 + (v01 - v02) * afp(sq2_d4) + (v03 - v04) * afp(sq2_m2);
afp m31 = v15 + (v11 - v12) * afp(sq2_d4) + (v13 - v14) * afp(sq2_m2);
afp m32 = v25 + (v21 - v22) * afp(sq2_d4) + (v23 - v24) * afp(sq2_m2);
afp m33 = v35 + (v31 - v32) * afp(sq2_d4) + (v33 - v34) * afp(sq2_m2);
afp m34 = v45 + (v41 - v42) * afp(sq2_d4) + (v43 - v44) * afp(sq2_m2);
afp m35 = v55 + (v51 - v52) * afp(sq2_d4) + (v53 - v54) * afp(sq2_m2);
v00 = m00 + m01 + m02 + m03 + m04;
v10 = m10 + m11 + m12 + m13 + m14;
v20 = m20 + m21 + m22 + m23 + m24;
v30 = m30 + m31 + m32 + m33 + m34;
v01 = (m01 - m02) * afp(sq2_d2) + (m03 - m04) * afp(sq2);
v11 = (m11 - m12) * afp(sq2_d2) + (m13 - m14) * afp(sq2);
v21 = (m21 - m22) * afp(sq2_d2) + (m23 - m24) * afp(sq2);
v31 = (m31 - m32) * afp(sq2_d2) + (m33 - m34) * afp(sq2);
v02 = (m01 + m02) * afp(0.5) + (m03 + m04) * afp(2);
v12 = (m11 + m12) * afp(0.5) + (m13 + m14) * afp(2);
v22 = (m21 + m22) * afp(0.5) + (m23 + m24) * afp(2);
v32 = (m31 + m32) * afp(0.5) + (m33 + m34) * afp(2);
v03 = m05 + (m01 - m02) * afp(sq2_d4) + (m03 - m04) * afp(sq2_m2);
v13 = m15 + (m11 - m12) * afp(sq2_d4) + (m13 - m14) * afp(sq2_m2);
v23 = m25 + (m21 - m22) * afp(sq2_d4) + (m23 - m24) * afp(sq2_m2);
v33 = m35 + (m31 - m32) * afp(sq2_d4) + (m33 - m34) * afp(sq2_m2);
if (bias_term == 1)
{
#if NCNN_image_shader
const afp bias_value = image3d_ld1(bias_blob, ivec3(gz, 0, 0));
#else
const afp bias_value = buffer_ld1(bias_data, gz);
#endif
v00 = bias_value + v00;
v01 = bias_value + v01;
v02 = bias_value + v02;
v03 = bias_value + v03;
v10 = bias_value + v10;
v11 = bias_value + v11;
v12 = bias_value + v12;
v13 = bias_value + v13;
v20 = bias_value + v20;
v21 = bias_value + v21;
v22 = bias_value + v22;
v23 = bias_value + v23;
v30 = bias_value + v30;
v31 = bias_value + v31;
v32 = bias_value + v32;
v33 = bias_value + v33;
}
v00 = activation_afp(v00, activation_type, activation_param_0, activation_param_1);
v01 = activation_afp(v01, activation_type, activation_param_0, activation_param_1);
v02 = activation_afp(v02, activation_type, activation_param_0, activation_param_1);
v03 = activation_afp(v03, activation_type, activation_param_0, activation_param_1);
v10 = activation_afp(v10, activation_type, activation_param_0, activation_param_1);
v11 = activation_afp(v11, activation_type, activation_param_0, activation_param_1);
v12 = activation_afp(v12, activation_type, activation_param_0, activation_param_1);
v13 = activation_afp(v13, activation_type, activation_param_0, activation_param_1);
v20 = activation_afp(v20, activation_type, activation_param_0, activation_param_1);
v21 = activation_afp(v21, activation_type, activation_param_0, activation_param_1);
v22 = activation_afp(v22, activation_type, activation_param_0, activation_param_1);
v23 = activation_afp(v23, activation_type, activation_param_0, activation_param_1);
v30 = activation_afp(v30, activation_type, activation_param_0, activation_param_1);
v31 = activation_afp(v31, activation_type, activation_param_0, activation_param_1);
v32 = activation_afp(v32, activation_type, activation_param_0, activation_param_1);
v33 = activation_afp(v33, activation_type, activation_param_0, activation_param_1);
// store 4x4
int x = gx * 4;
int y = gy * 4;
#if NCNN_image_shader
image3d_st1(top_blob, ivec3(x, y, gz), v00);
image3d_st1(top_blob, ivec3(x + 1, y, gz), v01);
image3d_st1(top_blob, ivec3(x + 2, y, gz), v02);
image3d_st1(top_blob, ivec3(x + 3, y, gz), v03);
image3d_st1(top_blob, ivec3(x, y + 1, gz), v10);
image3d_st1(top_blob, ivec3(x + 1, y + 1, gz), v11);
image3d_st1(top_blob, ivec3(x + 2, y + 1, gz), v12);
image3d_st1(top_blob, ivec3(x + 3, y + 1, gz), v13);
image3d_st1(top_blob, ivec3(x, y + 2, gz), v20);
image3d_st1(top_blob, ivec3(x + 1, y + 2, gz), v21);
image3d_st1(top_blob, ivec3(x + 2, y + 2, gz), v22);
image3d_st1(top_blob, ivec3(x + 3, y + 2, gz), v23);
image3d_st1(top_blob, ivec3(x, y + 3, gz), v30);
image3d_st1(top_blob, ivec3(x + 1, y + 3, gz), v31);
image3d_st1(top_blob, ivec3(x + 2, y + 3, gz), v32);
image3d_st1(top_blob, ivec3(x + 3, y + 3, gz), v33);
#else
ivec4 v_offset = gz * psc(outcstep) + y * psc(outw) + x + ivec4(0, 1, 2, 3) * psc(outw);
buffer_st1(top_blob_data, v_offset.r + 0, v00);
if (x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 1, v01);
if (x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 2, v02);
if (x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.r + 3, v03);
if (y + 1 < psc(outh)) buffer_st1(top_blob_data, v_offset.g + 0, v10);
if (y + 1 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 1, v11);
if (y + 1 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 2, v12);
if (y + 1 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.g + 3, v13);
if (y + 2 < psc(outh)) buffer_st1(top_blob_data, v_offset.b + 0, v20);
if (y + 2 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 1, v21);
if (y + 2 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 2, v22);
if (y + 2 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.b + 3, v23);
if (y + 3 < psc(outh)) buffer_st1(top_blob_data, v_offset.a + 0, v30);
if (y + 3 < psc(outh) && x + 1 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 1, v31);
if (y + 3 < psc(outh) && x + 2 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 2, v32);
if (y + 3 < psc(outh) && x + 3 < psc(outw)) buffer_st1(top_blob_data, v_offset.a + 3, v33);
#endif
}