File size: 3,899 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#pragma once
#include <acl/acl.h>
#include <aclnn/acl_meta.h>
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>

#define ACL_CHECK(x) do { \
    aclError __e = (x); \
    if (__e != ACL_ERROR_NONE) { \
        fprintf(stderr, "ACL error %d at %s:%d : %s\n", __e, __FILE__, __LINE__, #x); \
        std::abort(); \
    } \
} while(0)

#define ACLNN_CHECK(x) do { \
    aclnnStatus __e = (x); \
    if (__e != 0) { \
        const char* __msg = aclGetRecentErrMsg(); \
        fprintf(stderr, "aclnn error %d at %s:%d : %s\n  msg: %s\n", (int)__e, __FILE__, __LINE__, #x, __msg ? __msg : "(null)"); \
        std::abort(); \
    } \
} while(0)

// RAII wrapper for aclTensor: call aclDestroyTensor on dtor
struct AclTensorDel { void operator()(aclTensor* t) const { if (t) aclDestroyTensor(t); } };
using AclTensorPtr = std::unique_ptr<aclTensor, AclTensorDel>;

struct AclTensorListDel { void operator()(aclTensorList* t) const { if (t) aclDestroyTensorList(t); } };
using AclTensorListPtr = std::unique_ptr<aclTensorList, AclTensorListDel>;

struct AclIntArrayDel { void operator()(aclIntArray* a) const { if (a) aclDestroyIntArray(a); } };
using AclIntArrayPtr = std::unique_ptr<aclIntArray, AclIntArrayDel>;

// Create ACL tensor with explicit row-major shape (outermost leftmost) and element strides.
// NOTE: stride is in ELEMENTS, not bytes.
inline AclTensorPtr make_acl_tensor(void* data, aclDataType dt,
                                    const std::vector<int64_t>& shape,
                                    const std::vector<int64_t>& stride_elems,
                                    aclFormat fmt = ACL_FORMAT_ND) {
    int64_t n = (int64_t)shape.size();
    int64_t storage_len = 1;
    for (int i = 0; i < n; i++) storage_len += (shape[i] - 1) * stride_elems[i];
    aclTensor* t = aclCreateTensor(
        shape.data(), (uint64_t)n, dt,
        stride_elems.data(), 0, fmt,
        &storage_len, 1, data);
    return AclTensorPtr(t);
}

// Default contiguous strides for row-major tensor: stride[i] = product of shape[i+1..n-1]
inline std::vector<int64_t> contiguous_strides(const std::vector<int64_t>& shape) {
    int n = (int)shape.size();
    std::vector<int64_t> s(n);
    int64_t acc = 1;
    for (int i = n - 1; i >= 0; --i) {
        s[i] = acc;
        acc *= shape[i];
    }
    return s;
}

inline AclTensorPtr make_contig_tensor(void* data, aclDataType dt,
                                       const std::vector<int64_t>& shape,
                                       aclFormat fmt = ACL_FORMAT_ND) {
    return make_acl_tensor(data, dt, shape, contiguous_strides(shape), fmt);
}

inline size_t dtype_size(aclDataType dt) {
    switch (dt) {
        case ACL_FLOAT:     return 4;
        case ACL_FLOAT16:   return 2;
        case ACL_BF16:      return 2;
        case ACL_INT8:      return 1;
        case ACL_INT32:     return 4;
        case ACL_INT64:     return 8;
        default: return 0;
    }
}

// Device buffer RAII: allocates via aclrtMalloc, frees in dtor
struct DeviceBuffer {
    void*  ptr  = nullptr;
    size_t size = 0;

    DeviceBuffer() = default;
    explicit DeviceBuffer(size_t bytes) { alloc(bytes); }
    ~DeviceBuffer() { if (ptr) aclrtFree(ptr); }
    DeviceBuffer(const DeviceBuffer&) = delete;
    DeviceBuffer& operator=(const DeviceBuffer&) = delete;
    DeviceBuffer(DeviceBuffer&& o) noexcept : ptr(o.ptr), size(o.size) { o.ptr = nullptr; o.size = 0; }
    DeviceBuffer& operator=(DeviceBuffer&& o) noexcept {
        if (this != &o) { if (ptr) aclrtFree(ptr); ptr = o.ptr; size = o.size; o.ptr = nullptr; o.size = 0; }
        return *this;
    }

    void alloc(size_t bytes) {
        if (ptr) aclrtFree(ptr);
        ACL_CHECK(aclrtMalloc(&ptr, bytes, ACL_MEM_MALLOC_HUGE_FIRST));
        size = bytes;
    }
    void* get() { return ptr; }
    const void* get() const { return ptr; }
};