| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | static void convolution_winograd_dot_pack4_msa(Mat& bottom_blob_tm, int outch, const Mat& kernel_tm, Mat& top_blob_tm, const Option& opt) |
| | { |
| | |
| |
|
| | const int tiles = bottom_blob_tm.w; |
| | const int batch = bottom_blob_tm.h; |
| | const int inch = bottom_blob_tm.c; |
| |
|
| | |
| | Mat bottom_blob_tm2; |
| | if (tiles >= 12) |
| | bottom_blob_tm2.create(12 * inch, tiles / 12 + (tiles % 12) / 8 + (tiles % 12 % 8) / 4 + (tiles % 12 % 4) / 2 + tiles % 12 % 2, batch, 16u, 4, opt.workspace_allocator); |
| | else if (tiles >= 8) |
| | bottom_blob_tm2.create(8 * inch, tiles / 8 + (tiles % 8) / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 4, opt.workspace_allocator); |
| | else if (tiles >= 4) |
| | bottom_blob_tm2.create(4 * inch, tiles / 4 + (tiles % 4) / 2 + tiles % 2, batch, 16u, 4, opt.workspace_allocator); |
| | else if (tiles >= 2) |
| | bottom_blob_tm2.create(2 * inch, tiles / 2 + tiles % 2, batch, 16u, 4, opt.workspace_allocator); |
| | else |
| | bottom_blob_tm2.create(1 * inch, tiles, batch, 16u, 4, opt.workspace_allocator); |
| |
|
| | #pragma omp parallel for num_threads(opt.num_threads) |
| | for (int r = 0; r < batch; r++) |
| | { |
| | Mat tm2 = bottom_blob_tm2.channel(r); |
| |
|
| | |
| | int i = 0; |
| | for (; i + 11 < tiles; i += 12) |
| | { |
| | float* tmpptr = tm2.row(i / 12); |
| |
|
| | const float* r0 = bottom_blob_tm; |
| |
|
| | r0 += (r * tiles + i) * 4; |
| |
|
| | for (int q = 0; q < inch; q++) |
| | { |
| | |
| | v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0); |
| | v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0); |
| | v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); |
| | v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); |
| | v4f32 _r4 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0); |
| | v4f32 _r5 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0); |
| | v4f32 _r6 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0); |
| | v4f32 _r7 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0); |
| | v4f32 _r8 = (v4f32)__msa_ld_w(r0 + 4 * 8, 0); |
| | v4f32 _r9 = (v4f32)__msa_ld_w(r0 + 4 * 9, 0); |
| | v4f32 _ra = (v4f32)__msa_ld_w(r0 + 4 * 10, 0); |
| | v4f32 _rb = (v4f32)__msa_ld_w(r0 + 4 * 11, 0); |
| |
|
| | v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2); |
| | v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2); |
| | v4i32 _r45r = __msa_ilvr_w((v4i32)_r5, (v4i32)_r4); |
| | v4i32 _r45l = __msa_ilvl_w((v4i32)_r5, (v4i32)_r4); |
| | v4i32 _r67r = __msa_ilvr_w((v4i32)_r7, (v4i32)_r6); |
| | v4i32 _r67l = __msa_ilvl_w((v4i32)_r7, (v4i32)_r6); |
| | v4i32 _r89r = __msa_ilvr_w((v4i32)_r9, (v4i32)_r8); |
| | v4i32 _r89l = __msa_ilvl_w((v4i32)_r9, (v4i32)_r8); |
| | v4i32 _rabr = __msa_ilvr_w((v4i32)_rb, (v4i32)_ra); |
| | v4i32 _rabl = __msa_ilvl_w((v4i32)_rb, (v4i32)_ra); |
| | v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r); |
| | v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r); |
| | v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l); |
| | v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l); |
| | v2i64 _r4567_0 = __msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r); |
| | v2i64 _r4567_1 = __msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r); |
| | v2i64 _r4567_2 = __msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l); |
| | v2i64 _r4567_3 = __msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l); |
| | v2i64 _r89ab_0 = __msa_ilvr_d((v2i64)_rabr, (v2i64)_r89r); |
| | v2i64 _r89ab_1 = __msa_ilvl_d((v2i64)_rabr, (v2i64)_r89r); |
| | v2i64 _r89ab_2 = __msa_ilvr_d((v2i64)_rabl, (v2i64)_r89l); |
| | v2i64 _r89ab_3 = __msa_ilvl_d((v2i64)_rabl, (v2i64)_r89l); |
| |
|
| | __msa_st_w((v4i32)_r0123_0, tmpptr, 0); |
| | __msa_st_w((v4i32)_r4567_0, tmpptr + 4, 0); |
| | __msa_st_w((v4i32)_r89ab_0, tmpptr + 4 * 2, 0); |
| | __msa_st_w((v4i32)_r0123_1, tmpptr + 4 * 3, 0); |
| | __msa_st_w((v4i32)_r4567_1, tmpptr + 4 * 4, 0); |
| | __msa_st_w((v4i32)_r89ab_1, tmpptr + 4 * 5, 0); |
| | __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 6, 0); |
| | __msa_st_w((v4i32)_r4567_2, tmpptr + 4 * 7, 0); |
| | __msa_st_w((v4i32)_r89ab_2, tmpptr + 4 * 8, 0); |
| | __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 9, 0); |
| | __msa_st_w((v4i32)_r4567_3, tmpptr + 4 * 10, 0); |
| | __msa_st_w((v4i32)_r89ab_3, tmpptr + 4 * 11, 0); |
| |
|
| | r0 += bottom_blob_tm.cstep * 4; |
| | tmpptr += 48; |
| | } |
| | } |
| | for (; i + 7 < tiles; i += 8) |
| | { |
| | float* tmpptr = tm2.row(i / 12 + (i % 12) / 8); |
| |
|
| | const float* r0 = bottom_blob_tm; |
| |
|
| | r0 += (r * tiles + i) * 4; |
| |
|
| | for (int q = 0; q < inch; q++) |
| | { |
| | |
| | v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0); |
| | v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0); |
| | v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); |
| | v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); |
| | v4f32 _r4 = (v4f32)__msa_ld_w(r0 + 4 * 4, 0); |
| | v4f32 _r5 = (v4f32)__msa_ld_w(r0 + 4 * 5, 0); |
| | v4f32 _r6 = (v4f32)__msa_ld_w(r0 + 4 * 6, 0); |
| | v4f32 _r7 = (v4f32)__msa_ld_w(r0 + 4 * 7, 0); |
| |
|
| | v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2); |
| | v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2); |
| | v4i32 _r45r = __msa_ilvr_w((v4i32)_r5, (v4i32)_r4); |
| | v4i32 _r45l = __msa_ilvl_w((v4i32)_r5, (v4i32)_r4); |
| | v4i32 _r67r = __msa_ilvr_w((v4i32)_r7, (v4i32)_r6); |
| | v4i32 _r67l = __msa_ilvl_w((v4i32)_r7, (v4i32)_r6); |
| | v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r); |
| | v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r); |
| | v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l); |
| | v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l); |
| | v2i64 _r4567_0 = __msa_ilvr_d((v2i64)_r67r, (v2i64)_r45r); |
| | v2i64 _r4567_1 = __msa_ilvl_d((v2i64)_r67r, (v2i64)_r45r); |
| | v2i64 _r4567_2 = __msa_ilvr_d((v2i64)_r67l, (v2i64)_r45l); |
| | v2i64 _r4567_3 = __msa_ilvl_d((v2i64)_r67l, (v2i64)_r45l); |
| |
|
| | __msa_st_w((v4i32)_r0123_0, tmpptr, 0); |
| | __msa_st_w((v4i32)_r4567_0, tmpptr + 4, 0); |
| | __msa_st_w((v4i32)_r0123_1, tmpptr + 4 * 2, 0); |
| | __msa_st_w((v4i32)_r4567_1, tmpptr + 4 * 3, 0); |
| | __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 4, 0); |
| | __msa_st_w((v4i32)_r4567_2, tmpptr + 4 * 5, 0); |
| | __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 6, 0); |
| | __msa_st_w((v4i32)_r4567_3, tmpptr + 4 * 7, 0); |
| |
|
| | r0 += bottom_blob_tm.cstep * 4; |
| | tmpptr += 32; |
| | } |
| | } |
| | for (; i + 3 < tiles; i += 4) |
| | { |
| | float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); |
| |
|
| | const float* r0 = bottom_blob_tm; |
| |
|
| | r0 += (r * tiles + i) * 4; |
| |
|
| | for (int q = 0; q < inch; q++) |
| | { |
| | |
| | v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0); |
| | v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0); |
| | v4f32 _r2 = (v4f32)__msa_ld_w(r0 + 4 * 2, 0); |
| | v4f32 _r3 = (v4f32)__msa_ld_w(r0 + 4 * 3, 0); |
| |
|
| | v4i32 _r01r = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r01l = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r23r = __msa_ilvr_w((v4i32)_r3, (v4i32)_r2); |
| | v4i32 _r23l = __msa_ilvl_w((v4i32)_r3, (v4i32)_r2); |
| | v2i64 _r0123_0 = __msa_ilvr_d((v2i64)_r23r, (v2i64)_r01r); |
| | v2i64 _r0123_1 = __msa_ilvl_d((v2i64)_r23r, (v2i64)_r01r); |
| | v2i64 _r0123_2 = __msa_ilvr_d((v2i64)_r23l, (v2i64)_r01l); |
| | v2i64 _r0123_3 = __msa_ilvl_d((v2i64)_r23l, (v2i64)_r01l); |
| |
|
| | __msa_st_w((v4i32)_r0123_0, tmpptr, 0); |
| | __msa_st_w((v4i32)_r0123_1, tmpptr + 4, 0); |
| | __msa_st_w((v4i32)_r0123_2, tmpptr + 4 * 2, 0); |
| | __msa_st_w((v4i32)_r0123_3, tmpptr + 4 * 3, 0); |
| |
|
| | r0 += bottom_blob_tm.cstep * 4; |
| | tmpptr += 16; |
| | } |
| | } |
| | for (; i + 1 < tiles; i += 2) |
| | { |
| | float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); |
| |
|
| | const float* r0 = bottom_blob_tm; |
| |
|
| | r0 += (r * tiles + i) * 4; |
| |
|
| | for (int q = 0; q < inch; q++) |
| | { |
| | |
| | v4f32 _r0 = (v4f32)__msa_ld_w(r0, 0); |
| | v4f32 _r1 = (v4f32)__msa_ld_w(r0 + 4, 0); |
| |
|
| | v4i32 _r01_0 = __msa_ilvr_w((v4i32)_r1, (v4i32)_r0); |
| | v4i32 _r01_1 = __msa_ilvl_w((v4i32)_r1, (v4i32)_r0); |
| |
|
| | __msa_st_w((v4i32)_r01_0, tmpptr, 0); |
| | __msa_st_w((v4i32)_r01_1, tmpptr + 4, 0); |
| |
|
| | r0 += bottom_blob_tm.cstep * 4; |
| | tmpptr += 8; |
| | } |
| | } |
| | for (; i < tiles; i++) |
| | { |
| | float* tmpptr = tm2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); |
| |
|
| | const float* r0 = bottom_blob_tm; |
| |
|
| | r0 += (r * tiles + i) * 4; |
| |
|
| | for (int q = 0; q < inch; q++) |
| | { |
| | v4f32 _val = (v4f32)__msa_ld_w(r0, 0); |
| | __msa_st_w((v4i32)_val, tmpptr, 0); |
| |
|
| | r0 += bottom_blob_tm.cstep * 4; |
| | tmpptr += 4; |
| | } |
| | } |
| | } |
| |
|
| | bottom_blob_tm = Mat(); |
| | |
| |
|
| | top_blob_tm.create(tiles, batch, outch, 16u, 4, opt.workspace_allocator); |
| |
|
| | #pragma omp parallel for num_threads(opt.num_threads) |
| | for (int p = 0; p < outch; p++) |
| | { |
| | float* output0_tm = top_blob_tm.channel(p); |
| |
|
| | const Mat kernel0_tm = kernel_tm.channel(p); |
| |
|
| | for (int r = 0; r < batch; r++) |
| | { |
| | const Mat bb2 = bottom_blob_tm2.channel(r); |
| |
|
| | int i = 0; |
| | for (; i + 11 < tiles; i += 12) |
| | { |
| | const float* r0 = bb2.row(i / 12); |
| | const float* k0 = kernel0_tm.row(r); |
| |
|
| | int nn = inch * 4; |
| |
|
| | v4f32 _sum0 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum1 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum2 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum3 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum4 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum5 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum6 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum7 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum8 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum9 = (v4f32)__msa_fill_w(0); |
| | v4f32 _suma = (v4f32)__msa_fill_w(0); |
| | v4f32 _sumb = (v4f32)__msa_fill_w(0); |
| |
|
| | for (int j = 0; j < nn; j++) |
| | { |
| | __builtin_prefetch(r0 + 48); |
| | __builtin_prefetch(k0 + 16); |
| | v4i32 _val0123 = __msa_ld_w(r0, 0); |
| | v4i32 _val4567 = __msa_ld_w(r0 + 4, 0); |
| | v4i32 _val89ab = __msa_ld_w(r0 + 8, 0); |
| | v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0); |
| | _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0); |
| | _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0); |
| | _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0); |
| | _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0); |
| | _sum4 = __msa_fmadd_w(_sum4, (v4f32)__msa_splati_w(_val4567, 0), _w0); |
| | _sum5 = __msa_fmadd_w(_sum5, (v4f32)__msa_splati_w(_val4567, 1), _w0); |
| | _sum6 = __msa_fmadd_w(_sum6, (v4f32)__msa_splati_w(_val4567, 2), _w0); |
| | _sum7 = __msa_fmadd_w(_sum7, (v4f32)__msa_splati_w(_val4567, 3), _w0); |
| | _sum8 = __msa_fmadd_w(_sum8, (v4f32)__msa_splati_w(_val89ab, 0), _w0); |
| | _sum9 = __msa_fmadd_w(_sum9, (v4f32)__msa_splati_w(_val89ab, 1), _w0); |
| | _suma = __msa_fmadd_w(_suma, (v4f32)__msa_splati_w(_val89ab, 2), _w0); |
| | _sumb = __msa_fmadd_w(_sumb, (v4f32)__msa_splati_w(_val89ab, 3), _w0); |
| |
|
| | r0 += 12; |
| | k0 += 4; |
| | } |
| |
|
| | __msa_st_w((v4i32)_sum0, output0_tm, 0); |
| | __msa_st_w((v4i32)_sum1, output0_tm + 4, 0); |
| | __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0); |
| | __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0); |
| | __msa_st_w((v4i32)_sum4, output0_tm + 4 * 4, 0); |
| | __msa_st_w((v4i32)_sum5, output0_tm + 4 * 5, 0); |
| | __msa_st_w((v4i32)_sum6, output0_tm + 4 * 6, 0); |
| | __msa_st_w((v4i32)_sum7, output0_tm + 4 * 7, 0); |
| | __msa_st_w((v4i32)_sum8, output0_tm + 4 * 8, 0); |
| | __msa_st_w((v4i32)_sum9, output0_tm + 4 * 9, 0); |
| | __msa_st_w((v4i32)_suma, output0_tm + 4 * 10, 0); |
| | __msa_st_w((v4i32)_sumb, output0_tm + 4 * 11, 0); |
| |
|
| | output0_tm += 4 * 12; |
| | } |
| | for (; i + 7 < tiles; i += 8) |
| | { |
| | const float* r0 = bb2.row(i / 12 + (i % 12) / 8); |
| | const float* k0 = kernel0_tm.row(r); |
| |
|
| | int nn = inch * 4; |
| |
|
| | v4f32 _sum0 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum1 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum2 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum3 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum4 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum5 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum6 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum7 = (v4f32)__msa_fill_w(0); |
| |
|
| | for (int j = 0; j < nn; j++) |
| | { |
| | __builtin_prefetch(r0 + 32); |
| | __builtin_prefetch(k0 + 16); |
| | v4i32 _val0123 = __msa_ld_w(r0, 0); |
| | v4i32 _val4567 = __msa_ld_w(r0 + 4, 0); |
| | v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0); |
| | _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0); |
| | _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0); |
| | _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0); |
| | _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0); |
| | _sum4 = __msa_fmadd_w(_sum4, (v4f32)__msa_splati_w(_val4567, 0), _w0); |
| | _sum5 = __msa_fmadd_w(_sum5, (v4f32)__msa_splati_w(_val4567, 1), _w0); |
| | _sum6 = __msa_fmadd_w(_sum6, (v4f32)__msa_splati_w(_val4567, 2), _w0); |
| | _sum7 = __msa_fmadd_w(_sum7, (v4f32)__msa_splati_w(_val4567, 3), _w0); |
| |
|
| | r0 += 8; |
| | k0 += 4; |
| | } |
| |
|
| | __msa_st_w((v4i32)_sum0, output0_tm, 0); |
| | __msa_st_w((v4i32)_sum1, output0_tm + 4, 0); |
| | __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0); |
| | __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0); |
| | __msa_st_w((v4i32)_sum4, output0_tm + 4 * 4, 0); |
| | __msa_st_w((v4i32)_sum5, output0_tm + 4 * 5, 0); |
| | __msa_st_w((v4i32)_sum6, output0_tm + 4 * 6, 0); |
| | __msa_st_w((v4i32)_sum7, output0_tm + 4 * 7, 0); |
| |
|
| | output0_tm += 4 * 8; |
| | } |
| | for (; i + 3 < tiles; i += 4) |
| | { |
| | const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4); |
| | const float* k0 = kernel0_tm.row(r); |
| |
|
| | int nn = inch * 4; |
| |
|
| | v4f32 _sum0 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum1 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum2 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum3 = (v4f32)__msa_fill_w(0); |
| |
|
| | for (int j = 0; j < nn; j++) |
| | { |
| | __builtin_prefetch(r0 + 16); |
| | __builtin_prefetch(k0 + 16); |
| | v4i32 _val0123 = __msa_ld_w(r0, 0); |
| | v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0); |
| | _sum0 = __msa_fmadd_w(_sum0, (v4f32)__msa_splati_w(_val0123, 0), _w0); |
| | _sum1 = __msa_fmadd_w(_sum1, (v4f32)__msa_splati_w(_val0123, 1), _w0); |
| | _sum2 = __msa_fmadd_w(_sum2, (v4f32)__msa_splati_w(_val0123, 2), _w0); |
| | _sum3 = __msa_fmadd_w(_sum3, (v4f32)__msa_splati_w(_val0123, 3), _w0); |
| |
|
| | r0 += 4; |
| | k0 += 4; |
| | } |
| |
|
| | __msa_st_w((v4i32)_sum0, output0_tm, 0); |
| | __msa_st_w((v4i32)_sum1, output0_tm + 4, 0); |
| | __msa_st_w((v4i32)_sum2, output0_tm + 4 * 2, 0); |
| | __msa_st_w((v4i32)_sum3, output0_tm + 4 * 3, 0); |
| |
|
| | output0_tm += 4 * 4; |
| | } |
| | for (; i + 1 < tiles; i += 2) |
| | { |
| | const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2); |
| | const float* k0 = kernel0_tm.row(r); |
| |
|
| | int nn = inch * 4; |
| |
|
| | v4f32 _sum0 = (v4f32)__msa_fill_w(0); |
| | v4f32 _sum1 = (v4f32)__msa_fill_w(0); |
| |
|
| | for (int j = 0; j < nn; j++) |
| | { |
| | __builtin_prefetch(r0 + 8); |
| | __builtin_prefetch(k0 + 16); |
| | v4f32 _val0 = __msa_fill_w_f32(*r0++); |
| | v4f32 _val1 = __msa_fill_w_f32(*r0++); |
| | v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0); |
| | _sum0 = __msa_fmadd_w(_sum0, _val0, _w0); |
| | _sum1 = __msa_fmadd_w(_sum1, _val1, _w0); |
| |
|
| | k0 += 4; |
| | } |
| |
|
| | __msa_st_w((v4i32)_sum0, output0_tm, 0); |
| | __msa_st_w((v4i32)_sum1, output0_tm + 4, 0); |
| |
|
| | output0_tm += 4 * 2; |
| | } |
| | for (; i < tiles; i++) |
| | { |
| | const float* r0 = bb2.row(i / 12 + (i % 12) / 8 + (i % 12 % 8) / 4 + (i % 12 % 4) / 2 + i % 12 % 2); |
| | const float* k0 = kernel0_tm.row(r); |
| |
|
| | int nn = inch * 4; |
| |
|
| | v4f32 _sum = (v4f32)__msa_fill_w(0); |
| |
|
| | for (int j = 0; j < nn; j++) |
| | { |
| | __builtin_prefetch(r0 + 4); |
| | __builtin_prefetch(k0 + 16); |
| | v4f32 _val0 = __msa_fill_w_f32(*r0++); |
| | v4f32 _w0 = (v4f32)__msa_ld_w(k0, 0); |
| | _sum = __msa_fmadd_w(_sum, _val0, _w0); |
| |
|
| | k0 += 4; |
| | } |
| |
|
| | __msa_st_w((v4i32)_sum, output0_tm, 0); |
| |
|
| | output0_tm += 4; |
| | } |
| | } |
| | } |
| | } |
| |
|