File size: 4,235 Bytes
1634b64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import ctypes
import os
from typing import List, Tuple
import numpy as np
import platform
from pyaxdev import _lib, AxDeviceType, AxDevices, check_error
class ModelType(ctypes.c_int):
ax_det_model_type_unknown = -1
ax_det_model_type_yolov5 = 0
ax_det_model_type_yolov8 = 1
ax_det_model_type_yolov8_pose = 2
ax_det_model_type_yolo11 = 3
ax_det_model_type_yolo11_pose = 4
class DetInit(ctypes.Structure):
_fields_ = [
('dev_type', AxDeviceType),
('devid', ctypes.c_char),
('model_type', ModelType),
('model_path', ctypes.c_char * 256),
('num_classes', ctypes.c_int),
('num_kpt', ctypes.c_int),
('threshold', ctypes.c_float),
('mean', ctypes.c_float * 3),
('std', ctypes.c_float * 3),
]
class DetImage(ctypes.Structure):
_fields_ = [
('width', ctypes.c_int),
('height', ctypes.c_int),
('channels', ctypes.c_int),
('stride', ctypes.c_int),
('data', ctypes.POINTER(ctypes.c_ubyte)),
]
class ObjectItem(ctypes.Structure):
_fields_ = [
('box', ctypes.c_int * 4),
('kpts', ctypes.c_int * 2 * 32),
('num_kpt', ctypes.c_int),
('score', ctypes.c_float),
('label', ctypes.c_int),
]
class ObjectResult(ctypes.Structure):
_fields_ = [
('objects', ObjectItem * 64),
('num_objs', ctypes.c_int),
]
_lib.ax_det_init.argtypes = [ctypes.POINTER(DetInit), ctypes.POINTER(ctypes.c_void_p)]
_lib.ax_det_init.restype = ctypes.c_int
_lib.ax_det_deinit.argtypes = [ctypes.c_void_p]
_lib.ax_det_deinit.restype = ctypes.c_int
_lib.ax_det.argtypes = [ctypes.c_void_p, ctypes.POINTER(DetImage), ctypes.POINTER(ObjectResult)]
_lib.ax_det.restype = ctypes.c_int
class Object:
def __init__(self, box: List[int], score: float, label: int, kpts: List[int] = []):
self.box = box
self.score = score
self.label = label
self.kpts = kpts
def __repr__(self):
return f"Object(box={self.box}, score={self.score:.2f}, label={self.label}, kpts={self.kpts})"
class AXDet:
def __init__(self, model_path: str, model_type: ModelType, num_classes: int,
num_kpt: int = 0,
threshold: float= 0.25,
mean: List[float] = [0,0,0], std: List[float] = [1,1,1],
dev_type: AxDeviceType = AxDeviceType.axcl_device,
devid: int = 0):
self.handle = None
self.init_info = DetInit()
# 设置初始化参数
self.init_info.dev_type = dev_type
self.init_info.devid = devid
# 设置路径
self.init_info.model_type = model_type
self.init_info.model_path = model_path.encode('utf-8')
self.init_info.num_classes = num_classes
self.init_info.num_kpt = num_kpt
self.init_info.threshold = threshold
for i in range(3):
self.init_info.mean[i] = mean[i]
self.init_info.std[i] = std[i]
# 创建CLIP实例
handle = ctypes.c_void_p()
check_error(_lib.ax_det_init(ctypes.byref(self.init_info), ctypes.byref(handle)))
self.handle = handle
def __del__(self):
if self.handle:
_lib.ax_det_deinit(self.handle)
def detect(self, image_data: np.ndarray):
image = DetImage()
image.data = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
image.width = image_data.shape[1]
image.height = image_data.shape[0]
image.channels = image_data.shape[2]
image.stride = image_data.shape[1] * image_data.shape[2]
result = ObjectResult()
check_error(_lib.ax_det(self.handle, ctypes.byref(image), ctypes.byref(result)))
objects = []
for i in range(result.num_objs):
_obj = result.objects[i]
obj = Object(
box=[_obj.box[0], _obj.box[1], _obj.box[2], _obj.box[3]],
score=_obj.score,
label=_obj.label,
kpts=[(_obj.kpts[j][0], _obj.kpts[j][1]) for j in range(_obj.num_kpt)],
)
objects.append(obj)
return objects
|