|
|
#include "ggml.h" |
|
|
#include "mmf.hpp" |
|
|
|
|
|
void ggml_zdnn_mul_mat_f( |
|
|
const ggml_backend_zdnn_context * ctx, |
|
|
const ggml_tensor * src0, |
|
|
const ggml_tensor * src1, |
|
|
ggml_tensor * dst) { |
|
|
GGML_TENSOR_BINARY_OP_LOCALS; |
|
|
|
|
|
const enum ggml_type type = src0->type; |
|
|
|
|
|
GGML_ASSERT(ne0 == ne01); |
|
|
GGML_ASSERT(ne1 == ne11); |
|
|
GGML_ASSERT(ne2 == ne12); |
|
|
GGML_ASSERT(ne3 == ne13); |
|
|
|
|
|
|
|
|
GGML_ASSERT(nb00 == ggml_type_size(type)); |
|
|
GGML_ASSERT(nb10 == ggml_type_size(src1->type)); |
|
|
|
|
|
|
|
|
GGML_ASSERT(nb0 == sizeof(float)); |
|
|
GGML_ASSERT(nb0 <= nb1); |
|
|
GGML_ASSERT(nb1 <= nb2); |
|
|
GGML_ASSERT(nb2 <= nb3); |
|
|
|
|
|
const ggml_tensor * weights = src0; |
|
|
const ggml_tensor * inputs = src1; |
|
|
ggml_tensor * output = dst; |
|
|
|
|
|
ggml_backend_zdnn_buffer * weights_extra = (ggml_backend_zdnn_buffer *)weights->extra; |
|
|
ggml_backend_zdnn_buffer * inputs_extra = (ggml_backend_zdnn_buffer *)inputs->extra; |
|
|
ggml_backend_zdnn_buffer * output_extra = (ggml_backend_zdnn_buffer *)output->extra; |
|
|
ggml_backend_zdnn_buffer * bias_extra = (ggml_backend_zdnn_buffer *)output_extra->extra; |
|
|
|
|
|
const int64_t weights_rows = ne01; |
|
|
const int64_t weights_cols = ne00; |
|
|
const int64_t inputs_rows = ne11; |
|
|
const int64_t inputs_cols = ne10; |
|
|
|
|
|
assert(inputs_cols == weights_cols); |
|
|
|
|
|
const int64_t output_rows = ne1; |
|
|
const int64_t output_cols = ne0; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GGML_ASSERT(weights_extra->pre_tfm_desc.dim1 == weights->ne[0] && "weights_extra->pre_tfm_desc.dim1 must match weights->ne[0]"); |
|
|
GGML_ASSERT(weights_extra->pre_tfm_desc.dim2 == weights->ne[1] && "weights_extra->pre_tfm_desc.dim2 must match weights->ne[1]"); |
|
|
GGML_ASSERT(inputs_extra->pre_tfm_desc.dim1 == inputs->ne[0] && "inputs_extra->pre_tfm_desc.dim1 must match inputs->ne[0]"); |
|
|
GGML_ASSERT(inputs_extra->pre_tfm_desc.dim2 == inputs->ne[1] && "inputs_extra->pre_tfm_desc.dim2 must match inputs->ne[1]"); |
|
|
|
|
|
ZDNN_CHECK(zdnn_matmul_transpose_op(&inputs_extra->ztensor, &weights_extra->ztensor, &bias_extra->ztensor, |
|
|
false, true, MATMUL_OP_ADDITION, &output_extra->ztensor)); |
|
|
|
|
|
ZDNN_CHECK(zdnn_transform_origtensor(&output_extra->ztensor, output->data)); |
|
|
|
|
|
GGML_UNUSED(ctx); |
|
|
GGML_UNUSED(weights_rows); |
|
|
GGML_UNUSED(weights_cols); |
|
|
GGML_UNUSED(inputs_rows); |
|
|
GGML_UNUSED(inputs_cols); |
|
|
GGML_UNUSED(output_rows); |
|
|
GGML_UNUSED(output_cols); |
|
|
} |
|
|
|