File size: 2,115 Bytes
e3513f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import ctypes
import numpy as np
from pyaxdev import _lib, AxDeviceType, AxDevices, check_error


class SRInit(ctypes.Structure):
    _fields_ = [
        ('dev_type', AxDeviceType),
        ('devid', ctypes.c_char),
        ('model_path', ctypes.c_char * 256)
    ]

class SRImage(ctypes.Structure):
    _fields_ = [
        ('width', ctypes.c_int),
        ('height', ctypes.c_int),
        ('pVirAddr', ctypes.POINTER(ctypes.c_ubyte))
    ]


_lib.ax_sr_init.argtypes = [ctypes.POINTER(SRInit), ctypes.POINTER(ctypes.c_void_p)]
_lib.ax_sr_init.restype = ctypes.c_int

_lib.ax_sr_deinit.argtypes = [ctypes.c_void_p]
_lib.ax_sr_deinit.restype = ctypes.c_int

_lib.ax_sr_run.argtypes = [ctypes.c_void_p, ctypes.POINTER(SRImage), ctypes.POINTER(SRImage)]
_lib.ax_sr_run.restype = ctypes.c_int


class SR:
    def __init__(self, init_info: dict):
        self.handle = None
        self.init_info = SRInit()
        
        # 设置初始化参数
        self.init_info.dev_type = init_info.get('dev_type', AxDeviceType.axcl_device)
        self.init_info.devid = init_info.get('devid', 0)
        setattr(self.init_info, 'model_path', init_info['model_path'].encode('utf-8'))

        handle = ctypes.c_void_p()
        check_error(_lib.ax_sr_init(ctypes.byref(self.init_info), ctypes.byref(handle)))
        self.handle = handle

    def __del__(self):
        if self.handle:
            _lib.ax_sr_deinit(self.handle)

    def __call__(self, image_data: np.ndarray) -> None:
        image = SRImage()
        image.width = image_data.shape[1]
        image.height = image_data.shape[0]
        image.pVirAddr = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
            
        np_sr_image = np.zeros((image.height*2, image.width*2, 3), dtype=np.uint8)
        sr_image = SRImage()
        sr_image.width = image.width*2
        sr_image.height = image.height*2
        sr_image.pVirAddr = ctypes.cast(np_sr_image.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
        
        check_error(_lib.ax_sr_run(self.handle, ctypes.byref(image), ctypes.byref(sr_image)))
        return np_sr_image