File size: 5,618 Bytes
be903e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | // yala is pleased to support the open source community by making ncnn available.
//
//
// Copyright (C) 2022 yala <zhaojunchao@loongson.cn>;<junchao82@qq.com>. All rights reserved.
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "cast_loongarch.h"
#if __loongarch_sx
#include <lsxintrin.h>
#endif // __loongarch_sx
namespace ncnn {
Cast_loongarch::Cast_loongarch()
{
support_packing = true;
}
int Cast_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
if (type_from == type_to)
{
top_blob = bottom_blob;
return 0;
}
int w = bottom_blob.w;
int h = bottom_blob.h;
int d = bottom_blob.d;
int channels = bottom_blob.c;
int dims = bottom_blob.dims;
size_t elemsize = bottom_blob.elemsize;
int elempack = bottom_blob.elempack;
size_t out_elemsize = elemsize;
if (type_to == 1)
{
if (type_from == 3)
{
Cast::forward(bottom_blob, top_blob, opt);
}
// float32
out_elemsize = 4 * elempack;
}
else if (type_to == 2)
{
// float16
out_elemsize = 2 * elempack;
}
else if (type_to == 3)
{
// int8
out_elemsize = elempack;
}
else if (type_to == 4)
{
// bfloat16
out_elemsize = 2 * elempack;
}
if (dims == 1)
{
top_blob.create(w, out_elemsize, elempack, opt.blob_allocator);
}
else if (dims == 2)
{
top_blob.create(w, h, out_elemsize, elempack, opt.blob_allocator);
}
else if (dims == 3)
{
top_blob.create(w, h, channels, out_elemsize, elempack, opt.blob_allocator);
}
else if (dims == 4)
{
top_blob.create(w, h, d, channels, out_elemsize, elempack, opt.blob_allocator);
}
if (top_blob.empty())
return -100;
int size = w * h * d * elempack;
if (type_from == 1 && type_to == 2)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
const float* ptr = bottom_blob.channel(q);
unsigned short* outptr = top_blob.channel(q);
int i = 0;
#if __loongarch_sx
for (; i + 7 < size; i += 8)
{
__builtin_prefetch(ptr + 16);
__m128 _p0 = (__m128)__lsx_vld(ptr, 0);
__m128 _p1 = (__m128)__lsx_vld(ptr + 4, 0);
__m128i _p = __lsx_vfcvt_h_s(_p1, _p0);
__lsx_vst(_p, outptr, 0);
ptr += 8;
outptr += 8;
}
#endif // __loongarch_sx
for (; i < size; i++)
{
*outptr = float32_to_float16(*ptr);
outptr++;
ptr++;
}
}
}
if (type_from == 2 && type_to == 1)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
const unsigned short* ptr = bottom_blob.channel(q);
float* outptr = top_blob.channel(q);
int i = 0;
#if __loongarch_sx
for (; i + 7 < size; i += 8)
{
__builtin_prefetch(ptr + 16);
__m128i _p = __lsx_vld(ptr, 0);
__m128 _p0 = __lsx_vfcvtl_s_h(_p);
__m128 _p1 = __lsx_vfcvth_s_h(_p);
__lsx_vst(_p0, outptr, 0);
__lsx_vst(_p1, outptr + 4, 0);
ptr += 8;
outptr += 8;
}
#endif // __loongarch_sx
for (; i < size; i++)
{
*outptr = float16_to_float32(*ptr);
outptr++;
ptr++;
}
}
}
if (type_from == 3 && type_to == 1)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
const signed char* ptr = bottom_blob.channel(q);
float* outptr = top_blob.channel(q);
for (int i = 0; i < size; i++)
{
outptr[i] = (float)ptr[i];
}
}
}
if (type_from == 4 && type_to == 1)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
const unsigned short* ptr = bottom_blob.channel(q);
float* outptr = top_blob.channel(q);
int i = 0;
for (; i < size; i++)
{
*outptr = bfloat16_to_float32(*ptr);
outptr++;
ptr++;
}
}
}
if (type_from == 1 && type_to == 4)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
const float* ptr = bottom_blob.channel(q);
unsigned short* outptr = top_blob.channel(q);
int i = 0;
for (; i < size; i++)
{
*outptr = float32_to_bfloat16(*ptr);
outptr++;
ptr++;
}
}
}
return 0;
}
} // namespace ncnn
|