|
|
#version 450 |
|
|
|
|
|
#include "types.comp" |
|
|
|
|
|
layout (push_constant) uniform parameter |
|
|
{ |
|
|
uint ne; |
|
|
uint batches; |
|
|
uint channels; |
|
|
uint dst_w; |
|
|
uint dst_h; |
|
|
uint src_w; |
|
|
uint src_h; |
|
|
uint knl_w; |
|
|
uint knl_h; |
|
|
int stride_x; |
|
|
int stride_y; |
|
|
int pad_x; |
|
|
int pad_y; |
|
|
int dilation_x; |
|
|
int dilation_y; |
|
|
} p; |
|
|
|
|
|
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; |
|
|
layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; |
|
|
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; |
|
|
|
|
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; |
|
|
|
|
|
FLOAT_TYPE conv_2d_dw_whcn(uint idx) { |
|
|
uint i0 = idx / p.dst_w; |
|
|
uint dst_x = idx - i0 * p.dst_w; |
|
|
uint i1 = i0 / p.dst_h; |
|
|
uint dst_y = i0 - i1 * p.dst_h; |
|
|
uint n = i1 / p.channels; |
|
|
uint c = i1 - n * p.channels; |
|
|
|
|
|
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; |
|
|
uint knl_i = c * p.knl_h * p.knl_w; |
|
|
|
|
|
FLOAT_TYPE sum = 0.0; |
|
|
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { |
|
|
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; |
|
|
if (src_y >= p.src_h) { |
|
|
continue; |
|
|
} |
|
|
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { |
|
|
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; |
|
|
if (src_x >= p.src_w) { |
|
|
continue; |
|
|
} |
|
|
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); |
|
|
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); |
|
|
sum = fma(v, k, sum); |
|
|
} |
|
|
} |
|
|
return sum; |
|
|
} |
|
|
|
|
|
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { |
|
|
uint i0 = idx / p.channels; |
|
|
uint c = idx - i0 * p.channels; |
|
|
uint i1 = i0 / p.dst_w; |
|
|
uint dst_x = i0 - i1 * p.dst_w; |
|
|
uint n = i1 / p.dst_h; |
|
|
uint dst_y = i1 - n * p.dst_h; |
|
|
|
|
|
uint src_i = n * p.channels * p.src_h * p.src_w; |
|
|
uint src_row = p.src_w * p.channels; |
|
|
uint knl_row = p.knl_w * p.channels; |
|
|
|
|
|
FLOAT_TYPE sum = 0.0; |
|
|
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { |
|
|
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; |
|
|
if (src_y >= p.src_h) { |
|
|
continue; |
|
|
} |
|
|
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { |
|
|
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; |
|
|
if (src_x >= p.src_w) { |
|
|
continue; |
|
|
} |
|
|
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); |
|
|
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); |
|
|
sum = fma(v, k, sum); |
|
|
} |
|
|
} |
|
|
return sum; |
|
|
} |
|
|
|
|
|
void main() { |
|
|
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; |
|
|
if (idx >= p.ne) { |
|
|
return; |
|
|
} |
|
|
|
|
|
FLOAT_TYPE result = |
|
|
#ifdef WHCN |
|
|
conv_2d_dw_whcn(idx); |
|
|
#else |
|
|
conv_2d_dw_cwhn(idx); |
|
|
#endif |
|
|
dst_data[idx] = D_TYPE(result); |
|
|
} |
|
|
|
|
|
|