qqc1989 commited on
Commit
e01c900
·
verified ·
1 Parent(s): 815b9e2

Delete python/axengine

Browse files
python/axengine/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- # thanks to community contributors list below:
9
- # zylo117: https://github.com/zylo117, first implementation of the axclrt backend
10
-
11
- from ._providers import axengine_provider_name, axclrt_provider_name
12
- from ._providers import get_all_providers, get_available_providers
13
-
14
- # check if axclrt is installed, or is a supported chip(e.g. AX650, AX620E etc.)
15
- _available_providers = get_available_providers()
16
- if not _available_providers:
17
- raise ImportError(
18
- f"No providers found. Please make sure you have installed one of the following: {get_all_providers()}")
19
- print("[INFO] Available providers: ", _available_providers)
20
-
21
- from ._node import NodeArg
22
- from ._session import SessionOptions, InferenceSession
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_axclrt.py DELETED
@@ -1,372 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
- # first implementation of AXCLRTSession contributed by zylo117
8
-
9
- import atexit
10
- import os
11
- import time
12
- from typing import Any, Sequence
13
-
14
- import ml_dtypes as mldt
15
- import numpy as np
16
-
17
- from ._axclrt_capi import axclrt_cffi, axclrt_lib
18
- from ._axclrt_types import VNPUType, ModelType
19
- from ._base_session import Session, SessionOptions
20
- from ._node import NodeArg
21
-
22
- __all__: ["AXCLRTSession"]
23
-
24
- _is_axclrt_initialized = False
25
- _is_axclrt_engine_initialized = False
26
-
27
-
28
- def _transform_dtype(dtype):
29
- if dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT8):
30
- return np.dtype(np.uint8)
31
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT8):
32
- return np.dtype(np.int8)
33
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT16):
34
- return np.dtype(np.uint16)
35
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT16):
36
- return np.dtype(np.int16)
37
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_UINT32):
38
- return np.dtype(np.uint32)
39
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_INT32):
40
- return np.dtype(np.int32)
41
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_FP32):
42
- return np.dtype(np.float32)
43
- elif dtype == axclrt_cffi.cast("axclrtEngineDataType", axclrt_lib.AXCL_DATA_TYPE_BF16):
44
- return np.dtype(mldt.bfloat16)
45
- else:
46
- raise ValueError(f"Unsupported data type '{dtype}'.")
47
-
48
- def _initialize_axclrt():
49
- global _is_axclrt_initialized
50
- ret = axclrt_lib.axclInit([])
51
- if ret != 0:
52
- raise RuntimeError(f"Failed to initialize axcl runtime. {ret}.")
53
- _is_axclrt_initialized = True
54
-
55
-
56
- def _finalize_axclrt():
57
- global _is_axclrt_initialized, _is_axclrt_engine_initialized
58
- if _is_axclrt_engine_initialized:
59
- axclrt_lib.axclrtEngineFinalize()
60
- _is_axclrt_engine_initialized = False
61
- if _is_axclrt_initialized:
62
- axclrt_lib.axclFinalize()
63
- _is_axclrt_initialized = False
64
-
65
-
66
- _initialize_axclrt()
67
- atexit.register(_finalize_axclrt)
68
-
69
-
70
- def _get_vnpu_type() -> VNPUType:
71
- vnpu_type = axclrt_cffi.new("axclrtEngineVNpuKind *")
72
- ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu_type)
73
- if ret != 0:
74
- raise RuntimeError("Failed to get VNPU attribute.")
75
- return VNPUType(vnpu_type[0])
76
-
77
-
78
- def _get_version():
79
- major, minor, patch = axclrt_cffi.new('int32_t *'), axclrt_cffi.new('int32_t *'), axclrt_cffi.new(
80
- 'int32_t *')
81
- axclrt_lib.axclrtGetVersion(major, minor, patch)
82
- return f'{major[0]}.{minor[0]}.{patch[0]}'
83
-
84
-
85
- class AXCLRTSession(Session):
86
- def __init__(
87
- self,
88
- path_or_bytes: str | bytes | os.PathLike,
89
- sess_options: SessionOptions | None = None,
90
- provider_options: dict[Any, Any] | None = None,
91
- **kwargs,
92
- ) -> None:
93
- super().__init__()
94
-
95
- self._device_index = 0
96
-
97
- if provider_options is not None and "device_id" in provider_options[0]:
98
- self._device_index = provider_options[0].get("device_id", 0)
99
-
100
- lst = axclrt_cffi.new("axclrtDeviceList *")
101
- ret = axclrt_lib.axclrtGetDeviceList(lst)
102
- if ret != 0 or lst.num == 0:
103
- raise RuntimeError(f"Get AXCL device failed 0x{ret:08x}, find total {lst.num} device.")
104
-
105
- if self._device_index >= lst.num:
106
- raise RuntimeError(f"Device index {self._device_index} is out of range, total {lst.num} device.")
107
-
108
- self._device_id = lst.devices[self._device_index]
109
- ret = axclrt_lib.axclrtSetDevice(self._device_id)
110
- if ret != 0 or lst.num == 0:
111
- raise RuntimeError(f"Set AXCL device failed 0x{ret:08x}.")
112
-
113
- global _is_axclrt_engine_initialized
114
- vnpu_type = axclrt_cffi.cast(
115
- "axclrtEngineVNpuKind", VNPUType.DISABLED.value
116
- )
117
- # try to initialize NPU as disabled
118
- ret = axclrt_lib.axclrtEngineInit(vnpu_type)
119
- # if failed, try to get vnpu type
120
- if 0 != ret:
121
- vnpu = axclrt_cffi.new("axclrtEngineVNpuKind *")
122
- ret = axclrt_lib.axclrtEngineGetVNpuKind(vnpu)
123
- # if failed, that means the NPU is not available
124
- if ret != 0:
125
- raise RuntimeError(f"axclrtEngineInit as {vnpu.value} failed 0x{ret:08x}.")
126
- # if success, that means the NPU is already initialized as vnpu.value
127
- # so the initialization is failed.
128
- # this means the other users maybe uninitialized the NPU suddenly
129
- # and the app would be terminated unexpectedly at that moment.
130
- # but we can't do anything to fix this issue, just print a warning message.
131
- # it because the api looks like onnxruntime, so there no window avoid this.
132
- # such as the life.
133
- else:
134
- print(f"[WARNING] Failed to initialize NPU as {vnpu_type}, NPU is already initialized as {vnpu.value}.")
135
- # initialize NPU successfully, mark the flag to ensure the engine will be finalized
136
- else:
137
- _is_axclrt_engine_initialized = True
138
-
139
- self.soc_name = axclrt_cffi.string(axclrt_lib.axclrtGetSocName()).decode()
140
- print(f"[INFO] SOC Name: {self.soc_name}")
141
-
142
- # model handle, context, info, io
143
- self._model_id = axclrt_cffi.new("uint64_t *")
144
- self._context_id = axclrt_cffi.new("uint64_t *")
145
-
146
- # get vnpu type
147
- self._vnpu_type = _get_vnpu_type()
148
- print(f"[INFO] VNPU type: {self._vnpu_type}")
149
-
150
- # load model
151
- ret = self._load(path_or_bytes)
152
- if 0 != ret:
153
- raise RuntimeError("Failed to load model.")
154
- print(f"[INFO] Compiler version: {self._get_model_tool_version()}")
155
-
156
- # get model info
157
- self._info = self._get_info()
158
- self._shape_count = self._get_shape_count()
159
- self._inputs = self._get_inputs()
160
- self._outputs = self._get_outputs()
161
-
162
- # prepare io
163
- self._io = self._prepare_io()
164
-
165
- def __del__(self):
166
- self._unload()
167
-
168
- def _load(self, path_or_bytes):
169
- # model buffer, almost copied from onnx runtime
170
- if isinstance(path_or_bytes, (str, os.PathLike)):
171
- _model_path = axclrt_cffi.new("char[]", path_or_bytes.encode('utf-8'))
172
- ret = axclrt_lib.axclrtEngineLoadFromFile(_model_path, self._model_id)
173
- if ret != 0:
174
- raise RuntimeError("axclrtEngineLoadFromFile failed.")
175
- elif isinstance(path_or_bytes, bytes):
176
- _model_buffer = axclrt_cffi.new("char[]", path_or_bytes)
177
- _model_buffer_size = len(path_or_bytes)
178
-
179
- dev_mem_ptr = axclrt_cffi.new('void **', axclrt_cffi.NULL)
180
- ret = axclrt_lib.axclrtMalloc(dev_mem_ptr, _model_buffer_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
181
- if ret != 0:
182
- raise RuntimeError("axclrtMalloc failed.")
183
-
184
- ret = axclrt_lib.axclrtMemcpy(dev_mem_ptr[0], _model_buffer, _model_buffer_size, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
185
- if ret != 0:
186
- axclrt_lib.axclrtFree(dev_mem_ptr[0])
187
- raise RuntimeError("axclrtMemcpy failed.")
188
-
189
- ret = axclrt_lib.axclrtEngineLoadFromMem(dev_mem_ptr[0], _model_buffer_size, self._model_id)
190
- axclrt_lib.axclrtFree(dev_mem_ptr[0])
191
- if ret != 0:
192
- raise RuntimeError("axclrtEngineLoadFromMem failed.")
193
- else:
194
- raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'")
195
-
196
- ret = axclrt_lib.axclrtEngineCreateContext(self._model_id[0], self._context_id)
197
- if ret != 0:
198
- raise RuntimeError("axclrtEngineCreateContext failed")
199
- return ret
200
-
201
- def _unload(self):
202
- if self._io is not None:
203
- dev_size = axclrt_cffi.new("uint64_t *")
204
- dev_prt = axclrt_cffi.new("void **")
205
- for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
206
- axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
207
- axclrt_lib.axclrtFree(dev_prt[0])
208
- for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
209
- axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
210
- axclrt_lib.axclrtFree(dev_prt[0])
211
- axclrt_lib.axclrtEngineDestroyIO(self._io)
212
- self._io = None
213
- if self._model_id[0] is not None:
214
- axclrt_lib.axclrtEngineUnload(self._model_id[0])
215
- self._model_id[0] = 0
216
-
217
- def _get_model_tool_version(self):
218
- model_tool_version = axclrt_lib.axclrtEngineGetModelCompilerVersion(self._model_id[0])
219
- return axclrt_cffi.string(model_tool_version).decode()
220
-
221
- def _get_info(self):
222
- io_info = axclrt_cffi.new("axclrtEngineIOInfo *")
223
- ret = axclrt_lib.axclrtEngineGetIOInfo(self._model_id[0], io_info)
224
- if ret != 0:
225
- raise RuntimeError("axclrtEngineGetIOInfo failed.")
226
- return io_info
227
-
228
- def _get_shape_count(self):
229
- count = axclrt_cffi.new("int32_t *")
230
- ret = axclrt_lib.axclrtEngineGetShapeGroupsCount(self._info[0], count)
231
- if ret != 0:
232
- axclrt_lib.axclrtEngineUnload(self._model_id[0])
233
- raise RuntimeError("axclrtEngineGetShapeGroupsCount failed.")
234
- return count[0]
235
-
236
- def _get_inputs(self):
237
- inputs = []
238
- for group in range(self._shape_count):
239
- one_group_io = []
240
- for index in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
241
- cffi_name = axclrt_lib.axclrtEngineGetInputNameByIndex(self._info[0], index)
242
- name = axclrt_cffi.string(cffi_name).decode("utf-8")
243
-
244
- cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *")
245
- ret = axclrt_lib.axclrtEngineGetInputDataType(self._info[0], index, cffi_dtype)
246
- if ret != 0:
247
- raise RuntimeError("axclrtEngineGetInputDataType failed.")
248
- dtype = _transform_dtype(cffi_dtype[0])
249
-
250
- cffi_dims = axclrt_cffi.new("axclrtEngineIODims *")
251
- ret = axclrt_lib.axclrtEngineGetInputDims(self._info[0], group, index, cffi_dims)
252
- if ret != 0:
253
- raise RuntimeError("axclrtEngineGetInputDims failed.")
254
- shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)]
255
-
256
- meta = NodeArg(name, dtype, shape)
257
- one_group_io.append(meta)
258
- inputs.append(one_group_io)
259
- return inputs
260
-
261
- def _get_outputs(self):
262
- outputs = []
263
- for group in range(self._shape_count):
264
- one_group_io = []
265
- for index in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
266
- name = axclrt_lib.axclrtEngineGetOutputNameByIndex(self._info[0], index)
267
-
268
- cffi_dtype = axclrt_cffi.new("axclrtEngineDataType *")
269
- ret = axclrt_lib.axclrtEngineGetOutputDataType(self._info[0], index, cffi_dtype)
270
- if ret != 0:
271
- raise RuntimeError("axclrtEngineGetOutputDataType failed.")
272
- dtype = _transform_dtype(cffi_dtype[0])
273
-
274
- cffi_dims = axclrt_cffi.new("axclrtEngineIODims *")
275
- ret = axclrt_lib.axclrtEngineGetOutputDims(self._info[0], group, index, cffi_dims)
276
- if ret != 0:
277
- raise RuntimeError("axclrtEngineGetOutputDims failed.")
278
- shape = [cffi_dims.dims[i] for i in range(cffi_dims.dimCount)]
279
-
280
- meta = NodeArg(name, dtype, shape)
281
- one_group_io.append(meta)
282
- outputs.append(one_group_io)
283
- return outputs
284
-
285
- def _prepare_io(self):
286
- _io = axclrt_cffi.new("axclrtEngineIO *")
287
- ret = axclrt_lib.axclrtEngineCreateIO(self._info[0], _io)
288
- if ret != 0:
289
- raise RuntimeError(f"axclrtEngineCreateIO failed 0x{ret:08x}.")
290
- for i in range(axclrt_lib.axclrtEngineGetNumInputs(self._info[0])):
291
- max_size = 0
292
- for group in range(self._shape_count):
293
- size = axclrt_lib.axclrtEngineGetInputSizeByIndex(self._info[0], group, i)
294
- max_size = max(max_size, size)
295
- dev_ptr = axclrt_cffi.new("void **")
296
- ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
297
- if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL:
298
- raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for input {i}.")
299
- ret = axclrt_lib.axclrtEngineSetInputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
300
- if 0 != ret:
301
- raise RuntimeError(f"axclrtEngineSetInputBufferByIndex failed 0x{ret:08x} for input {i}.")
302
- for i in range(axclrt_lib.axclrtEngineGetNumOutputs(self._info[0])):
303
- max_size = 0
304
- for group in range(self._shape_count):
305
- size = axclrt_lib.axclrtEngineGetOutputSizeByIndex(self._info[0], group, i)
306
- max_size = max(max_size, size)
307
- dev_ptr = axclrt_cffi.new("void **")
308
- ret = axclrt_lib.axclrtMalloc(dev_ptr, max_size, axclrt_lib.AXCL_MEM_MALLOC_NORMAL_ONLY)
309
- if 0 != ret or dev_ptr[0] == axclrt_cffi.NULL:
310
- raise RuntimeError(f"axclrtMalloc failed 0x{ret:08x} for output {i}.")
311
- ret = axclrt_lib.axclrtEngineSetOutputBufferByIndex(_io[0], i, dev_ptr[0], max_size)
312
- if 0 != ret:
313
- raise RuntimeError(f"axclrtEngineSetOutputBufferByIndex failed 0x{ret:08x} for output {i}.")
314
- return _io[0]
315
-
316
- def run(
317
- self,
318
- output_names: list[str],
319
- input_feed: dict[str, np.ndarray],
320
- run_options=None
321
- ):
322
- self._validate_input(input_feed)
323
- self._validate_output(output_names)
324
-
325
- if None is output_names:
326
- output_names = [o.name for o in self.get_outputs()]
327
-
328
- # fill model io
329
- dev_prt = axclrt_cffi.new("void **")
330
- dev_size = axclrt_cffi.new("uint64_t *")
331
- for key, npy in input_feed.items():
332
- for i, one in enumerate(self.get_inputs()):
333
- if one.name == key:
334
- assert (
335
- list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
336
- ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, howerver gets input with shape {npy.shape} and dtype {npy.dtype}"
337
-
338
- if not (
339
- not npy.flags.c_contiguous
340
- and npy.flags.f_contiguous
341
- and npy.flags.contiguous
342
- ):
343
- npy = np.ascontiguousarray(npy)
344
- npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
345
- ret = axclrt_lib.axclrtEngineGetInputBufferByIndex(self._io, i, dev_prt, dev_size)
346
- if 0 != ret:
347
- raise RuntimeError(f"axclrtEngineGetInputBufferByIndex failed for input {i}.")
348
- ret = axclrt_lib.axclrtMemcpy(dev_prt[0], npy_ptr, npy.nbytes, axclrt_lib.AXCL_MEMCPY_HOST_TO_DEVICE)
349
- if 0 != ret:
350
- raise RuntimeError(f"axclrtMemcpy failed for input {i}.")
351
-
352
- # execute model
353
- ret = axclrt_lib.axclrtEngineExecute(self._model_id[0], self._context_id[0], 0, self._io)
354
-
355
- # get output
356
- outputs = []
357
- if 0 == ret:
358
- for i in range(len(self.get_outputs())):
359
- ret = axclrt_lib.axclrtEngineGetOutputBufferByIndex(self._io, i, dev_prt, dev_size)
360
- if 0 != ret:
361
- raise RuntimeError(f"axclrtEngineGetOutputBufferByIndex failed for output {i}.")
362
- npy = np.zeros(self.get_outputs()[i].shape, dtype=self.get_outputs()[i].dtype)
363
- npy_ptr = axclrt_cffi.cast("void *", npy.ctypes.data)
364
- ret = axclrt_lib.axclrtMemcpy(npy_ptr, dev_prt[0], npy.nbytes, axclrt_lib.AXCL_MEMCPY_DEVICE_TO_HOST)
365
- if 0 != ret:
366
- raise RuntimeError(f"axclrtMemcpy failed for output {i}.")
367
- name = self.get_outputs()[i].name
368
- if name in output_names:
369
- outputs.append(npy)
370
- return outputs
371
- else:
372
- raise RuntimeError(f"axclrtEngineExecute failed 0x{ret:08x}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_axclrt_capi.py DELETED
@@ -1,198 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- import ctypes.util
9
-
10
- from cffi import FFI
11
-
12
- __all__: ["axclrt_cffi", "axclrt_lib"]
13
-
14
- axclrt_cffi = FFI()
15
-
16
- # axcl_base.h
17
- axclrt_cffi.cdef(
18
- """
19
- #define AXCL_MAX_DEVICE_COUNT 256
20
- typedef int32_t axclError;
21
- typedef void *axclrtContext;
22
- """
23
- )
24
-
25
- # axcl_rt_type.h
26
- axclrt_cffi.cdef(
27
- """
28
- typedef struct axclrtDeviceList {
29
- uint32_t num;
30
- int32_t devices[AXCL_MAX_DEVICE_COUNT];
31
- } axclrtDeviceList;
32
-
33
- typedef enum axclrtMemMallocPolicy {
34
- AXCL_MEM_MALLOC_HUGE_FIRST,
35
- AXCL_MEM_MALLOC_HUGE_ONLY,
36
- AXCL_MEM_MALLOC_NORMAL_ONLY
37
- } axclrtMemMallocPolicy;
38
-
39
- typedef enum axclrtMemcpyKind {
40
- AXCL_MEMCPY_HOST_TO_HOST,
41
- AXCL_MEMCPY_HOST_TO_DEVICE, //!< host vir -> device phy
42
- AXCL_MEMCPY_DEVICE_TO_HOST, //!< host vir <- device phy
43
- AXCL_MEMCPY_DEVICE_TO_DEVICE,
44
- AXCL_MEMCPY_HOST_PHY_TO_DEVICE, //!< host phy -> device phy
45
- AXCL_MEMCPY_DEVICE_TO_HOST_PHY, //!< host phy <- device phy
46
- } axclrtMemcpyKind;
47
- """
48
- )
49
-
50
- # axcl_rt_engine_type.h
51
- axclrt_cffi.cdef(
52
- """
53
- #define AXCLRT_ENGINE_MAX_DIM_CNT 32
54
- typedef void* axclrtEngineIOInfo;
55
- typedef void* axclrtEngineIO;
56
-
57
- typedef enum axclrtEngineVNpuKind {
58
- AXCL_VNPU_DISABLE = 0,
59
- AXCL_VNPU_ENABLE = 1,
60
- AXCL_VNPU_BIG_LITTLE = 2,
61
- AXCL_VNPU_LITTLE_BIG = 3,
62
- } axclrtEngineVNpuKind;
63
-
64
- typedef enum axclrtEngineDataType {
65
- AXCL_DATA_TYPE_NONE = 0,
66
- AXCL_DATA_TYPE_INT4 = 1,
67
- AXCL_DATA_TYPE_UINT4 = 2,
68
- AXCL_DATA_TYPE_INT8 = 3,
69
- AXCL_DATA_TYPE_UINT8 = 4,
70
- AXCL_DATA_TYPE_INT16 = 5,
71
- AXCL_DATA_TYPE_UINT16 = 6,
72
- AXCL_DATA_TYPE_INT32 = 7,
73
- AXCL_DATA_TYPE_UINT32 = 8,
74
- AXCL_DATA_TYPE_INT64 = 9,
75
- AXCL_DATA_TYPE_UINT64 = 10,
76
- AXCL_DATA_TYPE_FP4 = 11,
77
- AXCL_DATA_TYPE_FP8 = 12,
78
- AXCL_DATA_TYPE_FP16 = 13,
79
- AXCL_DATA_TYPE_BF16 = 14,
80
- AXCL_DATA_TYPE_FP32 = 15,
81
- AXCL_DATA_TYPE_FP64 = 16,
82
- } axclrtEngineDataType;
83
-
84
- typedef enum axclrtEngineDataLayout {
85
- AXCL_DATA_LAYOUT_NONE = 0,
86
- AXCL_DATA_LAYOUT_NHWC = 0,
87
- AXCL_DATA_LAYOUT_NCHW = 1,
88
- } axclrtEngineDataLayout;
89
-
90
- typedef struct axclrtEngineIODims {
91
- int32_t dimCount;
92
- int32_t dims[AXCLRT_ENGINE_MAX_DIM_CNT];
93
- } axclrtEngineIODims;
94
- """
95
- )
96
-
97
- # axcl.h
98
- axclrt_cffi.cdef(
99
- """
100
- axclError axclInit(const char *config);
101
- axclError axclFinalize();
102
- """
103
- )
104
-
105
- # axcl_rt.h
106
- axclrt_cffi.cdef(
107
- """
108
- axclError axclrtGetVersion(int32_t *major, int32_t *minor, int32_t *patch);
109
- const char *axclrtGetSocName();
110
- """
111
- )
112
-
113
- # axcl_rt_device.h
114
- axclrt_cffi.cdef(
115
- """
116
- axclError axclrtGetDeviceList(axclrtDeviceList *deviceList);
117
- axclError axclrtSetDevice(int32_t deviceId);
118
- axclError axclrtResetDevice(int32_t deviceId);
119
- """
120
- )
121
-
122
- # axcl_rt_context.h
123
- axclrt_cffi.cdef(
124
- """
125
- axclError axclrtCreateContext(axclrtContext *context, int32_t deviceId);
126
- axclError axclrtDestroyContext(axclrtContext context);
127
- axclError axclrtSetCurrentContext(axclrtContext context);
128
- axclError axclrtGetCurrentContext(axclrtContext *context);
129
- axclError axclrtGetDefaultContext(axclrtContext *context, int32_t deviceId);
130
- """
131
- )
132
-
133
- # axcl_rt_engine.h
134
- axclrt_cffi.cdef(
135
- """
136
- axclError axclrtEngineInit(axclrtEngineVNpuKind npuKind);
137
- axclError axclrtEngineGetVNpuKind(axclrtEngineVNpuKind *npuKind);
138
- axclError axclrtEngineFinalize();
139
-
140
- axclError axclrtEngineLoadFromFile(const char *modelPath, uint64_t *modelId);
141
- axclError axclrtEngineLoadFromMem(const void *model, uint64_t modelSize, uint64_t *modelId);
142
- const char* axclrtEngineGetModelCompilerVersion(uint64_t modelId);
143
- axclError axclrtEngineUnload(uint64_t modelId);
144
-
145
- axclError axclrtEngineGetIOInfo(uint64_t modelId, axclrtEngineIOInfo *ioInfo);
146
- axclError axclrtEngineGetShapeGroupsCount(axclrtEngineIOInfo ioInfo, int32_t *count);
147
-
148
- uint32_t axclrtEngineGetNumInputs(axclrtEngineIOInfo ioInfo);
149
- uint32_t axclrtEngineGetNumOutputs(axclrtEngineIOInfo ioInfo);
150
-
151
- uint64_t axclrtEngineGetInputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index);
152
- uint64_t axclrtEngineGetOutputSizeByIndex(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index);
153
-
154
- axclError axclrtEngineGetInputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims);
155
- axclError axclrtEngineGetOutputDims(axclrtEngineIOInfo ioInfo, uint32_t group, uint32_t index, axclrtEngineIODims *dims);
156
-
157
- const char *axclrtEngineGetInputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index);
158
- const char *axclrtEngineGetOutputNameByIndex(axclrtEngineIOInfo ioInfo, uint32_t index);
159
-
160
- int32_t axclrtEngineGetInputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type);
161
- int32_t axclrtEngineGetOutputDataType(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataType *type);
162
-
163
- int32_t axclrtEngineGetInputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout);
164
- int32_t axclrtEngineGetOutputDataLayout(axclrtEngineIOInfo ioInfo, uint32_t index, axclrtEngineDataLayout *layout);
165
-
166
- axclError axclrtEngineCreateIO(axclrtEngineIOInfo ioInfo, axclrtEngineIO *io);
167
- axclError axclrtEngineDestroyIO(axclrtEngineIO io);
168
-
169
- axclError axclrtEngineSetInputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size);
170
- axclError axclrtEngineSetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, const void *dataBuffer, uint64_t size);
171
- axclError axclrtEngineGetInputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size);
172
- axclError axclrtEngineGetOutputBufferByIndex(axclrtEngineIO io, uint32_t index, void **dataBuffer, uint64_t *size);
173
-
174
- axclError axclrtEngineCreateContext(uint64_t modelId, uint64_t *contextId);
175
-
176
- axclError axclrtEngineExecute(uint64_t modelId, uint64_t contextId, uint32_t group, axclrtEngineIO io);
177
- """
178
- )
179
-
180
- # axcl_rt_memory.h
181
- axclrt_cffi.cdef(
182
- """
183
- axclError axclrtMalloc(void **devPtr, size_t size, axclrtMemMallocPolicy policy);
184
- axclError axclrtMallocCached(void **devPtr, size_t size, axclrtMemMallocPolicy policy);
185
- axclError axclrtMemcpy(void *dstPtr, const void *srcPtr, size_t count, axclrtMemcpyKind kind);
186
- axclError axclrtFree(void *devPtr);
187
- axclError axclrtMemFlush(void *devPtr, size_t size);
188
- """
189
- )
190
-
191
- rt_name = "axcl_rt"
192
- rt_path = ctypes.util.find_library(rt_name)
193
- assert (
194
- rt_path is not None
195
- ), f"Failed to find library {rt_name}. Please ensure it is installed and in the library path."
196
-
197
- axclrt_lib = axclrt_cffi.dlopen(rt_path)
198
- assert axclrt_lib is not None, f"Failed to load library {rt_path}. Please ensure it is installed and in the library path."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_axclrt_types.py DELETED
@@ -1,21 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- from enum import Enum
9
-
10
-
11
- class VNPUType(Enum):
12
- DISABLED = 0
13
- ENABLED = 1
14
- BIG_LITTLE = 2
15
- LITTLE_BIG = 3
16
-
17
-
18
- class ModelType(Enum):
19
- SINGLE = 0
20
- DUAL = 1
21
- TRIPLE = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_axe.py DELETED
@@ -1,399 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- import atexit
9
- import os
10
- from typing import Any, Sequence
11
-
12
- import ml_dtypes as mldt
13
- import numpy as np
14
-
15
- from ._axe_capi import sys_lib, engine_cffi, engine_lib
16
- from ._axe_types import VNPUType, ModelType, ChipType
17
- from ._base_session import Session, SessionOptions
18
- from ._node import NodeArg
19
-
20
- __all__: ["AXEngineSession"]
21
-
22
- _is_sys_initialized = False
23
- _is_engine_initialized = False
24
-
25
-
26
- def _transform_dtype(dtype):
27
- if dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT8):
28
- return np.dtype(np.uint8)
29
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT8):
30
- return np.dtype(np.int8)
31
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT16):
32
- return np.dtype(np.uint16)
33
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT16):
34
- return np.dtype(np.int16)
35
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_UINT32):
36
- return np.dtype(np.uint32)
37
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_SINT32):
38
- return np.dtype(np.int32)
39
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_FLOAT32):
40
- return np.dtype(np.float32)
41
- elif dtype == engine_cffi.cast("AX_ENGINE_DATA_TYPE_T", engine_lib.AX_ENGINE_DT_BFLOAT16):
42
- return np.dtype(mldt.bfloat16)
43
- else:
44
- raise ValueError(f"Unsupported data type '{dtype}'.")
45
-
46
-
47
- def _check_cffi_func_exists(lib, func_name):
48
- try:
49
- getattr(lib, func_name)
50
- return True
51
- except AttributeError:
52
- return False
53
-
54
-
55
- def _get_chip_type():
56
- if not _check_cffi_func_exists(engine_lib, "AX_ENGINE_SetAffinity"):
57
- return ChipType.M57H
58
- elif not _check_cffi_func_exists(engine_lib, "AX_ENGINE_GetTotalOps"):
59
- return ChipType.MC50
60
- else:
61
- return ChipType.MC20E
62
-
63
-
64
- def _get_version():
65
- engine_version = engine_lib.AX_ENGINE_GetVersion()
66
- return engine_cffi.string(engine_version).decode("utf-8")
67
-
68
-
69
- def _get_vnpu_type() -> VNPUType:
70
- vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *")
71
- ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
72
- if 0 != ret:
73
- raise RuntimeError("Failed to get VNPU attribute.")
74
- return VNPUType(vnpu_type.eHardMode)
75
-
76
-
77
- def _initialize_engine():
78
- global _is_sys_initialized, _is_engine_initialized
79
-
80
- ret = sys_lib.AX_SYS_Init()
81
- if ret != 0:
82
- raise RuntimeError("Failed to initialize ax sys.")
83
- _is_sys_initialized = True
84
-
85
- # disabled mode by default
86
- vnpu_type = engine_cffi.new("AX_ENGINE_NPU_ATTR_T *")
87
- ret = engine_lib.AX_ENGINE_GetVNPUAttr(vnpu_type)
88
- if 0 != ret:
89
- # this means the NPU was not initialized
90
- vnpu_type.eHardMode = engine_cffi.cast(
91
- "AX_ENGINE_NPU_MODE_T", VNPUType.DISABLED.value
92
- )
93
- ret = engine_lib.AX_ENGINE_Init(vnpu_type)
94
- if ret != 0:
95
- raise RuntimeError("Failed to initialize ax sys engine.")
96
- _is_engine_initialized = True
97
-
98
- print(f"[INFO] Chip type: {_get_chip_type()}")
99
- print(f"[INFO] VNPU type: {_get_vnpu_type()}")
100
- print(f"[INFO] Engine version: {_get_version()}")
101
-
102
-
103
- def _finalize_engine():
104
- global _is_sys_initialized, _is_engine_initialized
105
-
106
- if _is_engine_initialized:
107
- engine_lib.AX_ENGINE_Deinit()
108
- if _is_sys_initialized:
109
- sys_lib.AX_SYS_Deinit()
110
-
111
-
112
- _initialize_engine()
113
- atexit.register(_finalize_engine)
114
-
115
-
116
- class AXEngineSession(Session):
117
- def __init__(
118
- self,
119
- path_or_bytes: str | bytes | os.PathLike,
120
- sess_options: SessionOptions | None = None,
121
- provider_options: dict[Any, Any] | None = None,
122
- **kwargs,
123
- ) -> None:
124
- super().__init__()
125
-
126
- self._chip_type = _get_chip_type()
127
- self._vnpu_type = _get_vnpu_type()
128
-
129
- # handle, context, info, io
130
- self._handle = engine_cffi.new("uint64_t **")
131
- self._context = engine_cffi.new("uint64_t **")
132
- self._io = engine_cffi.new("AX_ENGINE_IO_T *")
133
-
134
- # model buffer, almost copied from onnx runtime
135
- if isinstance(path_or_bytes, (str, os.PathLike)):
136
- self._model_name = os.path.splitext(os.path.basename(path_or_bytes))[0]
137
- with open(path_or_bytes, "rb") as f:
138
- data = f.read()
139
- self._model_buffer = engine_cffi.new("char[]", data)
140
- self._model_buffer_size = len(data)
141
- elif isinstance(path_or_bytes, bytes):
142
- self._model_buffer = engine_cffi.new("char[]", path_or_bytes)
143
- self._model_buffer_size = len(path_or_bytes)
144
- else:
145
- raise TypeError(f"Unable to load model from type '{type(path_or_bytes)}'")
146
-
147
- # get model type
148
- self._model_type = self._get_model_type()
149
- if self._chip_type is ChipType.MC20E:
150
- if self._model_type is ModelType.FULL:
151
- print(f"[INFO] Model type: {self._model_type.value} (full core)")
152
- if self._model_type is ModelType.HALF:
153
- print(f"[INFO] Model type: {self._model_type.value} (half core)")
154
- if self._chip_type is ChipType.MC50:
155
- if self._model_type is ModelType.SINGLE:
156
- print(f"[INFO] Model type: {self._model_type.value} (single core)")
157
- if self._model_type is ModelType.DUAL:
158
- print(f"[INFO] Model type: {self._model_type.value} (dual core)")
159
- if self._model_type is ModelType.TRIPLE:
160
- print(f"[INFO] Model type: {self._model_type.value} (triple core)")
161
- if self._chip_type is ChipType.M57H:
162
- print(f"[INFO] Model type: {self._model_type.value} (single core)")
163
-
164
- # check model type
165
- if self._chip_type is ChipType.MC50:
166
- # all types (single or dual or triple) of model are allowed in vnpu mode disabled
167
- # only single core model is allowed in vnpu mode enabled
168
- # only triple core model is NOT allowed in vnpu mode big-little or little-big
169
- if self._vnpu_type is VNPUType.ENABLED:
170
- if self._model_type is not ModelType.SINGLE:
171
- raise ValueError(
172
- f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
173
- )
174
- if (
175
- self._vnpu_type is VNPUType.BIG_LITTLE
176
- or self._vnpu_type is VNPUType.LITTLE_BIG
177
- ):
178
- if self._model_type is ModelType.TRIPLE:
179
- raise ValueError(
180
- f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
181
- )
182
- if self._chip_type is ChipType.MC20E:
183
- # all types of full or half core model are allowed in vnpu mode disabled
184
- # only half core model is allowed in vnpu mode enabled
185
- if self._vnpu_type is VNPUType.ENABLED:
186
- if self._model_type is ModelType.FULL:
187
- raise ValueError(
188
- f"Model type '{self._model_type}' is not allowed when vnpu is inited as {self._vnpu_type}."
189
- )
190
- # if self._chip_type is ChipType.M57H:
191
- # there only one type of model will be compiled, so no need to check
192
-
193
- # load model
194
- ret = self._load()
195
- if 0 != ret:
196
- raise RuntimeError("Failed to load model.")
197
- print(f"[INFO] Compiler version: {self._get_model_tool_version()}")
198
-
199
- # get shape group count
200
- try:
201
- self._shape_count = self._get_shape_count()
202
- except AttributeError as e:
203
- print(f"[WARNING] {e}")
204
- self._shape_count = 1
205
-
206
- # get model shape
207
- self._info = self._get_info()
208
- self._inputs = self._get_inputs()
209
- self._outputs = self._get_outputs()
210
-
211
- # fill model io
212
- self._align = 128
213
- self._cmm_token = engine_cffi.new("AX_S8[]", b"PyEngine")
214
- self._io[0].nInputSize = len(self.get_inputs())
215
- self._io[0].nOutputSize = len(self.get_outputs())
216
- self._io[0].pInputs = engine_cffi.new(
217
- "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nInputSize)
218
- )
219
- self._io[0].pOutputs = engine_cffi.new(
220
- "AX_ENGINE_IO_BUFFER_T[{}]".format(self._io[0].nOutputSize)
221
- )
222
- for i in range(len(self.get_inputs())):
223
- max_buf = 0
224
- for j in range(self._shape_count):
225
- max_buf = max(max_buf, self._info[j][0].pInputs[i].nSize)
226
- self._io[0].pInputs[i].nSize = max_buf
227
- phy = engine_cffi.new("AX_U64*")
228
- vir = engine_cffi.new("AX_VOID**")
229
- ret = sys_lib.AX_SYS_MemAllocCached(
230
- phy, vir, self._io[0].pInputs[i].nSize, self._align, self._cmm_token
231
- )
232
- if 0 != ret:
233
- raise RuntimeError("Failed to allocate memory for input.")
234
- self._io[0].pInputs[i].phyAddr = phy[0]
235
- self._io[0].pInputs[i].pVirAddr = vir[0]
236
- for i in range(len(self.get_outputs())):
237
- max_buf = 0
238
- for j in range(self._shape_count):
239
- max_buf = max(max_buf, self._info[j][0].pOutputs[i].nSize)
240
- self._io[0].pOutputs[i].nSize = max_buf
241
- phy = engine_cffi.new("AX_U64*")
242
- vir = engine_cffi.new("AX_VOID**")
243
- ret = sys_lib.AX_SYS_MemAllocCached(
244
- phy, vir, self._io[0].pOutputs[i].nSize, self._align, self._cmm_token
245
- )
246
- if 0 != ret:
247
- raise RuntimeError("Failed to allocate memory for output.")
248
- self._io[0].pOutputs[i].phyAddr = phy[0]
249
- self._io[0].pOutputs[i].pVirAddr = vir[0]
250
-
251
- def __del__(self):
252
- self._unload()
253
-
254
- def _get_model_type(self) -> ModelType:
255
- model_type = engine_cffi.new("AX_ENGINE_MODEL_TYPE_T *")
256
- ret = engine_lib.AX_ENGINE_GetModelType(
257
- self._model_buffer, self._model_buffer_size, model_type
258
- )
259
- if 0 != ret:
260
- raise RuntimeError("Failed to get model type.")
261
- return ModelType(model_type[0])
262
-
263
- def _get_model_tool_version(self):
264
- model_tool_version = engine_lib.AX_ENGINE_GetModelToolsVersion(
265
- self._handle[0]
266
- )
267
- return engine_cffi.string(model_tool_version).decode("utf-8")
268
-
269
- def _load(self):
270
- extra = engine_cffi.new("AX_ENGINE_HANDLE_EXTRA_T *")
271
- extra_name = engine_cffi.new("char[]", self._model_name.encode("utf-8"))
272
- extra.pName = extra_name
273
-
274
- # for onnx runtime do not support one model multiple context running in multi-thread as far as I know, so
275
- # the engine handle and context will create only once
276
- ret = engine_lib.AX_ENGINE_CreateHandleV2(
277
- self._handle, self._model_buffer, self._model_buffer_size, extra
278
- )
279
- if 0 == ret:
280
- ret = engine_lib.AX_ENGINE_CreateContextV2(
281
- self._handle[0], self._context
282
- )
283
- return ret
284
-
285
- def _get_info(self):
286
- total_info = []
287
- if 1 == self._shape_count:
288
- info = engine_cffi.new("AX_ENGINE_IO_INFO_T **")
289
- ret = engine_lib.AX_ENGINE_GetIOInfo(self._handle[0], info)
290
- if 0 != ret:
291
- raise RuntimeError("Failed to get model shape.")
292
- total_info.append(info)
293
- else:
294
- for i in range(self._shape_count):
295
- info = engine_cffi.new("AX_ENGINE_IO_INFO_T **")
296
- ret = engine_lib.AX_ENGINE_GetGroupIOInfo(
297
- self._handle[0], i, info
298
- )
299
- if 0 != ret:
300
- raise RuntimeError(f"Failed to get model the {i}th shape.")
301
- total_info.append(info)
302
- return total_info
303
-
304
- def _get_shape_count(self):
305
- count = engine_cffi.new("AX_U32 *")
306
- ret = engine_lib.AX_ENGINE_GetGroupIOInfoCount(self._handle[0], count)
307
- if 0 != ret:
308
- raise RuntimeError("Failed to get model shape group.")
309
- return count[0]
310
-
311
- def _unload(self):
312
- if self._handle[0] is not None:
313
- engine_lib.AX_ENGINE_DestroyHandle(self._handle[0])
314
- self._handle[0] = engine_cffi.NULL
315
-
316
- def _get_io(self, io_type: str):
317
- io_info = []
318
- for group in range(self._shape_count):
319
- one_group_io = []
320
- for index in range(getattr(self._info[group][0], f'n{io_type}Size')):
321
- current_io = getattr(self._info[group][0], f'p{io_type}s')[index]
322
- name = engine_cffi.string(current_io.pName).decode("utf-8")
323
- shape = [current_io.pShape[i] for i in range(current_io.nShapeSize)]
324
- dtype = _transform_dtype(current_io.eDataType)
325
- meta = NodeArg(name, dtype, shape)
326
- one_group_io.append(meta)
327
- io_info.append(one_group_io)
328
- return io_info
329
-
330
- def _get_inputs(self):
331
- return self._get_io('Input')
332
-
333
- def _get_outputs(self):
334
- return self._get_io('Output')
335
-
336
- def run(
337
- self,
338
- output_names: list[str],
339
- input_feed: dict[str, np.ndarray],
340
- run_options=None
341
- ):
342
- self._validate_input(input_feed)
343
- self._validate_output(output_names)
344
-
345
- if None is output_names:
346
- output_names = [o.name for o in self.get_outputs()]
347
-
348
- # fill model io
349
- for key, npy in input_feed.items():
350
- for i, one in enumerate(self.get_inputs()):
351
- if one.name == key:
352
- assert (
353
- list(one.shape) == list(npy.shape) and one.dtype == npy.dtype
354
- ), f"model inputs({key}) expect shape {one.shape} and dtype {one.dtype}, however gets input with shape {npy.shape} and dtype {npy.dtype}"
355
-
356
- if not (
357
- not npy.flags.c_contiguous
358
- and npy.flags.f_contiguous
359
- and npy.flags.contiguous
360
- ):
361
- npy = np.ascontiguousarray(npy)
362
- npy_ptr = engine_cffi.cast("void *", npy.ctypes.data)
363
-
364
- engine_cffi.memmove(
365
- self._io[0].pInputs[i].pVirAddr, npy_ptr, npy.nbytes
366
- )
367
- sys_lib.AX_SYS_MflushCache(
368
- self._io[0].pInputs[i].phyAddr,
369
- self._io[0].pInputs[i].pVirAddr,
370
- self._io[0].pInputs[i].nSize,
371
- )
372
- break
373
-
374
- # execute model
375
- ret = engine_lib.AX_ENGINE_RunSyncV2(
376
- self._handle[0], self._context[0], self._io
377
- )
378
-
379
- # flush output
380
- outputs = []
381
- if 0 == ret:
382
- for i in range(len(self.get_outputs())):
383
- sys_lib.AX_SYS_MinvalidateCache(
384
- self._io[0].pOutputs[i].phyAddr,
385
- self._io[0].pOutputs[i].pVirAddr,
386
- self._io[0].pOutputs[i].nSize,
387
- )
388
- npy = np.frombuffer(
389
- engine_cffi.buffer(
390
- self._io[0].pOutputs[i].pVirAddr, self._io[0].pOutputs[i].nSize
391
- ),
392
- dtype=self.get_outputs()[i].dtype,
393
- ).reshape(self.get_outputs()[i].shape)
394
- name = self.get_outputs()[i].name
395
- if name in output_names:
396
- outputs.append(npy)
397
- return outputs
398
- else:
399
- raise RuntimeError("Failed to run model.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_axe_capi.py DELETED
@@ -1,323 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- import ctypes.util
9
- import platform
10
-
11
- from cffi import FFI
12
-
13
- __all__: ["sys_lib", "sys_cffi", "engine_lib", "engine_cffi"]
14
-
15
- sys_cffi = FFI()
16
-
17
- # ax_base_type.h
18
- sys_cffi.cdef(
19
- """
20
- typedef int AX_S32;
21
- typedef unsigned int AX_U32;
22
- typedef unsigned long long int AX_U64;
23
- typedef signed char AX_S8;
24
- typedef void AX_VOID;
25
- """
26
- )
27
-
28
- # ax_sys_api.h
29
- sys_cffi.cdef(
30
- """
31
- AX_S32 AX_SYS_Init(AX_VOID);
32
- AX_S32 AX_SYS_Deinit(AX_VOID);
33
- AX_S32 AX_SYS_MemAllocCached(AX_U64 *phyaddr, AX_VOID **pviraddr, AX_U32 size, AX_U32 align, const AX_S8 *token);
34
- AX_S32 AX_SYS_MemFree(AX_U64 phyaddr, AX_VOID *pviraddr);
35
- AX_S32 AX_SYS_MflushCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size);
36
- AX_S32 AX_SYS_MinvalidateCache(AX_U64 phyaddr, AX_VOID *pviraddr, AX_U32 size);
37
- """
38
- )
39
-
40
- sys_name = "ax_sys"
41
- sys_path = ctypes.util.find_library(sys_name)
42
- assert (
43
- sys_path is not None
44
- ), f"Failed to find library {sys_name}. Please ensure it is installed and in the library path."
45
-
46
- sys_lib = sys_cffi.dlopen(sys_path)
47
- assert sys_lib is not None, f"Failed to load library {sys_path}. Please ensure it is installed and in the library path."
48
-
49
- engine_cffi = FFI()
50
-
51
- # ax_base_type.h
52
- engine_cffi.cdef(
53
- """
54
- typedef unsigned long long int AX_U64;
55
- typedef unsigned int AX_U32;
56
- typedef unsigned char AX_U8;
57
- typedef int AX_S32;
58
- typedef signed char AX_S8;
59
- typedef char AX_CHAR;
60
- typedef void AX_VOID;
61
-
62
- typedef enum {
63
- AX_FALSE = 0,
64
- AX_TRUE = 1,
65
- } AX_BOOL;
66
- """
67
- )
68
-
69
- # ax_engine_type.h, base type
70
- engine_cffi.cdef(
71
- """
72
- typedef AX_U32 AX_ENGINE_NPU_SET_T;
73
- """
74
- )
75
-
76
- # ax_engine_type.h, enum
77
- engine_cffi.cdef(
78
- """
79
- typedef enum _AX_ENGINE_TENSOR_LAYOUT_E
80
- {
81
- AX_ENGINE_TENSOR_LAYOUT_UNKNOWN = 0,
82
- AX_ENGINE_TENSOR_LAYOUT_NHWC = 1,
83
- AX_ENGINE_TENSOR_LAYOUT_NCHW = 2,
84
- } AX_ENGINE_TENSOR_LAYOUT_T;
85
-
86
- typedef enum
87
- {
88
- AX_ENGINE_MT_PHYSICAL = 0,
89
- AX_ENGINE_MT_VIRTUAL = 1,
90
- AX_ENGINE_MT_OCM = 2,
91
- } AX_ENGINE_MEMORY_TYPE_T;
92
-
93
- typedef enum
94
- {
95
- AX_ENGINE_DT_UNKNOWN = 0,
96
- AX_ENGINE_DT_UINT8 = 1,
97
- AX_ENGINE_DT_UINT16 = 2,
98
- AX_ENGINE_DT_FLOAT32 = 3,
99
- AX_ENGINE_DT_SINT16 = 4,
100
- AX_ENGINE_DT_SINT8 = 5,
101
- AX_ENGINE_DT_SINT32 = 6,
102
- AX_ENGINE_DT_UINT32 = 7,
103
- AX_ENGINE_DT_FLOAT64 = 8,
104
- AX_ENGINE_DT_BFLOAT16 = 9,
105
- AX_ENGINE_DT_UINT10_PACKED = 100,
106
- AX_ENGINE_DT_UINT12_PACKED = 101,
107
- AX_ENGINE_DT_UINT14_PACKED = 102,
108
- AX_ENGINE_DT_UINT16_PACKED = 103,
109
- } AX_ENGINE_DATA_TYPE_T;
110
-
111
- typedef enum
112
- {
113
- AX_ENGINE_CS_FEATUREMAP = 0,
114
- AX_ENGINE_CS_RAW8 = 12,
115
- AX_ENGINE_CS_RAW10 = 1,
116
- AX_ENGINE_CS_RAW12 = 2,
117
- AX_ENGINE_CS_RAW14 = 11,
118
- AX_ENGINE_CS_RAW16 = 3,
119
- AX_ENGINE_CS_NV12 = 4,
120
- AX_ENGINE_CS_NV21 = 5,
121
- AX_ENGINE_CS_RGB = 6,
122
- AX_ENGINE_CS_BGR = 7,
123
- AX_ENGINE_CS_RGBA = 8,
124
- AX_ENGINE_CS_GRAY = 9,
125
- AX_ENGINE_CS_YUV444 = 10,
126
- } AX_ENGINE_COLOR_SPACE_T;
127
- """
128
- )
129
-
130
- # ax_engine_type.h, architecturally agnostic struct
131
- engine_cffi.cdef(
132
- """
133
- typedef enum {
134
- AX_ENGINE_VIRTUAL_NPU_DISABLE = 0,
135
- } AX_ENGINE_NPU_MODE_T;
136
-
137
- typedef enum {
138
- AX_ENGINE_MODEL_TYPE0 = 0,
139
- } AX_ENGINE_MODEL_TYPE_T;
140
-
141
- typedef struct {
142
- AX_ENGINE_NPU_MODE_T eHardMode;
143
- AX_U32 reserve[8];
144
- } AX_ENGINE_NPU_ATTR_T;
145
-
146
- typedef struct _AX_ENGINE_IO_META_EX_T
147
- {
148
- AX_ENGINE_COLOR_SPACE_T eColorSpace;
149
- AX_U64 u64Reserved[18];
150
- } AX_ENGINE_IO_META_EX_T;
151
-
152
- typedef struct {
153
- AX_ENGINE_NPU_SET_T nNpuSet;
154
- AX_S8* pName;
155
- AX_U32 reserve[8];
156
- } AX_ENGINE_HANDLE_EXTRA_T;
157
-
158
- typedef struct _AX_ENGINE_CMM_INFO_T
159
- {
160
- AX_U32 nCMMSize;
161
- } AX_ENGINE_CMM_INFO_T;
162
-
163
- typedef struct _AX_ENGINE_IO_SETTING_T
164
- {
165
- AX_U32 nWbtIndex;
166
- AX_U64 u64Reserved[7];
167
- }AX_ENGINE_IO_SETTING_T;
168
- """
169
- )
170
-
171
- # check architecture, 32bit or 64bit
172
- arch = platform.architecture()[0]
173
-
174
- # ax_engine_type.h, struct
175
- if arch == "64bit":
176
- engine_cffi.cdef(
177
- """
178
- typedef struct _AX_ENGINE_IO_META_T
179
- {
180
- AX_CHAR* pName;
181
- AX_S32* pShape;
182
- AX_U8 nShapeSize;
183
- AX_ENGINE_TENSOR_LAYOUT_T eLayout;
184
- AX_ENGINE_MEMORY_TYPE_T eMemoryType;
185
- AX_ENGINE_DATA_TYPE_T eDataType;
186
- AX_ENGINE_IO_META_EX_T* pExtraMeta;
187
- AX_U32 nSize;
188
- AX_U32 nQuantizationValue;
189
- AX_S32* pStride;
190
- AX_U64 u64Reserved[9];
191
- } AX_ENGINE_IO_META_T;
192
-
193
- typedef struct _AX_ENGINE_IO_INFO_T
194
- {
195
- AX_ENGINE_IO_META_T* pInputs;
196
- AX_U32 nInputSize;
197
- AX_ENGINE_IO_META_T* pOutputs;
198
- AX_U32 nOutputSize;
199
- AX_U32 nMaxBatchSize;
200
- AX_BOOL bDynamicBatchSize;
201
- AX_U64 u64Reserved[11];
202
- } AX_ENGINE_IO_INFO_T;
203
-
204
- typedef struct _AX_ENGINE_IO_BUFFER_T
205
- {
206
- AX_U64 phyAddr;
207
- AX_VOID* pVirAddr;
208
- AX_U32 nSize;
209
- AX_S32* pStride;
210
- AX_U8 nStrideSize;
211
- AX_U64 u64Reserved[11];
212
- } AX_ENGINE_IO_BUFFER_T;
213
-
214
- typedef struct _AX_ENGINE_IO_T
215
- {
216
- AX_ENGINE_IO_BUFFER_T* pInputs;
217
- AX_U32 nInputSize;
218
- AX_ENGINE_IO_BUFFER_T* pOutputs;
219
- AX_U32 nOutputSize;
220
- AX_U32 nBatchSize;
221
- AX_ENGINE_IO_SETTING_T* pIoSetting;
222
- AX_U64 u64Reserved[10];
223
- } AX_ENGINE_IO_T;
224
- """
225
- )
226
- else:
227
- engine_cffi.cdef(
228
- """
229
- typedef struct _AX_ENGINE_IO_META_T
230
- {
231
- AX_CHAR* pName;
232
- AX_S32* pShape;
233
- AX_U8 nShapeSize;
234
- AX_ENGINE_TENSOR_LAYOUT_T eLayout;
235
- AX_ENGINE_MEMORY_TYPE_T eMemoryType;
236
- AX_ENGINE_DATA_TYPE_T eDataType;
237
- AX_ENGINE_IO_META_EX_T* pExtraMeta;
238
- AX_U32 nSize;
239
- AX_U32 nQuantizationValue;
240
- AX_S32* pStride;
241
- AX_U64 u64Reserved[11];
242
- } AX_ENGINE_IO_META_T;
243
-
244
- typedef struct _AX_ENGINE_IO_INFO_T
245
- {
246
- AX_ENGINE_IO_META_T* pInputs;
247
- AX_U32 nInputSize;
248
- AX_ENGINE_IO_META_T* pOutputs;
249
- AX_U32 nOutputSize;
250
- AX_U32 nMaxBatchSize;
251
- AX_BOOL bDynamicBatchSize;
252
- AX_U64 u64Reserved[13];
253
- } AX_ENGINE_IO_INFO_T;
254
-
255
- typedef struct _AX_ENGINE_IO_BUFFER_T
256
- {
257
- AX_U64 phyAddr;
258
- AX_VOID* pVirAddr;
259
- AX_U32 nSize;
260
- AX_S32* pStride;
261
- AX_U8 nStrideSize;
262
- AX_U64 u64Reserved[13];
263
- } AX_ENGINE_IO_BUFFER_T;
264
-
265
- typedef struct _AX_ENGINE_IO_T
266
- {
267
- AX_ENGINE_IO_BUFFER_T* pInputs;
268
- AX_U32 nInputSize;
269
- AX_ENGINE_IO_BUFFER_T* pOutputs;
270
- AX_U32 nOutputSize;
271
- AX_U32 nBatchSize;
272
- AX_ENGINE_IO_SETTING_T* pIoSetting;
273
- AX_U64 u64Reserved[12];
274
- } AX_ENGINE_IO_T;
275
- """
276
- )
277
-
278
- # ax_engine_api.h
279
- engine_cffi.cdef(
280
- """
281
- const AX_CHAR* AX_ENGINE_GetVersion(AX_VOID);
282
-
283
- AX_VOID AX_ENGINE_NPUReset(AX_VOID);
284
- AX_S32 AX_ENGINE_Init(AX_ENGINE_NPU_ATTR_T* pNpuAttr);
285
- AX_S32 AX_ENGINE_GetVNPUAttr(AX_ENGINE_NPU_ATTR_T* pNpuAttr);
286
- AX_S32 AX_ENGINE_Deinit(AX_VOID);
287
-
288
- AX_S32 AX_ENGINE_GetModelType(const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_MODEL_TYPE_T* pModelType);
289
-
290
- AX_S32 AX_ENGINE_CreateHandleV2(uint64_t** pHandle, const AX_VOID* pData, AX_U32 nDataSize, AX_ENGINE_HANDLE_EXTRA_T* pExtraParam);
291
- AX_S32 AX_ENGINE_DestroyHandle(uint64_t* nHandle);
292
-
293
- AX_S32 AX_ENGINE_GetIOInfo(uint64_t* nHandle, AX_ENGINE_IO_INFO_T** pIO);
294
- AX_S32 AX_ENGINE_GetGroupIOInfoCount(uint64_t* nHandle, AX_U32* pCount);
295
- AX_S32 AX_ENGINE_GetGroupIOInfo(uint64_t* nHandle, AX_U32 nIndex, AX_ENGINE_IO_INFO_T** pIO);
296
-
297
- AX_S32 AX_ENGINE_GetHandleModelType(uint64_t* nHandle, AX_ENGINE_MODEL_TYPE_T* pModelType);
298
-
299
- AX_S32 AX_ENGINE_CreateContextV2(uint64_t* nHandle, uint64_t** pContext);
300
-
301
- AX_S32 AX_ENGINE_RunSyncV2(uint64_t* handle, uint64_t* context, AX_ENGINE_IO_T* pIO);
302
- AX_S32 AX_ENGINE_RunGroupIOSync(uint64_t* handle, uint64_t* context, AX_U32 nIndex, AX_ENGINE_IO_T* pIO);
303
-
304
- AX_S32 AX_ENGINE_SetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T nNpuSet);
305
- AX_S32 AX_ENGINE_GetAffinity(uint64_t* nHandle, AX_ENGINE_NPU_SET_T* pNpuSet);
306
-
307
- AX_S32 AX_ENGINE_GetCMMUsage(uint64_t* nHandle, AX_ENGINE_CMM_INFO_T* pCMMInfo);
308
-
309
- const AX_CHAR* AX_ENGINE_GetModelToolsVersion(uint64_t* nHandle);
310
-
311
- // internal use api, remember no question
312
- AX_S32 AX_ENGINE_GetTotalOps();
313
- """
314
- )
315
-
316
- engine_name = "ax_engine"
317
- engine_path = ctypes.util.find_library(engine_name)
318
- assert (
319
- engine_path is not None
320
- ), f"Failed to find library {engine_name}. Please ensure it is installed and in the library path."
321
-
322
- engine_lib = engine_cffi.dlopen(engine_path)
323
- assert engine_lib is not None, f"Failed to load library {engine_path}. Please ensure it is installed and in the library path."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_axe_types.py DELETED
@@ -1,29 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- from enum import Enum
9
-
10
-
11
- class VNPUType(Enum):
12
- DISABLED = 0
13
- ENABLED = 1
14
- BIG_LITTLE = 2
15
- LITTLE_BIG = 3
16
-
17
-
18
- class ModelType(Enum):
19
- HALF = 0 # for MC20E, which means chip is AX630C(x), or AX620Q(x)
20
- FULL = 1 # for MC20E
21
- SINGLE = 0 # for MC50, which means chip is AX650A or AX650N, and M57H
22
- DUAL = 1 # for MC50
23
- TRIPLE = 2 # for MC50
24
-
25
-
26
- class ChipType(Enum):
27
- MC20E = 0
28
- MC50 = 1
29
- M57H = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_base_session.py DELETED
@@ -1,59 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- from abc import ABC, abstractmethod
9
-
10
- import numpy as np
11
-
12
- from ._node import NodeArg
13
-
14
-
15
- class SessionOptions:
16
- pass
17
-
18
-
19
- class Session(ABC):
20
- def __init__(self) -> None:
21
- self._shape_count = 0
22
- self._inputs = []
23
- self._outputs = []
24
-
25
- def _validate_input(self, feed_input_names: dict[str, np.ndarray]):
26
- missing_input_names = []
27
- for i in self.get_inputs():
28
- if i.name not in feed_input_names:
29
- missing_input_names.append(i.name)
30
- if missing_input_names:
31
- raise ValueError(
32
- f"Required inputs ({missing_input_names}) are missing from input feed ({feed_input_names}).")
33
-
34
- def _validate_output(self, output_names: list[str]):
35
- if output_names is not None:
36
- for name in output_names:
37
- if name not in [o.name for o in self.get_outputs()]:
38
- raise ValueError(f"Output name '{name}' is not in model outputs name list.")
39
-
40
- def get_inputs(self, shape_group: int = 0) -> list[NodeArg]:
41
- if shape_group > self._shape_count:
42
- raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
43
- selected_info = self._inputs[shape_group]
44
- return selected_info
45
-
46
- def get_outputs(self, shape_group: int = 0) -> list[NodeArg]:
47
- if shape_group > self._shape_count:
48
- raise ValueError(f"Shape group '{shape_group}' is out of range, total {self._shape_count}.")
49
- selected_info = self._outputs[shape_group]
50
- return selected_info
51
-
52
- @abstractmethod
53
- def run(
54
- self,
55
- output_names: list[str] | None,
56
- input_feed: dict[str, np.ndarray],
57
- run_options=None
58
- ) -> list[np.ndarray]:
59
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_node.py DELETED
@@ -1,13 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
-
9
- class NodeArg(object):
10
- def __init__(self, name, dtype, shape):
11
- self.name = name
12
- self.dtype = dtype
13
- self.shape = shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_providers.py DELETED
@@ -1,31 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- import ctypes.util as cutil
9
-
10
- providers = []
11
- axengine_provider_name = 'AxEngineExecutionProvider'
12
- axclrt_provider_name = 'AXCLRTExecutionProvider'
13
-
14
- _axengine_lib_name = 'ax_engine'
15
- _axclrt_lib_name = 'axcl_rt'
16
-
17
- # check if axcl_rt is installed, so if available, it's the default provider
18
- if cutil.find_library(_axclrt_lib_name) is not None:
19
- providers.append(axclrt_provider_name)
20
-
21
- # check if ax_engine is installed
22
- if cutil.find_library(_axengine_lib_name) is not None:
23
- providers.append(axengine_provider_name)
24
-
25
-
26
- def get_all_providers():
27
- return [axengine_provider_name, axclrt_provider_name]
28
-
29
-
30
- def get_available_providers():
31
- return providers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
python/axengine/_session.py DELETED
@@ -1,117 +0,0 @@
1
- # Copyright (c) 2019-2024 Axera Semiconductor Co., Ltd. All Rights Reserved.
2
- #
3
- # This source file is the property of Axera Semiconductor Co., Ltd. and
4
- # may not be copied or distributed in any isomorphic form without the prior
5
- # written consent of Axera Semiconductor Co., Ltd.
6
- #
7
-
8
- import os
9
- from typing import Any, Sequence
10
-
11
- import numpy as np
12
-
13
- from ._base_session import SessionOptions
14
- from ._node import NodeArg
15
- from ._providers import axclrt_provider_name, axengine_provider_name
16
- from ._providers import get_available_providers
17
-
18
-
19
- class InferenceSession:
20
- def __init__(
21
- self,
22
- path_or_bytes: str | bytes | os.PathLike,
23
- sess_options: SessionOptions | None = None,
24
- providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None,
25
- provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs,
26
- ) -> None:
27
- self._sess = None
28
- self._sess_options = sess_options
29
- self._provider = None
30
- self._provider_options = None
31
- self._available_providers = get_available_providers()
32
-
33
- # the providers should be available at least one, checked in __init__.py
34
- if providers is None:
35
- # using first available provider as default
36
- _provider_name = self._available_providers[0]
37
- self._provider = _provider_name
38
- else:
39
- # if only one provider is specified
40
- if isinstance(providers, str):
41
- if providers not in self._available_providers:
42
- raise ValueError(f"Selected provider: '{providers}' is not available.")
43
- self._provider = providers
44
- # if multiple providers are specified, using the first one as default
45
- elif isinstance(providers, list):
46
- _unavailable_provider = []
47
- for p in providers:
48
- assert isinstance(p, str) or isinstance(p, tuple), \
49
- f"Invalid provider type: {type(p)}. Must be str or tuple."
50
- if isinstance(p, str):
51
- if p not in self._available_providers:
52
- _unavailable_provider.append(p)
53
- elif self._provider is None:
54
- self._provider = p
55
- if isinstance(p, tuple):
56
- assert len(p) == 2, f"Invalid provider type: {p}. Must be tuple with 2 elements."
57
- assert isinstance(p[0], str), f"Invalid provider type: {type(p[0])}. Must be str."
58
- assert isinstance(p[1], dict), f"Invalid provider type: {type(p[1])}. Must be dict."
59
- if p[0] not in self._available_providers:
60
- _unavailable_provider.append(p[0])
61
- elif self._provider is None:
62
- self._provider = p[0]
63
- # FIXME: check provider options
64
- self._provider_options = p[1]
65
- if _unavailable_provider:
66
- if self._provider is None:
67
- raise ValueError(f"Selected provider(s): {_unavailable_provider} is(are) not available.")
68
- else:
69
- print(f"[WARNING] Selected provider(s): {_unavailable_provider} is(are) not available.")
70
-
71
- # FIXME: can we remove this check?
72
- if self._provider is None:
73
- raise ValueError(f"No available provider found in {providers}.")
74
- print(f"[INFO] Using provider: {self._provider}")
75
-
76
- if self._provider == axclrt_provider_name:
77
- from ._axclrt import AXCLRTSession
78
- self._sess = AXCLRTSession(path_or_bytes, sess_options, provider_options, **kwargs)
79
- if self._provider == axengine_provider_name:
80
- from ._axe import AXEngineSession
81
- self._sess = AXEngineSession(path_or_bytes, sess_options, provider_options, **kwargs)
82
- if self._sess is None:
83
- raise RuntimeError(f"Create session failed with provider: {self._provider}")
84
-
85
- # add to support 'with' statement
86
- def __enter__(self):
87
- return self
88
-
89
- def __exit__(self, exc_type, exc_value, traceback):
90
- # not suppress exceptions
91
- return False
92
-
93
- def get_session_options(self):
94
- """
95
- Return the session options. See :class:`axengine.SessionOptions`.
96
- """
97
- return self._sess_options
98
-
99
- def get_providers(self):
100
- """
101
- Return list of registered execution providers.
102
- """
103
- return self._provider
104
-
105
- def get_inputs(self, shape_group: int = 0) -> list[NodeArg]:
106
- return self._sess.get_inputs(shape_group)
107
-
108
- def get_outputs(self, shape_group: int = 0) -> list[NodeArg]:
109
- return self._sess.get_outputs(shape_group)
110
-
111
- def run(
112
- self,
113
- output_names: list[str] | None,
114
- input_feed: dict[str, np.ndarray],
115
- run_options=None
116
- ) -> list[np.ndarray]:
117
- return self._sess.run(output_names, input_feed, run_options)