File size: 13,853 Bytes
568f19a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
/*******************************************************************************

* Copyright 2024 Intel Corporation

*

* Licensed under the Apache License, Version 2.0 (the "License");

* you may not use this file except in compliance with the License.

* You may obtain a copy of the License at

*

*     http://www.apache.org/licenses/LICENSE-2.0

*

* Unless required by applicable law or agreed to in writing, software

* distributed under the License is distributed on an "AS IS" BASIS,

* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

* See the License for the specific language governing permissions and

* limitations under the License.

*******************************************************************************/

/// @file
/// ukernel C API

#ifndef ONEAPI_DNNL_DNNL_UKERNEL_H
#define ONEAPI_DNNL_DNNL_UKERNEL_H

#include "oneapi/dnnl/dnnl.h"
#include "oneapi/dnnl/dnnl_ukernel_types.h"

#ifdef __cplusplus
extern "C" {
#endif

/// @addtogroup dnnl_api
/// @{

/// @addtogroup dnnl_api_ukernel
/// @{

#ifdef DNNL_EXPERIMENTAL_UKERNEL

/// Creates a ukernel attributes memory storage.
///
/// @param attr_params Output ukernel attributes memory storage.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_create(

        dnnl_ukernel_attr_params_t *attr_params);

/// Sets post-operations arguments to a storage.
///
/// @param attr_params Memory pointers storage object.
/// @param post_ops_args A pointer to pointers of post_ops storages. Expected to
///     be packed together.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_post_ops_args(

        dnnl_ukernel_attr_params_t attr_params, const void **post_ops_args);

/// Sets tensor A scales argument to a storage.
///
/// @param attr_params Memory pointers storage object.
/// @param a_scales Pointer to the scales storage.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_A_scales(

        dnnl_ukernel_attr_params_t attr_params, const void *a_scales);

/// Sets tensor B scales argument to a storage.
///
/// If `dnnl_brgemm_set_B_scales` used mask of 2, then at least N values of
/// selected data type are expected.
///
/// @param attr_params Memory pointers storage object.
/// @param b_scales Pointer to the scales storage.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_B_scales(

        dnnl_ukernel_attr_params_t attr_params, const void *b_scales);

/// Sets tensor D scales argument to a storage.
///
/// @param attr_params Memory pointers storage object.
/// @param d_scales Pointer to the scales storage.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_set_D_scales(

        dnnl_ukernel_attr_params_t attr_params, const void *d_scales);

/// Destroys a ukernel attributes memory storage.
///
/// @param attr_params Memory pointers storage object to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_ukernel_attr_params_destroy(

        dnnl_ukernel_attr_params_t attr_params);

/// @addtogroup dnnl_api_ukernel_brgemm
/// @{

/// Creates a BRGeMM ukernel object. Operates by the following formula:
/// `C = [A x B]`.
///
/// @param brgemm Output BRGeMM ukernel object.
/// @param M Dimension M of tensor A.
/// @param N Dimension N of tensor B.
/// @param K Dimension K of tensors A and B.
/// @param batch_size Number of batches to process.
/// @param lda Leading dimension of tensor A.
/// @param ldb Leading dimension of tensor B.
/// @param ldc Leading dimension of tensor C.
/// @param a_dt Data type of tensor A.
/// @param b_dt Data type of tensor B.
/// @param c_dt Data type of tensor C. Must be dnnl_f32.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_create(dnnl_brgemm_t *brgemm, dnnl_dim_t M,

        dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t batch_size, dnnl_dim_t lda,

        dnnl_dim_t ldb, dnnl_dim_t ldc, dnnl_data_type_t a_dt,

        dnnl_data_type_t b_dt, dnnl_data_type_t c_dt);

/// Sets adding an intermediate result to the output tensor C instead of
/// writing: `C += [A x B]`.
///
/// @param brgemm BRGeMM ukernel object.
/// @param add_C Value to indicate addition. Can be `0` to skip addition, and
///     `1` to apply addition.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_set_add_C(dnnl_brgemm_t brgemm, int add_C);

