|
|
#pragma OPENCL EXTENSION cl_khr_fp16 : enable |
|
|
|
|
|
typedef char int8_t; |
|
|
typedef uchar uint8_t; |
|
|
typedef short int16_t; |
|
|
typedef ushort uint16_t; |
|
|
typedef int int32_t; |
|
|
typedef uint uint32_t; |
|
|
|
|
|
#define QK4_0 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct block_q4_0 |
|
|
{ |
|
|
half d; |
|
|
uint8_t qs[QK4_0 / 2]; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { |
|
|
global ushort * qs = ((global ushort *)xb + 1); |
|
|
float d1 = il ? (xb->d / 16.h) : xb->d; |
|
|
float d2 = d1 / 256.f; |
|
|
float md = -8.h * xb->d; |
|
|
ushort mask0 = il ? 0x00F0 : 0x000F; |
|
|
ushort mask1 = mask0 << 8; |
|
|
|
|
|
reg->s0 = d1 * (qs[0] & mask0) + md; |
|
|
reg->s1 = d2 * (qs[0] & mask1) + md; |
|
|
|
|
|
reg->s2 = d1 * (qs[1] & mask0) + md; |
|
|
reg->s3 = d2 * (qs[1] & mask1) + md; |
|
|
|
|
|
reg->s4 = d1 * (qs[2] & mask0) + md; |
|
|
reg->s5 = d2 * (qs[2] & mask1) + md; |
|
|
|
|
|
reg->s6 = d1 * (qs[3] & mask0) + md; |
|
|
reg->s7 = d2 * (qs[3] & mask1) + md; |
|
|
|
|
|
reg->s8 = d1 * (qs[4] & mask0) + md; |
|
|
reg->s9 = d2 * (qs[4] & mask1) + md; |
|
|
|
|
|
reg->sa = d1 * (qs[5] & mask0) + md; |
|
|
reg->sb = d2 * (qs[5] & mask1) + md; |
|
|
|
|
|
reg->sc = d1 * (qs[6] & mask0) + md; |
|
|
reg->sd = d2 * (qs[6] & mask1) + md; |
|
|
|
|
|
reg->se = d1 * (qs[7] & mask0) + md; |
|
|
reg->sf = d2 * (qs[7] & mask1) + md; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kernel void kernel_get_rows_f32( |
|
|
global void * src0, |
|
|
ulong offset0, |
|
|
global int * src1, |
|
|
ulong offset1, |
|
|
global float * dst, |
|
|
ulong offsetd, |
|
|
int ne00, |
|
|
ulong nb01, |
|
|
ulong nb02, |
|
|
ulong nb03, |
|
|
int ne10, |
|
|
ulong nb10, |
|
|
ulong nb11, |
|
|
ulong nb12, |
|
|
ulong nb1, |
|
|
ulong nb2, |
|
|
ulong nb3 |
|
|
) { |
|
|
src0 = (global void*)((global char*)src0 + offset0); |
|
|
src1 = (global int*)((global char*)src1 + offset1); |
|
|
dst = (global float*)((global char*)dst + offsetd); |
|
|
|
|
|
int i10 = get_group_id(0); |
|
|
int i11 = get_group_id(1); |
|
|
int i12 = get_group_id(2); |
|
|
|
|
|
int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; |
|
|
|
|
|
int i02 = i11; |
|
|
int i03 = i12; |
|
|
|
|
|
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { |
|
|
if (ind >= ne00) { |
|
|
return; |
|
|
} |
|
|
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = |
|
|
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; |
|
|
} |
|
|
} |
|
|
|
|
|
kernel void kernel_get_rows_f16( |
|
|
global void * src0, |
|
|
ulong offset0, |
|
|
global int * src1, |
|
|
ulong offset1, |
|
|
global float * dst, |
|
|
ulong offsetd, |
|
|
int ne00, |
|
|
ulong nb01, |
|
|
ulong nb02, |
|
|
ulong nb03, |
|
|
int ne10, |
|
|
ulong nb10, |
|
|
ulong nb11, |
|
|
ulong nb12, |
|
|
ulong nb1, |
|
|
ulong nb2, |
|
|
ulong nb3 |
|
|
) { |
|
|
src0 = (global void*)((global char*)src0 + offset0); |
|
|
src1 = (global int*)((global char*)src1 + offset1); |
|
|
dst = (global float*)((global char*)dst + offsetd); |
|
|
|
|
|
int i10 = get_group_id(0); |
|
|
int i11 = get_group_id(1); |
|
|
int i12 = get_group_id(2); |
|
|
|
|
|
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; |
|
|
|
|
|
int i02 = i11; |
|
|
int i03 = i12; |
|
|
|
|
|
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { |
|
|
if (ind >= ne00) { |
|
|
return; |
|
|
} |
|
|
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = |
|
|
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; |
|
|
} |
|
|
} |
|
|
|
|
|
kernel void kernel_get_rows_q4_0( |
|
|
global void * src0, |
|
|
ulong offset0, |
|
|
global int * src1, |
|
|
ulong offset1, |
|
|
global float * dst, |
|
|
ulong offsetd, |
|
|
int ne00, |
|
|
ulong nb01, |
|
|
ulong nb02, |
|
|
ulong nb03, |
|
|
int ne10, |
|
|
ulong nb10, |
|
|
ulong nb11, |
|
|
ulong nb12, |
|
|
ulong nb1, |
|
|
ulong nb2, |
|
|
ulong nb3 |
|
|
) { |
|
|
src0 = (global void*)((global char*)src0 + offset0); |
|
|
src1 = (global int*)((global char*)src1 + offset1); |
|
|
dst = (global float*)((global char*)dst + offsetd); |
|
|
|
|
|
const int NL = 2; |
|
|
|
|
|
int i10 = get_group_id(0); |
|
|
int i11 = get_group_id(1); |
|
|
int i12 = get_group_id(2); |
|
|
|
|
|
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; |
|
|
|
|
|
int i02 = i11; |
|
|
int i03 = i12; |
|
|
|
|
|
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { |
|
|
float16 temp; |
|
|
if (ind >= ne00) { |
|
|
return; |
|
|
} |
|
|
dequantize_q4_0_f32( |
|
|
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp); |
|
|
*(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp; |
|
|
} |
|
|
} |
|
|
|