File size: 4,235 Bytes
1634b64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import ctypes
import os
from typing import List, Tuple
import numpy as np
import platform
from pyaxdev import _lib, AxDeviceType, AxDevices, check_error

class ModelType(ctypes.c_int):
    ax_det_model_type_unknown = -1
    ax_det_model_type_yolov5 = 0
    ax_det_model_type_yolov8 = 1
    ax_det_model_type_yolov8_pose = 2
    ax_det_model_type_yolo11 = 3
    ax_det_model_type_yolo11_pose = 4

class DetInit(ctypes.Structure):
    _fields_ = [
        ('dev_type', AxDeviceType),
        ('devid', ctypes.c_char),
        ('model_type', ModelType),
        ('model_path', ctypes.c_char * 256),
        ('num_classes', ctypes.c_int),
        ('num_kpt', ctypes.c_int),
        ('threshold', ctypes.c_float),
        ('mean', ctypes.c_float * 3),
        ('std', ctypes.c_float * 3),
    ]

class DetImage(ctypes.Structure):
    _fields_ = [
        ('width', ctypes.c_int),
        ('height', ctypes.c_int),
        ('channels', ctypes.c_int),
        ('stride', ctypes.c_int),
        ('data', ctypes.POINTER(ctypes.c_ubyte)),
    ]


class ObjectItem(ctypes.Structure):
    _fields_ = [
        ('box', ctypes.c_int * 4),
        ('kpts', ctypes.c_int * 2 * 32),
        ('num_kpt', ctypes.c_int),
        ('score', ctypes.c_float),
        ('label', ctypes.c_int),
    ]
    
class ObjectResult(ctypes.Structure):
    _fields_ = [
        ('objects', ObjectItem * 64),
        ('num_objs', ctypes.c_int),
    ]

_lib.ax_det_init.argtypes = [ctypes.POINTER(DetInit), ctypes.POINTER(ctypes.c_void_p)]
_lib.ax_det_init.restype = ctypes.c_int

_lib.ax_det_deinit.argtypes = [ctypes.c_void_p]
_lib.ax_det_deinit.restype = ctypes.c_int

_lib.ax_det.argtypes = [ctypes.c_void_p, ctypes.POINTER(DetImage), ctypes.POINTER(ObjectResult)]
_lib.ax_det.restype = ctypes.c_int

class Object:
    def __init__(self, box: List[int], score: float, label: int, kpts: List[int] = []):
        self.box = box
        self.score = score
        self.label = label
        self.kpts = kpts

    def __repr__(self):
        return f"Object(box={self.box}, score={self.score:.2f}, label={self.label}, kpts={self.kpts})"
    
class AXDet:
    def __init__(self, model_path: str, model_type: ModelType, num_classes: int, 
                 num_kpt: int = 0, 
                 threshold: float= 0.25, 
                 mean: List[float] = [0,0,0], std: List[float] = [1,1,1],
                 dev_type: AxDeviceType = AxDeviceType.axcl_device,
                 devid: int = 0):
        self.handle = None
        self.init_info = DetInit()
        
        # 设置初始化参数
        self.init_info.dev_type = dev_type
        self.init_info.devid = devid
        
        # 设置路径
        self.init_info.model_type = model_type
        self.init_info.model_path = model_path.encode('utf-8')
        self.init_info.num_classes = num_classes
        self.init_info.num_kpt = num_kpt
        self.init_info.threshold = threshold
        for i in range(3):
            self.init_info.mean[i] = mean[i]
            self.init_info.std[i] = std[i]
        
        # 创建CLIP实例
        handle = ctypes.c_void_p()
        check_error(_lib.ax_det_init(ctypes.byref(self.init_info), ctypes.byref(handle)))
        self.handle = handle

    def __del__(self):
        if self.handle:
            _lib.ax_det_deinit(self.handle)

    def detect(self, image_data: np.ndarray):
      
        image = DetImage()
        image.data = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
        image.width = image_data.shape[1]
        image.height = image_data.shape[0]
        image.channels = image_data.shape[2]
        image.stride = image_data.shape[1] * image_data.shape[2]
        result = ObjectResult()
        check_error(_lib.ax_det(self.handle, ctypes.byref(image), ctypes.byref(result)))
        objects = []
        for i in range(result.num_objs):
            _obj = result.objects[i]
            
            obj = Object(
                box=[_obj.box[0], _obj.box[1], _obj.box[2], _obj.box[3]],
                score=_obj.score,
                label=_obj.label,
                kpts=[(_obj.kpts[j][0], _obj.kpts[j][1]) for j in range(_obj.num_kpt)],
            )
            objects.append(obj)
        return objects