/// Sets post-operations to a BRGeMM ukernel object: `D = post-operations(C)`.
///
/// Post-operations applies if one of the following holds:
/// * Non-empty attributes are specified.
/// * Output data type `d_dt` is different from accumulation data type `c_dt`.
///
/// If any of conditions happens, the final call of the accumulation chain
/// must be `dnnl_brgemm_execute_postops`, and `dnnl_brgemm_execute`, otherwise.
///
/// @param brgemm BRGeMM ukernel object.
/// @param ldd Leading dimension of tensor D.
/// @param d_dt Data type of tensor D.
/// @param post_ops Primitive post operations attribute to extend the kernel
///     operations.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_set_post_ops(dnnl_brgemm_t brgemm,

        dnnl_dim_t ldd, dnnl_data_type_t d_dt, const_dnnl_post_ops_t post_ops);

/// Sets tensor A scales mask to a BRGeMM ukernel object.
///
/// For quantization flavor tensor A scales apply to accumulation buffer once C
/// is ready.
///
/// @param brgemm BRGeMM ukernel object.
/// @param a_scale_mask Tensor A scale mask. Can be `0` only.
dnnl_status_t DNNL_API dnnl_brgemm_set_A_scales(

        dnnl_brgemm_t brgemm, int a_scale_mask);

/// Sets tensor B scales mask to a BRGeMM ukernel object.
///
/// For quantization flavor tensor B scales apply to accumulation buffer once C
/// is ready.
///
/// @param brgemm BRGeMM ukernel object.
/// @param b_scale_mask Tensor B scale mask. Can be `0` and `2` only.
dnnl_status_t DNNL_API dnnl_brgemm_set_B_scales(

        dnnl_brgemm_t brgemm, int b_scale_mask);

/// Sets tensor D scales mask to a BRGeMM ukernel object.
///
/// For quantization flavor tensor D scales apply after all post-ops are
/// applied.
///
/// @param brgemm BRGeMM ukernel object.
/// @param d_scale_mask Tensor D scale mask. Can be `0` only.
dnnl_status_t DNNL_API dnnl_brgemm_set_D_scales(

        dnnl_brgemm_t brgemm, int d_scale_mask);

/// Finalizes initialization of a BRGeMM ukernel object.
///
/// This step is mandatory to query information from the object.
///
/// @param brgemm Output BRGeMM ukernel object.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_finalize(dnnl_brgemm_t brgemm);

/// Returns the packing type expected by a tensor B of a BRGeMM ukernel object.
///
/// @param brgemm BRGeMM ukernel object.
/// @param pack_type Output packing type. Can be `dnnl_brgemm_no_pack` if
///     packing is not expected, and `dnnl_brgemm_pack_32`, otherwise.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_get_B_pack_type(

        const_dnnl_brgemm_t brgemm, dnnl_pack_type_t *pack_type);

/// Returns the size of a scratchpad memory needed for the BRGeMM ukernel
/// object.
///
/// @param brgemm BRGeMM ukernel object.
/// @param size Output size of a buffer required for the BRGeMM ukernel object.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_get_scratchpad_size(

        const_dnnl_brgemm_t brgemm, size_t *size);

/// Returns the flag indicating when the call to `dnnl_brgemm_execute_postops`
/// is valid.
///
/// @param brgemm BRGeMM ukernel object.
/// @param valid The flag indicating if `dnnl_brgemm_execute_postops` is valid
///     for a given ukernel object. `1` is for valid and `0`, otherwise.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_is_execute_postops_valid(

        const_dnnl_brgemm_t brgemm, int *valid);

/// Initializes the hardware-specific context. If no initialization required,
/// returns the success status.
///
/// @param brgemm BRGeMM ukernel object.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_set_hw_context(const_dnnl_brgemm_t brgemm);

/// Releases the hardware-specific context. Must be used after all the execution
/// calls to BRGeMM ukernel objects.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_release_hw_context();

