File size: 5,142 Bytes
4d35814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

//------------------------------------------------------------------------------
// add
//------------------------------------------------------------------------------

// general-purpose kernel for addition of two tensors
// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
// cons: not very efficient
kernel void kernel_add(
        global char * src0,
        ulong  offset0,
        global char * src1,
        ulong  offset1,
        global char * dst,
        ulong  offsetd,
        int   ne00,
        int   ne01,
        int   ne02,
        int   ne03,
        ulong nb00,
        ulong nb01,
        ulong nb02,
        ulong nb03,
        int   ne10,
        int   ne11,
        int   ne12,
        int   ne13,
        ulong nb10,
        ulong nb11,
        ulong nb12,
        ulong nb13,
        int   ne0,
        int   ne1,
        int   ne2,
        int   ne3,
        ulong nb0,
        ulong nb1,
        ulong nb2,
        ulong nb3
) {
    src0 = src0 + offset0;
    src1 = src1 + offset1;
    dst = dst + offsetd;

    int i03 = get_group_id(2);
    int i02 = get_group_id(1);
    int i01 = get_group_id(0);

    int i13 = i03 % ne13;
    int i12 = i02 % ne12;
    int i11 = i01 % ne11;

    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;

    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
        const int i10 = i0 % ne10;
        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
    }
}

// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
        global float4 * src0,
        ulong  offset0,
        global float4 * src1,
        ulong  offset1,
        global float4 * dst,
        ulong  offsetd,
        int ne
) {
    src0 = (global float4*)((global char*)src0 + offset0);
    src1 = (global float4*)((global char*)src1 + offset1);
    dst = (global float4*)((global char*)dst + offsetd);

    // This performs better than using %.
    uint gid = get_global_id(0);
    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
    dst[gid] = src0[gid] + src1[idx1];
}

kernel void kernel_add_f16(
        global char * src0,
        ulong  offset0,
        global char * src1,
        ulong  offset1,
        global char * dst,
        ulong  offsetd,
        int   ne00,
        int   ne01,
        int   ne02,
        int   ne03,
        ulong nb00,
        ulong nb01,
        ulong nb02,
        ulong nb03,
        int   ne10,
        int   ne11,
        int   ne12,
        int   ne13,
        ulong nb10,
        ulong nb11,
        ulong nb12,
        ulong nb13,
        int   ne0,
        int   ne1,
        int   ne2,
        int   ne3,
        ulong nb0,
        ulong nb1,
        ulong nb2,
        ulong nb3,
        int type_src0,
        int type_src1
) {
    src0 = src0 + offset0;
    src1 = src1 + offset1;
    dst = dst + offsetd;

    int i03 = get_group_id(2);
    int i02 = get_group_id(1);
    int i01 = get_group_id(0);

    int i13 = i03 % ne13;
    int i12 = i02 % ne12;
    int i11 = i01 % ne11;

    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;

    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
        const int i10 = i0 % ne10;

        half v0, v1;
        if (type_src0 == 1) {
            v0 = convert_half(*((global float *)(src0_ptr + i0*nb00)));
        } else {
            v0 = *((global half *)(src0_ptr + i0*nb00));
        }

        if (type_src1 == 1) {
            v1 = convert_half(*((global float *)(src1_ptr + i10*nb10)));
        } else {
            v1 = *((global half *)(src1_ptr + i10*nb10));
        }

        *((global half *)(dst_ptr + i0*nb0)) = v0 + v1;
    }
}

kernel void kernel_add_row_f16(
        global char * src0,
        ulong  offset0,
        global char * src1,
        ulong  offset1,
        global half4 * dst,
        ulong  offsetd,
        int ne,
        int type_src0,
        int type_src1
) {
    dst = (global half4*)((global char*)dst + offsetd);

    // This performs better than using %.
    uint gid = get_global_id(0);
    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne

    half4 v0, v1;
    if (type_src0 == 1) {
        global float4* src0_f32 = (global float4*)((global char*)src0 + offset0);
        v0 = convert_half4(src0_f32[gid]);
    } else {
        global half4* src0_f16 = (global half4*)((global char*)src0 + offset0);
        v0 = src0_f16[gid];
    }

    if (type_src1 == 1) {
        global float4* src1_f32 = (global float4*)((global char*)src1 + offset1);
        v1 = convert_half4(src1_f32[idx1]);
    } else {
        global half4* src1_f16 = (global half4*)((global char*)src1 + offset1);
        v1 = src1_f16[idx1];
    }

    dst[gid] = v0 + v1;
}