libsr.axera / lib /pysr.py
zheqiushui's picture
Upload 19 files
e3513f8 verified
import ctypes
import numpy as np
from pyaxdev import _lib, AxDeviceType, AxDevices, check_error
class SRInit(ctypes.Structure):
_fields_ = [
('dev_type', AxDeviceType),
('devid', ctypes.c_char),
('model_path', ctypes.c_char * 256)
]
class SRImage(ctypes.Structure):
_fields_ = [
('width', ctypes.c_int),
('height', ctypes.c_int),
('pVirAddr', ctypes.POINTER(ctypes.c_ubyte))
]
_lib.ax_sr_init.argtypes = [ctypes.POINTER(SRInit), ctypes.POINTER(ctypes.c_void_p)]
_lib.ax_sr_init.restype = ctypes.c_int
_lib.ax_sr_deinit.argtypes = [ctypes.c_void_p]
_lib.ax_sr_deinit.restype = ctypes.c_int
_lib.ax_sr_run.argtypes = [ctypes.c_void_p, ctypes.POINTER(SRImage), ctypes.POINTER(SRImage)]
_lib.ax_sr_run.restype = ctypes.c_int
class SR:
def __init__(self, init_info: dict):
self.handle = None
self.init_info = SRInit()
# 设置初始化参数
self.init_info.dev_type = init_info.get('dev_type', AxDeviceType.axcl_device)
self.init_info.devid = init_info.get('devid', 0)
setattr(self.init_info, 'model_path', init_info['model_path'].encode('utf-8'))
handle = ctypes.c_void_p()
check_error(_lib.ax_sr_init(ctypes.byref(self.init_info), ctypes.byref(handle)))
self.handle = handle
def __del__(self):
if self.handle:
_lib.ax_sr_deinit(self.handle)
def __call__(self, image_data: np.ndarray) -> None:
image = SRImage()
image.width = image_data.shape[1]
image.height = image_data.shape[0]
image.pVirAddr = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
np_sr_image = np.zeros((image.height*2, image.width*2, 3), dtype=np.uint8)
sr_image = SRImage()
sr_image.width = image.width*2
sr_image.height = image.height*2
sr_image.pVirAddr = ctypes.cast(np_sr_image.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
check_error(_lib.ax_sr_run(self.handle, ctypes.byref(image), ctypes.byref(sr_image)))
return np_sr_image