import ctypes import numpy as np from pyaxdev import _lib, AxDeviceType, AxDevices, check_error class SRInit(ctypes.Structure): _fields_ = [ ('dev_type', AxDeviceType), ('devid', ctypes.c_char), ('model_path', ctypes.c_char * 256) ] class SRImage(ctypes.Structure): _fields_ = [ ('width', ctypes.c_int), ('height', ctypes.c_int), ('pVirAddr', ctypes.POINTER(ctypes.c_ubyte)) ] _lib.ax_sr_init.argtypes = [ctypes.POINTER(SRInit), ctypes.POINTER(ctypes.c_void_p)] _lib.ax_sr_init.restype = ctypes.c_int _lib.ax_sr_deinit.argtypes = [ctypes.c_void_p] _lib.ax_sr_deinit.restype = ctypes.c_int _lib.ax_sr_run.argtypes = [ctypes.c_void_p, ctypes.POINTER(SRImage), ctypes.POINTER(SRImage)] _lib.ax_sr_run.restype = ctypes.c_int class SR: def __init__(self, init_info: dict): self.handle = None self.init_info = SRInit() # 设置初始化参数 self.init_info.dev_type = init_info.get('dev_type', AxDeviceType.axcl_device) self.init_info.devid = init_info.get('devid', 0) setattr(self.init_info, 'model_path', init_info['model_path'].encode('utf-8')) handle = ctypes.c_void_p() check_error(_lib.ax_sr_init(ctypes.byref(self.init_info), ctypes.byref(handle))) self.handle = handle def __del__(self): if self.handle: _lib.ax_sr_deinit(self.handle) def __call__(self, image_data: np.ndarray) -> None: image = SRImage() image.width = image_data.shape[1] image.height = image_data.shape[0] image.pVirAddr = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte)) np_sr_image = np.zeros((image.height*2, image.width*2, 3), dtype=np.uint8) sr_image = SRImage() sr_image.width = image.width*2 sr_image.height = image.height*2 sr_image.pVirAddr = ctypes.cast(np_sr_image.ctypes.data, ctypes.POINTER(ctypes.c_ubyte)) check_error(_lib.ax_sr_run(self.handle, ctypes.byref(image), ctypes.byref(sr_image))) return np_sr_image