/// Generates an executable part of BRGeMM ukernel object.
/// @param brgemm BRGeMM ukernel object.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_generate(dnnl_brgemm_t brgemm);

/// Executes a BRGeMM ukernel object.
///
/// @param brgemm BRGeMM ukernel object.
/// @param A_ptr Base pointer to a tensor A.
/// @param B_ptr Base pointer to a tensor B.
/// @param A_B_offsets Pointer to the set of tensor A and tensor B offsets for
///     each batch; the set must be contiguous in memory. Single batch should
///     supply offsets for both tensors A and B simultaneously. The number of
///     batches must coincide with the `batch_size` value passed at the creation
///     stage.
/// @param C_ptr Pointer to a tensor C (accumulation buffer).
/// @param scratchpad_ptr Pointer to a scratchpad buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_execute(const_dnnl_brgemm_t brgemm,

        const void *A_ptr, const void *B_ptr, const dnnl_dim_t *A_B_offsets,

        void *C_ptr, void *scratchpad_ptr);

/// Executes a BRGeMM ukernel object with post operations.
///
/// @param brgemm BRGeMM ukernel object.
/// @param A Base pointer to a tensor A.
/// @param B Base pointer to a tensor B.
/// @param A_B_offsets Pointer to a set of tensor A and tensor B offsets for
///     each batch. A set must be contiguous in memory. A single batch should
///     supply offsets for both tensors A and B simultaneously. The number of
///     batches must coincide with the `batch_size` value passed at the creation
///     stage.
/// @param C_ptr Pointer to a tensor C (accumulation buffer).
/// @param D_ptr Pointer to a tensor D (output buffer).
/// @param scratchpad_ptr Pointer to a scratchpad buffer.
/// @param attr_params Ukernel attributes memory storage.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_execute_postops(const_dnnl_brgemm_t brgemm,

        const void *A, const void *B, const dnnl_dim_t *A_B_offsets,

        const void *C_ptr, void *D_ptr, void *scratchpad_ptr,

        const_dnnl_ukernel_attr_params_t attr_params);

/// Destroys a BRGeMM ukernel object.
///
/// @param brgemm BRGeMM ukernel object to destroy.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_brgemm_destroy(dnnl_brgemm_t brgemm);

/// Creates a transform object.
///
/// @param transform Output transform object.
/// @param K Dimension K.
/// @param N Dimension N.
/// @param in_pack_type Input packing type. Must be one of
///     `dnnl_pack_type_no_trans`, or `dnnl_pack_type_trans`.
/// @param in_ld Input leading dimension.
/// @param out_ld Output leading dimension. When packing data, it specifies a
///     block by N dimension.
/// @param in_dt Input data type.
/// @param out_dt Output data type.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_transform_create(dnnl_transform_t *transform,

        dnnl_dim_t K, dnnl_dim_t N, dnnl_pack_type_t in_pack_type,

        dnnl_dim_t in_ld, dnnl_dim_t out_ld, dnnl_data_type_t in_dt,

        dnnl_data_type_t out_dt);

/// Generates an executable part of transform object.
/// @param transform Transform object.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_transform_generate(dnnl_transform_t transform);

/// Executes a transform object.
///
/// @param transform Transform object.
/// @param in_ptr Pointer to an input buffer.
/// @param out_ptr Pointer to an output buffer.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_transform_execute(

        const_dnnl_transform_t transform, const void *in_ptr, void *out_ptr);

/// Destroys a transform object.
///
/// @param transform Transform object.
/// @returns #dnnl_success on success and a status describing the error
///     otherwise.
dnnl_status_t DNNL_API dnnl_transform_destroy(dnnl_transform_t transform);

/// @} dnnl_api_ukernel_brgemm

#endif

/// @} dnnl_api_ukernel

/// @} dnnl_api

#ifdef __cplusplus
}
#endif

#endif /* ONEAPI_DNNL_DNNL_UKERNEL_H */