happyme531 commited on
Commit
661e50f
·
verified ·
1 Parent(s): f4851b2

Upload 2 files

Browse files
Files changed (1) hide show
  1. ztu_somemodelruntime_rknnlite2.py +1195 -0
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ try:
10
+ import onnxruntime as ort
11
+ HAS_ORT = True
12
+ except ImportError:
13
+ HAS_ORT = False
14
+ warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
15
+
16
+ # 配置日志
17
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
18
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
19
+ if not logger.handlers:
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
22
+ logger.addHandler(handler)
23
+
24
+ # ONNX Runtime日志级别到Python logging级别的映射
25
+ _LOGGING_LEVEL_MAP = {
26
+ 0: logging.DEBUG, # Verbose
27
+ 1: logging.INFO, # Info
28
+ 2: logging.WARNING, # Warning
29
+ 3: logging.ERROR, # Error
30
+ 4: logging.CRITICAL # Fatal
31
+ }
32
+
33
+ # 检查环境变量中的日志级别设置
34
+ try:
35
+ env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
36
+ if env_log_level is not None:
37
+ log_level = int(env_log_level)
38
+ if log_level in _LOGGING_LEVEL_MAP:
39
+ logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
40
+ logger.info(f"从环境变量设置日志级别: {log_level}")
41
+ else:
42
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
43
+ except ValueError:
44
+ logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
45
+
46
+
47
+ def set_default_logger_severity(level: int) -> None:
48
+ """
49
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
50
+
51
+ Args:
52
+ level: 日志级别(0-4)
53
+ """
54
+ if level not in _LOGGING_LEVEL_MAP:
55
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
56
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
57
+
58
+ def set_default_logger_verbosity(level: int) -> None:
59
+ """
60
+ Sets the default logging verbosity level. To activate the verbose log,
61
+ you need to set the default logging severity to 0:Verbose level.
62
+
63
+ Args:
64
+ level: 日志级别(0-4)
65
+ """
66
+ set_default_logger_severity(level)
67
+
68
+ # RKNN tensor type到numpy dtype的映射
69
+ RKNN_DTYPE_MAP = {
70
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
71
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
72
+ 2: np.int8, # RKNN_TENSOR_INT8
73
+ 3: np.uint8, # RKNN_TENSOR_UINT8
74
+ 4: np.int16, # RKNN_TENSOR_INT16
75
+ 5: np.uint16, # RKNN_TENSOR_UINT16
76
+ 6: np.int32, # RKNN_TENSOR_INT32
77
+ 7: np.uint32, # RKNN_TENSOR_UINT32
78
+ 8: np.int64, # RKNN_TENSOR_INT64
79
+ 9: bool, # RKNN_TENSOR_BOOL
80
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
81
+ }
82
+
83
+ def get_available_providers() -> List[str]:
84
+ """
85
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
86
+
87
+ Returns:
88
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
89
+ """
90
+ return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
91
+
92
+
93
+ def get_device() -> str:
94
+ """
95
+ 获取当前设备
96
+
97
+ Returns:
98
+ str: 当前设备
99
+ """
100
+ return "RKNN2"
101
+
102
+ def get_version_info() -> Dict[str, str]:
103
+ """
104
+ 获取版本信息
105
+
106
+ Returns:
107
+ dict: 包含API和驱动版本信息的字典
108
+ """
109
+ runtime = RKNNLite()
110
+ version = runtime.get_sdk_version()
111
+ return {
112
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
113
+ "driver_version": version.split('\n')[3].split(': ')[1]
114
+ }
115
+
116
+ class IOTensor:
117
+ """输入/输出张量的信息封装类"""
118
+ def __init__(self, name, shape, type=None):
119
+ self.name = name.decode() if isinstance(name, bytes) else name
120
+ self.shape = shape
121
+ self.type = type
122
+
123
+ def __str__(self):
124
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
125
+
126
+ class SessionOptions:
127
+ """会话选项类"""
128
+ def __init__(self):
129
+ self.enable_profiling = False # 是否使用性能分析
130
+ self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
131
+ self.log_severity_level = -1 # 另一个设置日志级别的参数
132
+ self.log_verbosity_level = -1 # 另一个设置日志级别的参数
133
+
134
+
135
+ class InferenceSession:
136
+ """
137
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
138
+ """
139
+
140
+ def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
141
+ processed_path = InferenceSession._process_model_path(model_path, sess_options)
142
+ if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
143
+ logger.info("使用ONNX Runtime加载模型")
144
+ if not HAS_ORT:
145
+ raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
146
+ return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
147
+ else:
148
+ # 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
149
+ instance = super().__new__(cls)
150
+ # 保存处理后的路径
151
+ instance._processed_path = processed_path
152
+ return instance
153
+
154
+ def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
155
+ """
156
+ 初始化运行时并加载模型
157
+
158
+ Args:
159
+ model_path: 模型文件路径(.rknn或.onnx)
160
+ sess_options: 会话选项
161
+ **kwargs: 其他初始化参数
162
+ """
163
+ options = sess_options or SessionOptions()
164
+
165
+ # 只在未设置环境变量时使用SessionOptions中的日志级别
166
+ if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
167
+ if options.log_severity_level != -1:
168
+ set_default_logger_severity(options.log_severity_level)
169
+ if options.log_verbosity_level != -1:
170
+ set_default_logger_verbosity(options.log_verbosity_level)
171
+
172
+ # 使用__new__中处理好的路径
173
+ model_path = getattr(self, '_processed_path', model_path)
174
+ if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
175
+ # 避免重复加载 ONNX 模型
176
+ return
177
+
178
+ # ... 现有的 RKNN 模型加载和初始化代码 ...
179
+ self.model_path = model_path
180
+ if not os.path.exists(self.model_path):
181
+ logger.error(f"模型文件不存在: {self.model_path}")
182
+ raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
183
+
184
+ self.runtime = RKNNLite(verbose=options.enable_profiling)
185
+
186
+ logger.debug(f"正在加载模型: {self.model_path}")
187
+ ret = self.runtime.load_rknn(self.model_path)
188
+ if ret != 0:
189
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
190
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
191
+ logger.debug("模型加载成功")
192
+
193
+
194
+ if options.intra_op_num_threads == 1:
195
+ core_mask = RKNNLite.NPU_CORE_AUTO
196
+ elif options.intra_op_num_threads == 2:
197
+ core_mask = RKNNLite.NPU_CORE_0_1
198
+ elif options.intra_op_num_threads == 3:
199
+ core_mask = RKNNLite.NPU_CORE_0_1_2
200
+ else:
201
+ raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
202
+
203
+ logger.debug("正在初始化运行时环境")
204
+ ret = self.runtime.init_runtime(core_mask=core_mask)
205
+ if ret != 0:
206
+ logger.error("初始化运行时环境失败")
207
+ raise RuntimeError('初始化运行时环境失败')
208
+
209
+ logger.debug("运行时环境初始化成功")
210
+
211
+ # 在 runtime 初始化后,按环境变量自动注册自定义算子插件库
212
+ try:
213
+ # 注册用户指定路径插件(逗号/分号分隔)
214
+ env_custom = os.getenv('ZTU_MODELRT_RKNN2_REG_CUSTOM_OP_LIB', '').strip()
215
+ if env_custom:
216
+ paths = [seg.strip() for seg in re.split(r"[,;:]", env_custom) if seg.strip()]
217
+ ok = 0
218
+ for p in paths:
219
+ if self.register_custom_op_lib(p):
220
+ ok += 1
221
+ if ok > 0:
222
+ logger.info(f"已注册 {ok}/{len(paths)} 个自定义算子插件")
223
+ # 注册系统目录下插件
224
+ if os.getenv('ZTU_MODELRT_RKNN2_REG_SYSTEM_CUSTOM_OP_LIB', '1') == '1':
225
+ cnt = self.register_system_custom_op_lib()
226
+ if cnt > 0:
227
+ logger.info(f"已从系统目录注册 {cnt} 个自定义算子插件")
228
+ except Exception as e:
229
+ logger.warning(f"自动注册自定义算子插件失败: {e}")
230
+
231
+ # 可选:按环境变量注册内置(基于Python)捆绑算子
232
+ if os.getenv('ZTU_MODELRT_RKNN2_REG_BUNDLED_OPS', '0') == '1':
233
+ logger.info("根据环境变量注册捆绑算子")
234
+ self.register_bundled_ops()
235
+
236
+ self._init_io_info()
237
+ self.options = options
238
+
239
+ def get_performance_info(self) -> Dict[str, float]:
240
+ """
241
+ 获取性能信息
242
+
243
+ Returns:
244
+ dict: 包含性能信息的字典
245
+ """
246
+ if not self.options.perf_debug:
247
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
248
+
249
+ perf = self.runtime.rknn_runtime.get_run_perf()
250
+ return {
251
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
252
+ }
253
+
254
+ def set_core_mask(self, core_mask: int) -> None:
255
+ """
256
+ 设置NPU核心使用模式
257
+
258
+ Args:
259
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
260
+ """
261
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
262
+ if ret != 0:
263
+ raise RuntimeError("设置NPU核心��式失败")
264
+
265
+ @staticmethod
266
+ def _process_model_path(model_path, sess_options):
267
+ """
268
+ 处理模型路径,支持.onnx和.rknn文件
269
+
270
+ Args:
271
+ model_path: 模型文件路径
272
+ """
273
+ # 如果是ONNX文件,检查是否需要自动加载RKNN
274
+ if model_path.lower().endswith('.onnx'):
275
+ logger.info("检测到ONNX模型文件")
276
+
277
+ # 获取需要跳过自动加载的模型列表
278
+ skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
279
+ if skip_models:
280
+ skip_list = [m.strip() for m in skip_models.split(',')]
281
+ # 获取模型文件名(不含路径)用于匹配
282
+ model_name = os.path.basename(model_path)
283
+ if model_name.lower() in [m.lower() for m in skip_list]:
284
+ logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
285
+ return model_path
286
+
287
+ # 构造RKNN文件路径
288
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
289
+ if os.path.exists(rknn_path):
290
+ logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
291
+ return rknn_path
292
+ else:
293
+ logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
294
+ return model_path
295
+
296
+ return model_path
297
+
298
+ def _convert_nhwc_to_nchw(self, shape):
299
+ """将NHWC格式的shape转换为NCHW格式"""
300
+ if len(shape) == 4:
301
+ # NHWC -> NCHW
302
+ n, h, w, c = shape
303
+ return [n, c, h, w]
304
+ return shape
305
+
306
+ def _init_io_info(self):
307
+ """初始化模型的输入输出信息"""
308
+ runtime = self.runtime.rknn_runtime
309
+
310
+ # 获取输入输出数量
311
+ n_input, n_output = runtime.get_in_out_num()
312
+
313
+ # 获取输入信息
314
+ self.input_tensors = []
315
+ for i in range(n_input):
316
+ attr = runtime.get_tensor_attr(i)
317
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
318
+ # 对四维输入进行NHWC到NCHW的转换
319
+ shape = self._convert_nhwc_to_nchw(shape)
320
+ # 获取dtype
321
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
322
+ tensor = IOTensor(attr.name, shape, dtype)
323
+ self.input_tensors.append(tensor)
324
+
325
+ # 获取输出信息
326
+ self.output_tensors = []
327
+ for i in range(n_output):
328
+ attr = runtime.get_tensor_attr(i, is_output=True)
329
+ shape = runtime.get_output_shape(i)
330
+ # 获取dtype
331
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
332
+ tensor = IOTensor(attr.name, shape, dtype)
333
+ self.output_tensors.append(tensor)
334
+
335
+ def get_inputs(self):
336
+ """
337
+ 获取模型输入信息
338
+
339
+ Returns:
340
+ list: 包含输入信息的列表
341
+ """
342
+ return self.input_tensors
343
+
344
+ def get_outputs(self):
345
+ """
346
+ 获取模型输出信息
347
+
348
+ Returns:
349
+ list: 包含输出信息的列表
350
+ """
351
+ return self.output_tensors
352
+
353
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
354
+ """
355
+ 执行模型推理
356
+
357
+ Args:
358
+ output_names: 输出节点名称列表,指定需要返回哪些输出
359
+ input_feed: 输入数据字典或列表
360
+ data_format: 输入数据格式,"nchw"或"nhwc"
361
+ **kwargs: 其他运行时参数
362
+
363
+ Returns:
364
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
365
+ """
366
+ if input_feed is None:
367
+ logger.error("input_feed不能为None")
368
+ raise ValueError("input_feed不能为None")
369
+
370
+ # 准备输入数据
371
+ if isinstance(input_feed, dict):
372
+ # 如果是字典,按照模型输入顺序排列
373
+ inputs = []
374
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
375
+ for tensor in self.input_tensors:
376
+ if tensor.name not in input_feed:
377
+ raise ValueError(f"缺少输入: {tensor.name}")
378
+ inputs.append(input_feed[tensor.name])
379
+ elif isinstance(input_feed, (list, tuple)):
380
+ # 如果是列表,确保长度匹配
381
+ if len(input_feed) != len(self.input_tensors):
382
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
383
+ inputs = list(input_feed)
384
+ else:
385
+ logger.error("input_feed必须是字典或列表类型")
386
+ raise ValueError("input_feed必须是字典或列表类型")
387
+
388
+ # 执行推理
389
+ try:
390
+ logger.debug("开始执行推理")
391
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
392
+
393
+ # 如果没有指定output_names,返回所有输出
394
+ if output_names is None:
395
+ return all_outputs
396
+
397
+ # 获取指定的输出
398
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
399
+ selected_outputs = []
400
+ for name in output_names:
401
+ if name not in output_map:
402
+ raise ValueError(f"未找到输出节点: {name}")
403
+ selected_outputs.append(all_outputs[output_map[name]])
404
+
405
+ return selected_outputs
406
+
407
+ except Exception as e:
408
+ logger.error(f"推理执行失败: {str(e)}")
409
+ raise RuntimeError(f"推理执行失败: {str(e)}")
410
+
411
+ def close(self):
412
+ """
413
+ 关闭会话,释放资源
414
+ """
415
+ if self.runtime is not None:
416
+ logger.info("正在释放运行时资源")
417
+ self.runtime.release()
418
+ self.runtime = None
419
+
420
+ def __enter__(self):
421
+ return self
422
+
423
+ def __exit__(self, exc_type, exc_val, exc_tb):
424
+ self.close()
425
+
426
+ def end_profiling(self) -> Optional[str]:
427
+ """
428
+ 结束性能分析的存根方法
429
+
430
+ Returns:
431
+ Optional[str]: None
432
+ """
433
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
434
+ return None
435
+
436
+ def get_profiling_start_time_ns(self) -> int:
437
+ """
438
+ 获取性能分析开始时间的存根方法
439
+
440
+ Returns:
441
+ int: 0
442
+ """
443
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
444
+ return 0
445
+
446
+ def get_modelmeta(self) -> Dict[str, str]:
447
+ """
448
+ 获取模型元数据的存根方法
449
+
450
+ Returns:
451
+ Dict[str, str]: 空字典
452
+ """
453
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
454
+ return {}
455
+
456
+ def get_session_options(self) -> SessionOptions:
457
+ """
458
+ 获取会话选项
459
+
460
+ Returns:
461
+ SessionOptions: 当前会话选项
462
+ """
463
+ return self.options
464
+
465
+ def get_providers(self) -> List[str]:
466
+ """
467
+ 获取当前使用的providers的存根方法
468
+
469
+ Returns:
470
+ List[str]: ["CPUExecutionProvider"]
471
+ """
472
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
473
+ return ["CPUExecutionProvider"]
474
+
475
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
476
+ """
477
+ 获取provider选项的存根方法
478
+
479
+ Returns:
480
+ Dict[str, Dict[str, str]]: 空字典
481
+ """
482
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
483
+ return {}
484
+
485
+ def get_session_config(self) -> Dict[str, str]:
486
+ """
487
+ 获取会话配置的存根方法
488
+
489
+ Returns:
490
+ Dict[str, str]: 空字典
491
+ """
492
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
493
+ return {}
494
+
495
+ def get_session_state(self) -> Dict[str, str]:
496
+ """
497
+ 获取会话状态的存根方法
498
+
499
+ Returns:
500
+ Dict[str, str]: 空字典
501
+ """
502
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
503
+ return {}
504
+
505
+ def set_session_config(self, config: Dict[str, str]) -> None:
506
+ """
507
+ 设置会话配置的存根方法
508
+
509
+ Args:
510
+ config: 会话配置字典
511
+ """
512
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
513
+
514
+ def get_memory_info(self) -> Dict[str, int]:
515
+ """
516
+ 获取内存使用信息的存根方法
517
+
518
+ Returns:
519
+ Dict[str, int]: 空字典
520
+ """
521
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
522
+ return {}
523
+
524
+ def set_memory_pattern(self, enable: bool) -> None:
525
+ """
526
+ 设置内存模式的存根方法
527
+
528
+ Args:
529
+ enable: 是否启用内存模式
530
+ """
531
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
532
+
533
+ def disable_memory_pattern(self) -> None:
534
+ """
535
+ 禁用内存模式的存根方法
536
+ """
537
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
538
+
539
+ def get_optimization_level(self) -> int:
540
+ """
541
+ 获取优化级别的存根方法
542
+
543
+ Returns:
544
+ int: 0
545
+ """
546
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
547
+ return 0
548
+
549
+ def set_optimization_level(self, level: int) -> None:
550
+ """
551
+ 设置优化级别的存根方法
552
+
553
+ Args:
554
+ level: 优化级别
555
+ """
556
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
557
+
558
+ def get_model_metadata(self) -> Dict[str, str]:
559
+ """
560
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
561
+
562
+ Returns:
563
+ Dict[str, str]: 空字典
564
+ """
565
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
566
+ return {}
567
+
568
+ def get_model_path(self) -> str:
569
+ """
570
+ 获取模型路径
571
+
572
+ Returns:
573
+ str: 模型文件路径
574
+ """
575
+ return self.model_path
576
+
577
+ def get_input_type_info(self) -> List[Dict[str, str]]:
578
+ """
579
+ 获取输入类型信息的存根方法
580
+
581
+ Returns:
582
+ List[Dict[str, str]]: 空列表
583
+ """
584
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
585
+ return []
586
+
587
+ def get_output_type_info(self) -> List[Dict[str, str]]:
588
+ """
589
+ 获取输出类型信息的存根方法
590
+
591
+ Returns:
592
+ List[Dict[str, str]]: 空列表
593
+ """
594
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
595
+ return []
596
+
597
+ ################### 自定义算子 ###################
598
+
599
+ def _init_custom_op_types(self):
600
+ """初始化自定义算子的类型定义"""
601
+ # 常量
602
+ self._RKNN_TENSOR_FLOAT32 = 0
603
+ self._RKNN_TENSOR_UINT8 = 3
604
+ self._RKNN_TENSOR_INT64 = 8
605
+ self._RKNN_TARGET_TYPE_CPU = 1
606
+
607
+ # 结构体定义
608
+ class RKNN_TensorAttr(ctypes.Structure):
609
+ _fields_ = [
610
+ ("index", ctypes.c_uint32),
611
+ ("n_dims", ctypes.c_uint32),
612
+ ("dims", ctypes.c_uint32 * RKNN_MAX_DIMS),
613
+ ("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
614
+ ("n_elems", ctypes.c_uint32),
615
+ ("size", ctypes.c_uint32),
616
+ ("fmt", ctypes.c_int),
617
+ ("type", ctypes.c_int),
618
+ ("qnt_type", ctypes.c_int),
619
+ ("fl", ctypes.c_int8),
620
+ ("zp", ctypes.c_int32),
621
+ ("scale", ctypes.c_float),
622
+ ("w_stride", ctypes.c_uint32),
623
+ ("size_with_stride", ctypes.c_uint32),
624
+ ("pass_through", ctypes.c_uint8),
625
+ ("h_stride", ctypes.c_uint32),
626
+ ]
627
+
628
+ class RKNN_TensorMem(ctypes.Structure):
629
+ _fields_ = [
630
+ ("virt_addr", ctypes.c_void_p),
631
+ ("phys_addr", ctypes.c_uint64),
632
+ ("fd", ctypes.c_int32),
633
+ ("offset", ctypes.c_int32),
634
+ ("size", ctypes.c_uint32),
635
+ ("flags", ctypes.c_uint32),
636
+ ("priv_data", ctypes.c_void_p),
637
+ ]
638
+
639
+ class RKNN_CustomOpTensor(ctypes.Structure):
640
+ _fields_ = [
641
+ ("attr", RKNN_TensorAttr),
642
+ ("mem", RKNN_TensorMem),
643
+ ]
644
+
645
+ class RKNN_GPUOpContext(ctypes.Structure):
646
+ _fields_ = [
647
+ ("cl_context", ctypes.c_void_p),
648
+ ("cl_command_queue", ctypes.c_void_p),
649
+ ("cl_kernel", ctypes.c_void_p),
650
+ ]
651
+
652
+ InternalCtxType = (
653
+ ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
654
+ )
655
+
656
+ class RKNN_CustomOpContext(ctypes.Structure):
657
+ _fields_ = [
658
+ ("target", ctypes.c_int),
659
+ ("internal_ctx", InternalCtxType),
660
+ ("gpu_ctx", RKNN_GPUOpContext),
661
+ ("priv_data", ctypes.c_void_p),
662
+ ]
663
+
664
+ class RKNN_CustomOpAttr(ctypes.Structure):
665
+ _fields_ = [
666
+ ("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
667
+ ("dtype", ctypes.c_int),
668
+ ("n_elems", ctypes.c_uint32),
669
+ ("data", ctypes.c_void_p),
670
+ ]
671
+
672
+ CB_SIG = ctypes.CFUNCTYPE(
673
+ ctypes.c_int,
674
+ ctypes.POINTER(RKNN_CustomOpContext),
675
+ ctypes.POINTER(RKNN_CustomOpTensor),
676
+ ctypes.c_uint32,
677
+ ctypes.POINTER(RKNN_CustomOpTensor),
678
+ ctypes.c_uint32,
679
+ )
680
+
681
+ DESTROY_SIG = ctypes.CFUNCTYPE(
682
+ ctypes.c_int, ctypes.POINTER(RKNN_CustomOpContext)
683
+ )
684
+
685
+ class RKNN_CustomOp(ctypes.Structure):
686
+ _fields_ = [
687
+ ("version", ctypes.c_uint32),
688
+ ("target", ctypes.c_int),
689
+ ("op_type", ctypes.c_char * RKNN_MAX_NAME_LEN),
690
+ ("cl_kernel_name", ctypes.c_char * RKNN_MAX_NAME_LEN),
691
+ ("cl_kernel_source", ctypes.c_char_p),
692
+ ("cl_source_size", ctypes.c_uint64),
693
+ ("cl_build_options", ctypes.c_char * RKNN_MAX_NAME_LEN),
694
+ ("init", CB_SIG),
695
+ ("prepare", CB_SIG),
696
+ ("compute", CB_SIG),
697
+ ("compute_native", CB_SIG),
698
+ ("destroy", DESTROY_SIG),
699
+ ]
700
+
701
+ # 保存类型定义
702
+ self._RKNN_TensorAttr = RKNN_TensorAttr
703
+ self._RKNN_TensorMem = RKNN_TensorMem
704
+ self._RKNN_CustomOpTensor = RKNN_CustomOpTensor
705
+ self._RKNN_CustomOpContext = RKNN_CustomOpContext
706
+ self._RKNN_CustomOpAttr = RKNN_CustomOpAttr
707
+ self._RKNN_CustomOp = RKNN_CustomOp
708
+ self._CB_SIG = CB_SIG
709
+ self._DESTROY_SIG = DESTROY_SIG
710
+
711
+ def _create_attr_readers(self, get_op_attr):
712
+ """创建属性读取函数"""
713
+ def read_attr_int64(op_ctx_ptr, key: str, default: int = 0) -> int:
714
+ attr = self._RKNN_CustomOpAttr()
715
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
716
+ if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_INT64 and attr.data:
717
+ return ctypes.c_int64.from_address(attr.data).value
718
+ return default
719
+
720
+ def read_attr_float32(op_ctx_ptr, key: str, default: float = 0) -> float:
721
+ attr = self._RKNN_CustomOpAttr()
722
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
723
+ if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_FLOAT32 and attr.data:
724
+ return ctypes.c_float.from_address(attr.data).value
725
+ return default
726
+
727
+ def read_attr_str(op_ctx_ptr, key: str, default: str = "") -> str:
728
+ attr = self._RKNN_CustomOpAttr()
729
+ get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
730
+ if attr.n_elems > 0 and attr.dtype == self._RKNN_TENSOR_UINT8 and attr.data:
731
+ buf = (ctypes.c_ubyte * attr.n_elems).from_address(attr.data)
732
+ try:
733
+ return bytes(buf).decode("utf-8", errors="ignore").strip('"')
734
+ except Exception:
735
+ return default
736
+ return default
737
+
738
+
739
+ return read_attr_int64, read_attr_str, read_attr_float32
740
+
741
+ def _build_py_custom_op(self,
742
+ op_type: str,
743
+ n_inputs: int,
744
+ n_outputs: int,
745
+ on_init,
746
+ on_compute):
747
+ """通用的Python自定义算子构造器
748
+
749
+ Args:
750
+ op_type: 算子类型名(字符串)
751
+ n_inputs: 输入个数
752
+ n_outputs: 输出个数
753
+ on_init: 回调,签名 on_init(op_ctx_p, read_attr_int64, read_attr_str) -> state
754
+ on_compute: 回调,签名 on_compute(op_ctx_p, inputs_p, outputs_p, state) -> int(0成功)
755
+ Returns:
756
+ (RKNN_CustomOp对象, 回调tuple)
757
+ """
758
+ @self._CB_SIG
759
+ def _py_init(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
760
+ try:
761
+ # 允许无需提前读取属性
762
+ runtime = self.runtime.rknn_base.rknn_runtime
763
+ read_attr_int64, read_attr_str, read_attr_float32 = self._create_attr_readers(runtime.lib.rknn_custom_op_get_op_attr)
764
+ user_state = on_init(op_ctx_p, read_attr_int64, read_attr_str, read_attr_float32)
765
+ # 为该实例分配唯一ID, 并写入priv_data
766
+ if not hasattr(self, "_custom_op_states"):
767
+ self._custom_op_states = {}
768
+ if not hasattr(self, "_next_custom_op_id"):
769
+ self._next_custom_op_id = 1
770
+ inst_id = int(self._next_custom_op_id)
771
+ self._next_custom_op_id += 1
772
+ # 保存Python侧状态
773
+ self._custom_op_states[inst_id] = user_state
774
+ # 将实例ID写入priv_data
775
+ try:
776
+ op_ctx_p.contents.priv_data = ctypes.c_void_p(inst_id)
777
+ except Exception:
778
+ # 回退: 直接写入整数
779
+ op_ctx_p.contents.priv_data = inst_id
780
+ return 0
781
+ except Exception as e:
782
+ logger.error(f"{op_type} init失败: {e}")
783
+ return -1
784
+
785
+ @self._CB_SIG
786
+ def _py_prepare(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
787
+ return 0
788
+
789
+ @self._CB_SIG
790
+ def _py_compute(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
791
+ try:
792
+ if n_inputs_p != n_inputs or n_outputs_p != n_outputs:
793
+ return -1
794
+ # 通过priv_data取回该实例的状态
795
+ try:
796
+ inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
797
+ except Exception:
798
+ inst_id = 0
799
+ user_state = None
800
+ if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
801
+ user_state = self._custom_op_states.get(inst_id)
802
+ else:
803
+ logger.error(f"{op_type} compute失败: 找不到实例状态, inst_id={inst_id}")
804
+ return -1
805
+ return on_compute(op_ctx_p, inputs_p, outputs_p, user_state)
806
+ except Exception as e:
807
+ logger.error(f"{op_type} compute失败: {e}")
808
+ import traceback
809
+ logger.error(f"{op_type} compute失败: {traceback.format_exc()}")
810
+ return -1
811
+
812
+ @self._DESTROY_SIG
813
+ def _py_destroy(op_ctx_p):
814
+ try:
815
+ # 清理该实例的状态
816
+ try:
817
+ inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
818
+ except Exception:
819
+ inst_id = 0
820
+ if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
821
+ del self._custom_op_states[inst_id]
822
+ # 将priv_data清空
823
+ try:
824
+ op_ctx_p.contents.priv_data = ctypes.c_void_p(0)
825
+ except Exception:
826
+ op_ctx_p.contents.priv_data = 0
827
+ return 0
828
+ except Exception:
829
+ return -1
830
+
831
+ op = self._RKNN_CustomOp()
832
+ op.version = 1
833
+ op.target = self._RKNN_TARGET_TYPE_CPU
834
+ op.op_type = op_type.encode("utf-8")
835
+ op.cl_kernel_name = b""
836
+ op.cl_kernel_source = None
837
+ op.cl_source_size = 0
838
+ op.cl_build_options = b""
839
+ op.init = _py_init
840
+ op.prepare = _py_prepare
841
+ op.compute = _py_compute
842
+ op.compute_native = self._CB_SIG() # NULL
843
+ op.destroy = _py_destroy
844
+
845
+ return op, (_py_init, _py_prepare, _py_compute, _py_destroy)
846
+
847
+
848
+ def _tensor_to_numpy(self, rknn_tensor):
849
+ """将 RKNN_CustomOpTensor 转换为 Numpy 数组视图"""
850
+ # 确定Numpy数据类型
851
+ # 您可以扩展这个映射
852
+ dtype_map = {
853
+ self._RKNN_TENSOR_FLOAT32: (ctypes.c_float, np.float32),
854
+ self._RKNN_TENSOR_UINT8: (ctypes.c_uint8, np.uint8),
855
+ self._RKNN_TENSOR_INT64: (ctypes.c_int64, np.int64),
856
+ }
857
+ c_type, np_dtype = dtype_map.get(rknn_tensor.attr.type, (None, None))
858
+ if c_type is None:
859
+ raise TypeError(f"不支持的RKNN张量类型: {rknn_tensor.attr.type}")
860
+
861
+ # 获取内存地址和形状
862
+ addr = (rknn_tensor.mem.virt_addr or 0) + int(rknn_tensor.mem.offset)
863
+ ptr = ctypes.cast(addr, ctypes.POINTER(c_type))
864
+ shape = tuple(rknn_tensor.attr.dims[i] for i in range(rknn_tensor.attr.n_dims))
865
+
866
+ # 创建Numpy数组视图
867
+ return np.ctypeslib.as_array(ptr, shape=shape)
868
+
869
+
870
+ def _create_onnxscript_op_creator(self,
871
+ op_type: str,
872
+ # 现在接收一个"函数模板构造器"
873
+ onnxscript_func_builder,
874
+ n_inputs: int,
875
+ n_outputs: int,
876
+ attributes: dict = {},
877
+ constants: dict = {}):
878
+ """
879
+ 一个高阶工厂函数,用于创建基于ONNXScript的自定义算子构造器。
880
+ 它在 on_init 阶段动态生成最终的 onnxscript 计算函数。
881
+
882
+ Args:
883
+ op_type (str): 算子类型名。
884
+ onnxscript_func_builder: 一个函数,它接收所有属性和常量作为关键字参数,
885
+ 并返回一个编译好的 onnxscript 函数。
886
+ 例如: def builder(mean, scale):
887
+ @onnxscript.script()
888
+ def compute(like):
889
+ return opset.RandomNormalLike(like, mean=mean, scale=scale)
890
+ return compute
891
+ attributes (dict): 从模型中读取的属性字典。
892
+ constants (dict): 编译时常量字典。
893
+ n_inputs (int): 输入个数。
894
+ n_outputs (int): 输出个数。
895
+ """
896
+
897
+ def creator_func():
898
+ def on_init(op_ctx_p, read_i64, read_s, read_f32):
899
+ # 1. 读取所有动态属性
900
+ attr_values = {}
901
+ for name, (attr_type, default) in attributes.items():
902
+ if attr_type == 'int64':
903
+ attr_values[name] = read_i64(op_ctx_p, name, default)
904
+ elif attr_type == 'str':
905
+ attr_values[name] = read_s(op_ctx_p, name, default)
906
+ elif attr_type == 'float32':
907
+ attr_values[name] = read_f32(op_ctx_p, name, default)
908
+ else:
909
+ raise ValueError(f"不支持的属性类型: {attr_type}")
910
+
911
+ # 2. 合并常量和属性
912
+ final_kwargs = {**constants, **attr_values}
913
+
914
+ # 3. 动态构建 onnxscript 函数! <<<<< 核心修改
915
+ # 这确保了所有属性值都作为常量被闭包捕获
916
+ compute_func = onnxscript_func_builder(**final_kwargs)
917
+
918
+ # 4. 将最终生成的、已编译的函数存入 state
919
+ return {"compute_func": compute_func}
920
+
921
+ def on_compute(op_ctx_p, inputs_p, outputs_p, state):
922
+ compute_func = state["compute_func"]
923
+
924
+ input_nps = [self._tensor_to_numpy(inputs_p[i]) for i in range(n_inputs)]
925
+ output_nps = [self._tensor_to_numpy(outputs_p[i]) for i in range(n_outputs)]
926
+
927
+ results = compute_func(*input_nps)
928
+
929
+ if n_outputs == 1:
930
+ result_val = results[0] if isinstance(results, tuple) else results
931
+ output_nps[0][...] = result_val
932
+ else:
933
+ for i in range(n_outputs):
934
+ output_nps[i][...] = results[i]
935
+
936
+ return 0
937
+
938
+ return self._build_py_custom_op(
939
+ op_type=op_type,
940
+ n_inputs=n_inputs,
941
+ n_outputs=n_outputs,
942
+ on_init=on_init,
943
+ on_compute=on_compute
944
+ )
945
+
946
+ return creator_func
947
+
948
+ def _create_gridsample_op(self):
949
+ import onnxscript
950
+ from onnxscript import opset17 as opset
951
+
952
+ def grid_sample_builder(align_corners, mode, padding_mode):
953
+ @onnxscript.script()
954
+ def grid_sample_compute(X, G):
955
+ return opset.GridSample(X, G, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
956
+ return grid_sample_compute
957
+
958
+ grid_sample_creator = self._create_onnxscript_op_creator(
959
+ op_type="GridSample",
960
+ onnxscript_func_builder=grid_sample_builder, # << 传入 builder
961
+ attributes={
962
+ "align_corners": ("int64", 0),
963
+ "mode": ("str", "bilinear"),
964
+ "padding_mode": ("str", "zeros"),
965
+ },
966
+ n_inputs = 2,
967
+ n_outputs = 1
968
+ )
969
+ return grid_sample_creator
970
+
971
+ def _create_scatterelements_op(self):
972
+ import onnxscript
973
+ from onnxscript import opset17 as opset
974
+
975
+ @onnxscript.script()
976
+ def scatter_elements_compute(data, indices, updates):
977
+ indices_i64 = opset.Cast(indices, to=onnxscript.INT64.dtype)
978
+ return opset.ScatterElements(data, indices_i64, updates)
979
+
980
+ scatter_elements_creator = self._create_onnxscript_op_creator(
981
+ op_type="ScatterElements",
982
+ onnxscript_func_builder=lambda: scatter_elements_compute,
983
+ n_inputs = 3,
984
+ n_outputs = 1
985
+ )
986
+ return scatter_elements_creator
987
+
988
+ def _create_randomnormallike_op(self):
989
+ import onnxscript
990
+ from onnxscript import opset17 as opset
991
+
992
+ def random_normal_like_builder(mean, scale):
993
+ @onnxscript.script()
994
+ def random_normal_like_compute(like):
995
+ return opset.RandomNormalLike(like, mean=mean, scale=scale)
996
+
997
+ return random_normal_like_compute
998
+
999
+ # 3. 使用新的工厂函数
1000
+ random_normal_like_creator = self._create_onnxscript_op_creator(
1001
+ op_type="RandomNormalLike",
1002
+ onnxscript_func_builder=random_normal_like_builder, # << 传入 builder
1003
+ attributes={
1004
+ "mean": ("float32", 0.0),
1005
+ "scale": ("float32", 1.0),
1006
+ },
1007
+ n_inputs = 1,
1008
+ n_outputs = 1
1009
+ )
1010
+ return random_normal_like_creator
1011
+
1012
+ def _create_einsum_op(self):
1013
+ import onnxscript
1014
+ from onnxscript import opset17 as opset
1015
+
1016
+ def einsum_builder(equation):
1017
+
1018
+ @onnxscript.script()
1019
+ def einsum_compute(in1, in2):
1020
+ return opset.Einsum(in1, in2, equation=equation)
1021
+
1022
+ return einsum_compute
1023
+
1024
+ # 3. 使用新的工厂函数
1025
+ einsum_creator = self._create_onnxscript_op_creator(
1026
+ op_type="Einsum",
1027
+ onnxscript_func_builder=einsum_builder, # << 传入 builder
1028
+ attributes={
1029
+ "equation": ("str", ""),
1030
+ },
1031
+ n_inputs = 2,
1032
+ n_outputs = 1
1033
+ )
1034
+ return einsum_creator
1035
+
1036
+ def register_bundled_ops(self) -> None:
1037
+ """注册自定义操作"""
1038
+ if getattr(self, "_custom_ops_registered", False):
1039
+ return
1040
+
1041
+ runtime = self.runtime.rknn_base.rknn_runtime
1042
+ lib = runtime.lib
1043
+ ctx = runtime.context
1044
+
1045
+ try:
1046
+ _ = lib.rknn_register_custom_ops
1047
+ _ = lib.rknn_custom_op_get_op_attr
1048
+ except AttributeError as e:
1049
+ logger.debug(f"SDK不支持自定义算子注册: {e}")
1050
+ return
1051
+
1052
+ self._init_custom_op_types()
1053
+
1054
+ # 注意:插件库注册已在模型加载后由环境变量控制,不在此处重复触发
1055
+
1056
+ # 算子创建函数的列表现在更加清晰
1057
+ op_creator_factories = [
1058
+ self._create_gridsample_op,
1059
+ self._create_scatterelements_op,
1060
+ self._create_randomnormallike_op,
1061
+ self._create_einsum_op,
1062
+ # self._create_my_custom_add_op, # 添加新算子非常简单
1063
+ ]
1064
+
1065
+ ops_to_register = []
1066
+ all_callbacks = []
1067
+
1068
+ for factory in op_creator_factories:
1069
+ try:
1070
+ # 调用工厂获得真正的构造器
1071
+ creator_func = factory()
1072
+ # 调用构造器生成算子实例
1073
+ op, callbacks = creator_func()
1074
+ ops_to_register.append(op)
1075
+ all_callbacks.extend(callbacks)
1076
+ logger.debug(f"成功创建自定义算子: {op.op_type.decode()}")
1077
+ except Exception as e:
1078
+ logger.warning(f"创建自定义算子失败: {e}", exc_info=True)
1079
+
1080
+ if not ops_to_register:
1081
+ logger.debug("没有可注册的自定义算子")
1082
+ return
1083
+
1084
+ # 创建一个ctypes数组以包含所有要注册的算子, 然后一次性注册
1085
+ num_ops = len(ops_to_register)
1086
+ op_array = (self._RKNN_CustomOp * num_ops)(*ops_to_register)
1087
+ ret = lib.rknn_register_custom_ops(ctx, op_array, num_ops)
1088
+ if ret != 0:
1089
+ logger.error(f"注册自定义算子失败, ret={ret} (可能是误报, 继续执行...)")
1090
+ # raise RuntimeError(f"rknn_register_custom_ops 失败, ret={ret}")
1091
+
1092
+ logger.info(f"成功注册 {len(ops_to_register)} 个自定义算子")
1093
+
1094
+ self._custom_ops_registered = True
1095
+ self._registered_ops = ops_to_register
1096
+ self._op_callbacks = all_callbacks
1097
+
1098
+ def _load_and_register_plugin_op(self, so_path: str) -> bool:
1099
+ """加载单个插件库并注册其中的自定义算子。
1100
+
1101
+ 要求插件实现 get_rknn_custom_op(),返回 rknn_custom_op*。
1102
+ 我们将该 C 指针直接传递给 rknn_register_custom_ops,避免复制。
1103
+ """
1104
+ if not os.path.isfile(so_path):
1105
+ logger.warning(f"插件库不存在: {so_path}")
1106
+ return False
1107
+
1108
+ runtime = self.runtime.rknn_base.rknn_runtime
1109
+ lib = runtime.lib
1110
+ ctx = runtime.context
1111
+
1112
+ # 根据平台位宽设置 rknn_context 的 ctypes 类型
1113
+ ContextCType = ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
1114
+ # 设置 rknn_register_custom_ops(ctx, op_ptr, num) 签名。第二参数按 void* 传递,避免结构体布局不一致
1115
+ try:
1116
+ lib.rknn_register_custom_ops.argtypes = [ContextCType, ctypes.c_void_p, ctypes.c_uint32]
1117
+ lib.rknn_register_custom_ops.restype = ctypes.c_int
1118
+ except Exception:
1119
+ pass
1120
+
1121
+ # 加载插件
1122
+ try:
1123
+ handle = ctypes.CDLL(so_path)
1124
+ except Exception as e:
1125
+ logger.error(f"dlopen 失败: {so_path}, err={e}")
1126
+ return False
1127
+
1128
+ # 获取 get_rknn_custom_op 符号
1129
+ try:
1130
+ get_sym = getattr(handle, "get_rknn_custom_op")
1131
+ except AttributeError:
1132
+ logger.error(f"插件缺少符号 get_rknn_custom_op: {so_path}")
1133
+ return False
1134
+
1135
+ # 返回类型直接使用 void*,避免 Python 解析第三方结构体
1136
+ try:
1137
+ get_sym.argtypes = []
1138
+ except Exception:
1139
+ pass
1140
+ get_sym.restype = ctypes.c_void_p
1141
+
1142
+ op_void_ptr = get_sym()
1143
+ if not op_void_ptr:
1144
+ logger.error(f"get_rknn_custom_op 返回空指针: {so_path}")
1145
+ return False
1146
+
1147
+ # 直接使用原生指针注册(零拷贝)
1148
+ ctx_val = ContextCType(runtime.context)
1149
+ ret = lib.rknn_register_custom_ops(ctx_val, ctypes.c_void_p(op_void_ptr), 1)
1150
+ if ret != 0:
1151
+ logger.error(f"rknn_register_custom_ops 失败, ret={ret}, so={so_path} (可能是误报, 继续执行...)")
1152
+ # return False
1153
+
1154
+ # 保留句柄,避免被垃圾回收卸载
1155
+ if not hasattr(self, "_plugin_handles"):
1156
+ self._plugin_handles = []
1157
+ self._plugin_handles.append(handle)
1158
+ logger.info(f"成功注册插件自定义算子: {so_path}")
1159
+ return True
1160
+
1161
+ def register_plugin_ops(self, plugin_paths: List[str]) -> int:
1162
+ """按给定路径列表注册插件库中的自定义算子。返回成功数量。"""
1163
+ if not plugin_paths:
1164
+ return 0
1165
+ success = 0
1166
+ for path in plugin_paths:
1167
+ try:
1168
+ if self._load_and_register_plugin_op(path):
1169
+ success += 1
1170
+ except Exception as e:
1171
+ logger.error(f"注册插件失败: {path}, err={e}")
1172
+ return success
1173
+
1174
+ # 对外API:注册单个自定义算子插件库
1175
+ def register_custom_op_lib(self, path: str) -> bool:
1176
+ return self._load_and_register_plugin_op(path)
1177
+
1178
+ # 对外API:扫描并注册 Linux 系统目录下所有插件库(Android 不处理)
1179
+ def register_system_custom_op_lib(self) -> int:
1180
+ if os.name != 'posix':
1181
+ return 0
1182
+ # 仅 Linux:RKNN 官方默认目录
1183
+ system_dir = "/usr/lib/rknpu/op_plugins/"
1184
+ if not os.path.isdir(system_dir):
1185
+ return 0
1186
+ try:
1187
+ entries = os.listdir(system_dir)
1188
+ except Exception:
1189
+ return 0
1190
+ so_list = []
1191
+ for name in entries:
1192
+ # 官方要求文件名以 librkcst_ 开头
1193
+ if name.startswith("librkcst_") and name.endswith('.so'):
1194
+ so_list.append(os.path.join(system_dir, name))
1195
+ return self.register_plugin_ops(so_list)