File size: 1,448 Bytes
d1d4335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/*

 * Copyright (c) Meta Platforms, Inc. and affiliates.

 * All rights reserved.

 *

 * This source code is licensed under the BSD-style license found in the

 * LICENSE file in the root directory of this source tree.

 */

#pragma once
#include <chrono>
#include <functional>
#include <random>
#include <vector>

#include "fbgemm/FbgemmBuild.h"
#include "fbgemm/FbgemmSparse.h"
#include "fbgemm/UtilsAvx2.h"
#include "fbgemm/spmmUtilsAvx2.h"

namespace fbgemm {

FBGEMM_API void sparseDenseMMRef(

    int M,

    int N,

    const int* row_ptr,

    const int* col_idx,

    const float* values,

    const float* B,

    int ldb,

    float* C,

    int ldc,

    bool accum = false);

template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
FBGEMM_API void sparseDenseInt8MMRef(

    int N,

    const std::unique_ptr<BCSRMatrix<>>& bcsr,

    const uint8_t* B,

    int ldb,

    int32_t* C_i32,

    uint8_t* C_u8,

    int ldc,

    trRequantizationParams_t& rParams,

    bool accum = false,

    int thread_id = 0,

    int num_threads = 1);

template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
FBGEMM_API void trRequantizeRef(

    uint8_t* out,

    const int32_t* inp,

    const block_type_t& block,

    int ld_out,

    int ld_in,

    const trRequantizationParams_t& r);

// Get matrix shapes of interest
FBGEMM_API std::vector<std::vector<int>> getSparseMatrixShapes();

} // namespace fbgemm