File size: 5,824 Bytes
14a799a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import ctypes
import os
from typing import List, Tuple
import numpy as np
import platform
from pyaxdev import _lib, AxDeviceType, AxDevices, check_error
class ClipInit(ctypes.Structure):
_fields_ = [
('dev_type', AxDeviceType),
('devid', ctypes.c_char),
('text_encoder_path', ctypes.c_char * 128),
('image_encoder_path', ctypes.c_char * 128),
('tokenizer_path', ctypes.c_char * 128),
('isCN', ctypes.c_char),
('db_path', ctypes.c_char * 128)
]
class ClipImage(ctypes.Structure):
_fields_ = [
('data', ctypes.POINTER(ctypes.c_ubyte)),
('width', ctypes.c_int),
('height', ctypes.c_int),
('channels', ctypes.c_int),
('stride', ctypes.c_int)
]
class ClipFeatureItem(ctypes.Structure):
_fields_ = [
('feat', ctypes.c_float * 768),
('len', ctypes.c_int)
]
class ClipResultItem(ctypes.Structure):
_fields_ = [
('key', ctypes.c_char * 64),
('score', ctypes.c_float)
]
_lib.clip_create.argtypes = [ctypes.POINTER(ClipInit), ctypes.POINTER(ctypes.c_void_p)]
_lib.clip_create.restype = ctypes.c_int
_lib.clip_destroy.argtypes = [ctypes.c_void_p]
_lib.clip_destroy.restype = ctypes.c_int
_lib.clip_add.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipImage), ctypes.c_char]
_lib.clip_add.restype = ctypes.c_int
_lib.clip_remove.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
_lib.clip_remove.restype = ctypes.c_int
_lib.clip_contain.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
_lib.clip_contain.restype = ctypes.c_int
_lib.clip_get_text_feat.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipFeatureItem)]
_lib.clip_get_text_feat.restype = ctypes.c_int
_lib.clip_match_feat.argtypes = [ctypes.c_void_p, ctypes.POINTER(ClipFeatureItem), ctypes.POINTER(ClipResultItem), ctypes.c_int]
_lib.clip_match_feat.restype = ctypes.c_int
_lib.clip_match_text.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipResultItem), ctypes.c_int]
_lib.clip_match_text.restype = ctypes.c_int
_lib.clip_match_image.argtypes = [ctypes.c_void_p, ctypes.POINTER(ClipImage), ctypes.POINTER(ClipResultItem), ctypes.c_int]
_lib.clip_match_image.restype = ctypes.c_int
class Clip:
def __init__(self, init_info: dict):
self.handle = None
self.init_info = ClipInit()
# 设置初始化参数
self.init_info.dev_type = init_info.get('dev_type', AxDeviceType.axcl_device)
self.init_info.devid = init_info.get('devid', 0)
self.init_info.isCN = init_info.get('isCN', 1)
# 设置路径
for path_name in ['text_encoder_path', 'image_encoder_path', 'tokenizer_path', 'db_path']:
if path_name in init_info:
setattr(self.init_info, path_name, init_info[path_name].encode('utf-8'))
# 创建CLIP实例
handle = ctypes.c_void_p()
check_error(_lib.clip_create(ctypes.byref(self.init_info), ctypes.byref(handle)))
self.handle = handle
def __del__(self):
if self.handle:
_lib.clip_destroy(self.handle)
def add_image(self, key: str, image_data: np.ndarray) -> None:
if self.contains_image(key):
return
image = ClipImage()
image.data = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
image.width = image_data.shape[1]
image.height = image_data.shape[0]
image.channels = image_data.shape[2]
image.stride = image_data.shape[1] * image_data.shape[2]
check_error(_lib.clip_add(self.handle, key.encode('utf-8'), ctypes.byref(image), 0))
def remove_image(self, key: str) -> None:
check_error(_lib.clip_remove(self.handle, key.encode('utf-8')))
def contains_image(self, key: str) -> bool:
return _lib.clip_contain(self.handle, key.encode('utf-8')) == 1
def get_text_feat(self, text: str) -> np.ndarray:
feat = ClipFeatureItem()
check_error(_lib.clip_get_text_feat(self.handle, text.encode('utf-8'), ctypes.byref(feat)))
return np.array(feat.feat[:feat.len])
def match_feat(self, feat: np.ndarray, top_k: int = 10) -> List[Tuple[str, float]]:
feat_item = ClipFeatureItem()
# feat_item.feat = feat.astype(np.float32).tolist()
arr = feat.astype(np.float32)
for i in range(len(arr)):
feat_item.feat[i] = arr[i]
feat_item.len = len(feat)
results = (ClipResultItem * top_k)()
check_error(_lib.clip_match_feat(self.handle, ctypes.byref(feat_item), results, top_k))
return [(item.key.decode('utf-8'), item.score) for item in results]
def match_text(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
results = (ClipResultItem * top_k)()
check_error(_lib.clip_match_text(self.handle, text.encode('utf-8'), results, top_k))
return [(item.key.decode('utf-8'), item.score) for item in results]
def match_image(self, image_data: bytes, width: int, height: int, channels: int = 3, top_k: int = 10) -> List[Tuple[str, float]]:
image = ClipImage()
image.data = ctypes.cast(ctypes.create_string_buffer(image_data), ctypes.POINTER(ctypes.c_ubyte))
image.width = width
image.height = height
image.channels = channels
image.stride = width * channels
results = (ClipResultItem * top_k)()
check_error(_lib.clip_match_image(self.handle, ctypes.byref(image), ctypes.byref(results), top_k))
return [(item.key.decode('utf-8'), item.score) for item in results] |