wli1995 commited on
Commit
14a799a
·
verified ·
1 Parent(s): 44e9a1d

update python script

Browse files
Files changed (3) hide show
  1. pyclip/gradio_example.py +96 -83
  2. pyclip/pyaxdev.py +147 -0
  3. pyclip/pyclip.py +147 -260
pyclip/gradio_example.py CHANGED
@@ -1,83 +1,96 @@
1
- import os
2
- import gradio as gr
3
- from pyclip import Clip, enum_devices, sys_init, sys_deinit, ClipDeviceType
4
- import cv2
5
- import glob
6
- from PIL import Image
7
- import tqdm
8
- import argparse
9
-
10
- if __name__ == '__main__':
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument('--ienc', type=str, default='cnclip/cnclip_vit_l14_336px_vision_u16u8.axmodel')
13
- parser.add_argument('--tenc', type=str, default='cnclip/cnclip_vit_l14_336px_text_u16.axmodel')
14
- parser.add_argument('--vocab', type=str, default='cnclip/cn_vocab.txt')
15
- parser.add_argument('--isCN', type=int, default=1)
16
- parser.add_argument('--db_path', type=str, default='clip_feat_db_coco')
17
- parser.add_argument('--image_folder', type=str, default='coco_1000')
18
- parser.add_argument('--dev_type', type=str, default='host', help='host or axcl')
19
- args = parser.parse_args()
20
-
21
- image_folder = args.image_folder
22
- device = ClipDeviceType.host_device if args.dev_type == 'host' else ClipDeviceType.axcl_device
23
-
24
- # 初始化
25
- print("可用设备:", enum_devices())
26
- sys_init(device, 0)
27
-
28
- clip = Clip({
29
- 'text_encoder_path': args.tenc,
30
- 'image_encoder_path': args.ienc,
31
- 'tokenizer_path': args.vocab,
32
- 'db_path': args.db_path,
33
- 'isCN': args.isCN,
34
- 'dev_type': device,
35
- })
36
-
37
-
38
- # 加载图片数据库(只做一次)
39
- image_files = glob.glob(os.path.join(image_folder, '*.jpg'))
40
- for image_file in tqdm.tqdm(image_files):
41
- filename = os.path.basename(image_file)
42
- if clip.contains_image(filename) == 1:
43
- continue
44
- img = cv2.imread(image_file)
45
- cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
46
- clip.add_image(filename, img)
47
-
48
- # 工具函数:图片转 base64
49
- def img_to_pil(img_path):
50
- return Image.open(img_path).convert("RGB")
51
-
52
- # 主搜索函数
53
- def search_images(query, top_k):
54
- results = clip.match_text(query, top_k=top_k)
55
- images = []
56
- for filename, score in results:
57
- img_path = os.path.join(image_folder, filename)
58
- if os.path.exists(img_path):
59
- img = img_to_pil(img_path)
60
- images.append((img, f"{filename} Score: {score:.4f}"))
61
- return images
62
-
63
-
64
- # Gradio界面
65
- with gr.Blocks() as demo:
66
- gr.Markdown("# 🔍 文搜图 Demo")
67
-
68
- with gr.Row():
69
- query_input = gr.Textbox(label="请输入文本查询")
70
- topk_input = gr.Number(value=25, precision=0, label="Top-K")
71
- search_btn = gr.Button("搜图")
72
-
73
- gallery = gr.Gallery(label="匹配结果", show_label=True, columns=4)
74
-
75
- search_btn.click(fn=search_images, inputs=[query_input, topk_input], outputs=gallery)
76
-
77
- # 启动
78
- ip = "0.0.0.0"
79
- demo.launch(server_name=ip, server_port=7860)
80
-
81
- # 关闭系统(你可加信号处理来自动关闭)
82
- import atexit
83
- atexit.register(lambda: sys_deinit(ClipDeviceType.host_device, 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from pyclip import Clip
4
+ from pyaxdev import enum_devices, sys_init, sys_deinit, AxDeviceType
5
+ import cv2
6
+ import glob
7
+ from PIL import Image
8
+ import tqdm
9
+ import argparse
10
+
11
+ if __name__ == '__main__':
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--ienc', type=str, default='cnclip/cnclip_vit_l14_336px_vision_u16u8.axmodel')
14
+ parser.add_argument('--tenc', type=str, default='cnclip/cnclip_vit_l14_336px_text_u16.axmodel')
15
+ parser.add_argument('--vocab', type=str, default='cnclip/cn_vocab.txt')
16
+ parser.add_argument('--isCN', type=int, default=1)
17
+ parser.add_argument('--db_path', type=str, default='clip_feat_db_coco')
18
+ parser.add_argument('--image_folder', type=str, default='coco_1000')
19
+ parser.add_argument('--dev_type', type=str, default='host', choices=['host', 'axcl'])
20
+ args = parser.parse_args()
21
+
22
+ image_folder = args.image_folder
23
+
24
+ # 初始化
25
+ devices_info = enum_devices()
26
+ print("可用设备:", devices_info)
27
+ device = AxDeviceType.host_device if args.dev_type == 'host' else AxDeviceType.axcl_device
28
+
29
+ if devices_info['host']['available']:
30
+ print("host device available")
31
+ sys_init(device, 0)
32
+ elif devices_info['devices']['count'] > 0:
33
+ print("axcl device available, use device-0")
34
+ sys_init(device, 0)
35
+ else:
36
+ raise Exception("No available device")
37
+
38
+ clip = Clip({
39
+ 'text_encoder_path': args.tenc,
40
+ 'image_encoder_path': args.ienc,
41
+ 'tokenizer_path': args.vocab,
42
+ 'db_path': args.db_path,
43
+ 'isCN': args.isCN,
44
+ 'dev_type': device
45
+ })
46
+
47
+
48
+ # 加载图片数据库(只做一次)
49
+ image_files = glob.glob(os.path.join(image_folder, '*.jpg'))
50
+ for image_file in tqdm.tqdm(image_files):
51
+ filename = os.path.basename(image_file)
52
+ if clip.contains_image(filename) == 1:
53
+ continue
54
+ img = cv2.imread(image_file)
55
+ cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
56
+ clip.add_image(filename, img)
57
+
58
+ def img_to_pil(img_path):
59
+ return Image.open(img_path).convert("RGB")
60
+
61
+ # 主搜索函数
62
+ def search_images(query, top_k):
63
+ results = clip.match_text(query, top_k=top_k)
64
+ images = []
65
+ for filename, score in results:
66
+ img_path = os.path.join(image_folder, filename)
67
+ if os.path.exists(img_path):
68
+ img = img_to_pil(img_path)
69
+ images.append((img, f"{filename} Score: {score:.4f}"))
70
+ return images
71
+
72
+
73
+ # Gradio界面
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("# 🔍 文搜图 Demo")
76
+
77
+ with gr.Row():
78
+ query_input = gr.Textbox(label="请输入文本查询")
79
+ topk_input = gr.Number(value=25, precision=0, label="Top-K")
80
+ search_btn = gr.Button("搜图")
81
+
82
+ gallery = gr.Gallery(label="匹配结果", show_label=True, columns=4)
83
+
84
+ search_btn.click(fn=search_images, inputs=[query_input, topk_input], outputs=gallery)
85
+
86
+ # 启动
87
+ ip = "0.0.0.0"
88
+ demo.launch(server_name=ip, server_port=7860)
89
+
90
+ import atexit
91
+ if devices_info['host']['available']:
92
+ atexit.register(lambda: sys_deinit(AxDeviceType.host_device, -1))
93
+ elif devices_info['devices']['count'] > 0:
94
+ atexit.register(lambda: sys_deinit(AxDeviceType.axcl_device, 0))
95
+
96
+
pyclip/pyaxdev.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import os
3
+ import platform
4
+
5
+ def check_error(code: int) -> None:
6
+ if code != 0:
7
+ raise Exception(f"API错误: {code}")
8
+
9
+ base_dir = os.path.dirname(__file__)
10
+ arch = platform.machine()
11
+
12
+ if arch == 'x86_64':
13
+ arch_dir = 'x86_64'
14
+ elif arch in ('aarch64', 'arm64'):
15
+ arch_dir = 'aarch64'
16
+ else:
17
+ raise RuntimeError(f"Unsupported architecture: {arch}")
18
+
19
+ lib_paths = [
20
+ os.path.join(base_dir, arch_dir, 'libclip.so'),
21
+ os.path.join(base_dir, 'libclip.so')
22
+ ]
23
+
24
+ last_error = None
25
+ diagnostic_shown = set()
26
+
27
+ for lib_path in lib_paths:
28
+ try:
29
+ print(f"Trying to load: {lib_path}")
30
+ _lib = ctypes.CDLL(lib_path)
31
+ print(f"✅ Successfully loaded: {lib_path}")
32
+ break
33
+ except OSError as e:
34
+ last_error = e
35
+ err_str = str(e)
36
+ print(f"\n❌ Failed to load: {lib_path}")
37
+ print(f" {err_str}")
38
+
39
+ # Only show GLIBCXX tip once
40
+ if "GLIBCXX" in err_str and "not found" in err_str:
41
+ if "missing_glibcxx" not in diagnostic_shown:
42
+ diagnostic_shown.add("missing_glibcxx")
43
+ print("🔍 Detected missing GLIBCXX version in libstdc++.so.6")
44
+ print("💡 This usually happens when your environment (like Conda) uses an older libstdc++")
45
+ print(f"👉 Try running with system libstdc++ preloaded:")
46
+ print(f" export LD_PRELOAD=/usr/lib/{arch_dir}-linux-gnu/libstdc++.so.6\n")
47
+ elif "No such file" in err_str:
48
+ if "file_not_found" not in diagnostic_shown:
49
+ diagnostic_shown.add("file_not_found")
50
+ print("🔍 File not found. Please verify that libclip.so exists and the path is correct.\n")
51
+ elif "wrong ELF class" in err_str:
52
+ if "elf_mismatch" not in diagnostic_shown:
53
+ diagnostic_shown.add("elf_mismatch")
54
+ print("🔍 ELF class mismatch — likely due to architecture conflict (e.g., loading x86_64 .so on aarch64).")
55
+ print(f"👉 Run `file {lib_path}` to verify the binary architecture.\n")
56
+ else:
57
+ if "generic_error" not in diagnostic_shown:
58
+ diagnostic_shown.add("generic_error")
59
+ print("📎 Tip: Use `ldd` to inspect missing dependencies:")
60
+ print(f" ldd {lib_path}\n")
61
+ else:
62
+ raise RuntimeError(f"\n❗ Failed to load libclip.so.\nLast error:\n{last_error}")
63
+
64
+
65
+ # 定义枚举类型
66
+ class AxDeviceType(ctypes.c_int):
67
+ unknown_device = 0
68
+ host_device = 1
69
+ axcl_device = 2
70
+
71
+ # 定义结构体
72
+ class AxMemInfo(ctypes.Structure):
73
+ _fields_ = [
74
+ ('remain', ctypes.c_int),
75
+ ('total', ctypes.c_int)
76
+ ]
77
+
78
+ class AxHostInfo(ctypes.Structure):
79
+ _fields_ = [
80
+ ('available', ctypes.c_char),
81
+ ('version', ctypes.c_char * 32),
82
+ ('mem_info', AxMemInfo)
83
+ ]
84
+
85
+ class AxDeviceInfo(ctypes.Structure):
86
+ _fields_ = [
87
+ ('temp', ctypes.c_int),
88
+ ('cpu_usage', ctypes.c_int),
89
+ ('npu_usage', ctypes.c_int),
90
+ ('mem_info', AxMemInfo)
91
+ ]
92
+
93
+ class AxDevices(ctypes.Structure):
94
+ _fields_ = [
95
+ ('host', AxHostInfo),
96
+ ('host_version', ctypes.c_char * 32),
97
+ ('dev_version', ctypes.c_char * 32),
98
+ ('count', ctypes.c_ubyte),
99
+ ('devices_info', AxDeviceInfo * 16)
100
+ ]
101
+
102
+
103
+ _lib.ax_dev_enum_devices.argtypes = [ctypes.POINTER(AxDevices)]
104
+ _lib.ax_dev_enum_devices.restype = ctypes.c_int
105
+
106
+ _lib.ax_dev_sys_init.argtypes = [AxDeviceType, ctypes.c_char]
107
+ _lib.ax_dev_sys_init.restype = ctypes.c_int
108
+
109
+ _lib.ax_dev_sys_deinit.argtypes = [AxDeviceType, ctypes.c_char]
110
+ _lib.ax_dev_sys_deinit.restype = ctypes.c_int
111
+
112
+ def enum_devices() -> dict:
113
+ devices = AxDevices()
114
+ check_error(_lib.ax_dev_enum_devices(ctypes.byref(devices)))
115
+
116
+ return {
117
+ 'host': {
118
+ 'available': bool(devices.host.available[0]),
119
+ 'version': devices.host.version.decode('utf-8'),
120
+ 'mem_info': {
121
+ 'remain': devices.host.mem_info.remain,
122
+ 'total': devices.host.mem_info.total
123
+ }
124
+ },
125
+ 'devices': {
126
+ 'host_version': devices.host_version.decode('utf-8'),
127
+ 'dev_version': devices.dev_version.decode('utf-8'),
128
+ 'count': devices.count,
129
+ 'devices_info': [{
130
+ 'temp': dev.temp,
131
+ 'cpu_usage': dev.cpu_usage,
132
+ 'npu_usage': dev.npu_usage,
133
+ 'mem_info': {
134
+ 'remain': dev.mem_info.remain,
135
+ 'total': dev.mem_info.total
136
+ }
137
+ } for dev in devices.devices_info[:devices.count]]
138
+ }
139
+ }
140
+
141
+
142
+ def sys_init(dev_type: AxDeviceType = AxDeviceType.axcl_device, devid: int = 0) -> None:
143
+ check_error(_lib.ax_dev_sys_init(dev_type, devid))
144
+
145
+
146
+ def sys_deinit(dev_type: AxDeviceType = AxDeviceType.axcl_device, devid: int = 0) -> None:
147
+ check_error(_lib.ax_dev_sys_deinit(dev_type, devid))
pyclip/pyclip.py CHANGED
@@ -1,260 +1,147 @@
1
- import ctypes
2
- import os
3
- from typing import List, Tuple, Optional
4
- import numpy as np
5
- import platform
6
-
7
- base_dir = os.path.dirname(__file__)
8
- arch = platform.machine()
9
-
10
- if arch == 'x86_64':
11
- arch_dir = 'x86_64'
12
- elif arch in ('aarch64', 'arm64'):
13
- arch_dir = 'aarch64'
14
- else:
15
- raise RuntimeError(f"Unsupported architecture: {arch}")
16
-
17
- lib_paths = [
18
- os.path.join(base_dir, arch_dir, 'libclip.so'),
19
- os.path.join(base_dir, 'libclip.so')
20
- ]
21
-
22
- last_error = None
23
- diagnostic_shown = set()
24
-
25
- for lib_path in lib_paths:
26
- try:
27
- print(f"Trying to load: {lib_path}")
28
- _lib = ctypes.CDLL(lib_path)
29
- print(f"✅ Successfully loaded: {lib_path}")
30
- break
31
- except OSError as e:
32
- last_error = e
33
- err_str = str(e)
34
- print(f"\n❌ Failed to load: {lib_path}")
35
- print(f" {err_str}")
36
-
37
- # Only show GLIBCXX tip once
38
- if "GLIBCXX" in err_str and "not found" in err_str:
39
- if "missing_glibcxx" not in diagnostic_shown:
40
- diagnostic_shown.add("missing_glibcxx")
41
- print("🔍 Detected missing GLIBCXX version in libstdc++.so.6")
42
- print("💡 This usually happens when your environment (like Conda) uses an older libstdc++")
43
- print(f"👉 Try running with system libstdc++ preloaded:")
44
- print(f" export LD_PRELOAD=/usr/lib/{arch_dir}-linux-gnu/libstdc++.so.6\n")
45
- elif "No such file" in err_str:
46
- if "file_not_found" not in diagnostic_shown:
47
- diagnostic_shown.add("file_not_found")
48
- print("🔍 File not found. Please verify that libclip.so exists and the path is correct.\n")
49
- elif "wrong ELF class" in err_str:
50
- if "elf_mismatch" not in diagnostic_shown:
51
- diagnostic_shown.add("elf_mismatch")
52
- print("🔍 ELF class mismatch — likely due to architecture conflict (e.g., loading x86_64 .so on aarch64).")
53
- print(f"👉 Run `file {lib_path}` to verify the binary architecture.\n")
54
- else:
55
- if "generic_error" not in diagnostic_shown:
56
- diagnostic_shown.add("generic_error")
57
- print("📎 Tip: Use `ldd` to inspect missing dependencies:")
58
- print(f" ldd {lib_path}\n")
59
- else:
60
- raise RuntimeError(f"\n❗ Failed to load libclip.so.\nLast error:\n{last_error}")
61
-
62
-
63
- # 定义枚举类型
64
- class ClipDeviceType(ctypes.c_int):
65
- unknown_device = 0
66
- host_device = 1
67
- axcl_device = 2
68
-
69
- # 定义结构体
70
- class ClipMemInfo(ctypes.Structure):
71
- _fields_ = [
72
- ('remain', ctypes.c_int),
73
- ('total', ctypes.c_int)
74
- ]
75
-
76
- class ClipHostInfo(ctypes.Structure):
77
- _fields_ = [
78
- ('available', ctypes.c_char),
79
- ('version', ctypes.c_char * 32),
80
- ('mem_info', ClipMemInfo)
81
- ]
82
-
83
- class ClipDeviceInfo(ctypes.Structure):
84
- _fields_ = [
85
- ('temp', ctypes.c_int),
86
- ('cpu_usage', ctypes.c_int),
87
- ('npu_usage', ctypes.c_int),
88
- ('mem_info', ClipMemInfo)
89
- ]
90
-
91
- class ClipDevices(ctypes.Structure):
92
- _fields_ = [
93
- ('host', ClipHostInfo),
94
- ('host_version', ctypes.c_char * 32),
95
- ('dev_version', ctypes.c_char * 32),
96
- ('count', ctypes.c_ubyte),
97
- ('devices_info', ClipDeviceInfo * 16)
98
- ]
99
-
100
- class ClipInit(ctypes.Structure):
101
- _fields_ = [
102
- ('dev_type', ClipDeviceType),
103
- ('devid', ctypes.c_char),
104
- ('text_encoder_path', ctypes.c_char * 128),
105
- ('image_encoder_path', ctypes.c_char * 128),
106
- ('tokenizer_path', ctypes.c_char * 128),
107
- ('isCN', ctypes.c_char),
108
- ('db_path', ctypes.c_char * 128)
109
- ]
110
-
111
- class ClipImage(ctypes.Structure):
112
- _fields_ = [
113
- ('data', ctypes.POINTER(ctypes.c_ubyte)),
114
- ('width', ctypes.c_int),
115
- ('height', ctypes.c_int),
116
- ('channels', ctypes.c_int),
117
- ('stride', ctypes.c_int)
118
- ]
119
-
120
- class ClipResultItem(ctypes.Structure):
121
- _fields_ = [
122
- ('key', ctypes.c_char * 64),
123
- ('score', ctypes.c_float)
124
- ]
125
-
126
- # 设置函数参数和返回类型
127
- _lib.clip_enum_devices.argtypes = [ctypes.POINTER(ClipDevices)]
128
- _lib.clip_enum_devices.restype = ctypes.c_int
129
-
130
- _lib.clip_sys_init.argtypes = [ClipDeviceType, ctypes.c_char]
131
- _lib.clip_sys_init.restype = ctypes.c_int
132
-
133
- _lib.clip_sys_deinit.argtypes = [ClipDeviceType, ctypes.c_char]
134
- _lib.clip_sys_deinit.restype = ctypes.c_int
135
-
136
- _lib.clip_create.argtypes = [ctypes.POINTER(ClipInit), ctypes.POINTER(ctypes.c_void_p)]
137
- _lib.clip_create.restype = ctypes.c_int
138
-
139
- _lib.clip_destroy.argtypes = [ctypes.c_void_p]
140
- _lib.clip_destroy.restype = ctypes.c_int
141
-
142
- _lib.clip_add.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipImage), ctypes.c_char]
143
- _lib.clip_add.restype = ctypes.c_int
144
-
145
- _lib.clip_remove.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
146
- _lib.clip_remove.restype = ctypes.c_int
147
-
148
- _lib.clip_contain.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
149
- _lib.clip_contain.restype = ctypes.c_int
150
-
151
- _lib.clip_match_text.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipResultItem), ctypes.c_int]
152
- _lib.clip_match_text.restype = ctypes.c_int
153
-
154
- _lib.clip_match_image.argtypes = [ctypes.c_void_p, ctypes.POINTER(ClipImage), ctypes.POINTER(ClipResultItem), ctypes.c_int]
155
- _lib.clip_match_image.restype = ctypes.c_int
156
-
157
- class ClipError(Exception):
158
- pass
159
-
160
- def check_error(code: int) -> None:
161
- if code != 0:
162
- raise ClipError(f"CLIP API错误: {code}")
163
-
164
- class Clip:
165
- def __init__(self, init_info: dict):
166
- self.handle = None
167
- self.init_info = ClipInit()
168
-
169
- # 设置初始化参数
170
- self.init_info.dev_type = init_info.get('dev_type', ClipDeviceType.axcl_device)
171
- self.init_info.devid = init_info.get('devid', 0)
172
- self.init_info.isCN = init_info.get('isCN', 1)
173
-
174
- # 设置路径
175
- for path_name in ['text_encoder_path', 'image_encoder_path', 'tokenizer_path', 'db_path']:
176
- if path_name in init_info:
177
- setattr(self.init_info, path_name, init_info[path_name].encode('utf-8'))
178
-
179
- # 创建CLIP实例
180
- handle = ctypes.c_void_p()
181
- check_error(_lib.clip_create(ctypes.byref(self.init_info), ctypes.byref(handle)))
182
- self.handle = handle
183
-
184
- def __del__(self):
185
- if self.handle:
186
- _lib.clip_destroy(self.handle)
187
-
188
- def add_image(self, key: str, image_data: np.ndarray) -> None:
189
- if self.contains_image(key):
190
- return
191
- image = ClipImage()
192
- image.data = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
193
- image.width = image_data.shape[1]
194
- image.height = image_data.shape[0]
195
- image.channels = image_data.shape[2]
196
- image.stride = image_data.shape[1] * image_data.shape[2]
197
-
198
- check_error(_lib.clip_add(self.handle, key.encode('utf-8'), ctypes.byref(image), 0))
199
-
200
- def remove_image(self, key: str) -> None:
201
- check_error(_lib.clip_remove(self.handle, key.encode('utf-8')))
202
-
203
- def contains_image(self, key: str) -> bool:
204
- return _lib.clip_contain(self.handle, key.encode('utf-8')) == 1
205
-
206
- def match_text(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
207
- results = (ClipResultItem * top_k)()
208
- check_error(_lib.clip_match_text(self.handle, text.encode('utf-8'), results, top_k))
209
-
210
- return [(item.key.decode('utf-8'), item.score) for item in results]
211
-
212
- def match_image(self, image_data: bytes, width: int, height: int, channels: int = 3, top_k: int = 10) -> List[Tuple[str, float]]:
213
- image = ClipImage()
214
- image.data = ctypes.cast(ctypes.create_string_buffer(image_data), ctypes.POINTER(ctypes.c_ubyte))
215
- image.width = width
216
- image.height = height
217
- image.channels = channels
218
- image.stride = width * channels
219
-
220
- results = (ClipResultItem * top_k)()
221
- check_error(_lib.clip_match_image(self.handle, ctypes.byref(image), ctypes.byref(results), top_k))
222
-
223
- return [(item.key.decode('utf-8'), item.score) for item in results]
224
-
225
- def enum_devices() -> dict:
226
- devices = ClipDevices()
227
- check_error(_lib.clip_enum_devices(ctypes.byref(devices)))
228
-
229
- return {
230
- 'host': {
231
- 'available': bool(devices.host.available),
232
- 'version': devices.host.version.decode('utf-8'),
233
- 'mem_info': {
234
- 'remain': devices.host.mem_info.remain,
235
- 'total': devices.host.mem_info.total
236
- }
237
- },
238
- 'devices': {
239
- 'host_version': devices.host_version.decode('utf-8'),
240
- 'dev_version': devices.dev_version.decode('utf-8'),
241
- 'count': devices.count,
242
- 'devices_info': [{
243
- 'temp': dev.temp,
244
- 'cpu_usage': dev.cpu_usage,
245
- 'npu_usage': dev.npu_usage,
246
- 'mem_info': {
247
- 'remain': dev.mem_info.remain,
248
- 'total': dev.mem_info.total
249
- }
250
- } for dev in devices.devices_info[:devices.count]]
251
- }
252
- }
253
-
254
-
255
- def sys_init(dev_type: ClipDeviceType = ClipDeviceType.axcl_device, devid: int = 0) -> None:
256
- check_error(_lib.clip_sys_init(dev_type, devid))
257
-
258
-
259
- def sys_deinit(dev_type: ClipDeviceType = ClipDeviceType.axcl_device, devid: int = 0) -> None:
260
- check_error(_lib.clip_sys_deinit(dev_type, devid))
 
1
+ import ctypes
2
+ import os
3
+ from typing import List, Tuple
4
+ import numpy as np
5
+ import platform
6
+ from pyaxdev import _lib, AxDeviceType, AxDevices, check_error
7
+
8
+
9
+ class ClipInit(ctypes.Structure):
10
+ _fields_ = [
11
+ ('dev_type', AxDeviceType),
12
+ ('devid', ctypes.c_char),
13
+ ('text_encoder_path', ctypes.c_char * 128),
14
+ ('image_encoder_path', ctypes.c_char * 128),
15
+ ('tokenizer_path', ctypes.c_char * 128),
16
+ ('isCN', ctypes.c_char),
17
+ ('db_path', ctypes.c_char * 128)
18
+ ]
19
+
20
+ class ClipImage(ctypes.Structure):
21
+ _fields_ = [
22
+ ('data', ctypes.POINTER(ctypes.c_ubyte)),
23
+ ('width', ctypes.c_int),
24
+ ('height', ctypes.c_int),
25
+ ('channels', ctypes.c_int),
26
+ ('stride', ctypes.c_int)
27
+ ]
28
+
29
+ class ClipFeatureItem(ctypes.Structure):
30
+ _fields_ = [
31
+ ('feat', ctypes.c_float * 768),
32
+ ('len', ctypes.c_int)
33
+ ]
34
+
35
+
36
+ class ClipResultItem(ctypes.Structure):
37
+ _fields_ = [
38
+ ('key', ctypes.c_char * 64),
39
+ ('score', ctypes.c_float)
40
+ ]
41
+
42
+ _lib.clip_create.argtypes = [ctypes.POINTER(ClipInit), ctypes.POINTER(ctypes.c_void_p)]
43
+ _lib.clip_create.restype = ctypes.c_int
44
+
45
+ _lib.clip_destroy.argtypes = [ctypes.c_void_p]
46
+ _lib.clip_destroy.restype = ctypes.c_int
47
+
48
+ _lib.clip_add.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipImage), ctypes.c_char]
49
+ _lib.clip_add.restype = ctypes.c_int
50
+
51
+ _lib.clip_remove.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
52
+ _lib.clip_remove.restype = ctypes.c_int
53
+
54
+ _lib.clip_contain.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
55
+ _lib.clip_contain.restype = ctypes.c_int
56
+
57
+ _lib.clip_get_text_feat.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipFeatureItem)]
58
+ _lib.clip_get_text_feat.restype = ctypes.c_int
59
+
60
+ _lib.clip_match_feat.argtypes = [ctypes.c_void_p, ctypes.POINTER(ClipFeatureItem), ctypes.POINTER(ClipResultItem), ctypes.c_int]
61
+ _lib.clip_match_feat.restype = ctypes.c_int
62
+
63
+ _lib.clip_match_text.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.POINTER(ClipResultItem), ctypes.c_int]
64
+ _lib.clip_match_text.restype = ctypes.c_int
65
+
66
+ _lib.clip_match_image.argtypes = [ctypes.c_void_p, ctypes.POINTER(ClipImage), ctypes.POINTER(ClipResultItem), ctypes.c_int]
67
+ _lib.clip_match_image.restype = ctypes.c_int
68
+
69
+
70
+ class Clip:
71
+ def __init__(self, init_info: dict):
72
+ self.handle = None
73
+ self.init_info = ClipInit()
74
+
75
+ # 设置初始化参数
76
+ self.init_info.dev_type = init_info.get('dev_type', AxDeviceType.axcl_device)
77
+ self.init_info.devid = init_info.get('devid', 0)
78
+ self.init_info.isCN = init_info.get('isCN', 1)
79
+
80
+ # 设置路径
81
+ for path_name in ['text_encoder_path', 'image_encoder_path', 'tokenizer_path', 'db_path']:
82
+ if path_name in init_info:
83
+ setattr(self.init_info, path_name, init_info[path_name].encode('utf-8'))
84
+
85
+ # 创建CLIP实例
86
+ handle = ctypes.c_void_p()
87
+ check_error(_lib.clip_create(ctypes.byref(self.init_info), ctypes.byref(handle)))
88
+ self.handle = handle
89
+
90
+ def __del__(self):
91
+ if self.handle:
92
+ _lib.clip_destroy(self.handle)
93
+
94
+ def add_image(self, key: str, image_data: np.ndarray) -> None:
95
+ if self.contains_image(key):
96
+ return
97
+ image = ClipImage()
98
+ image.data = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte))
99
+ image.width = image_data.shape[1]
100
+ image.height = image_data.shape[0]
101
+ image.channels = image_data.shape[2]
102
+ image.stride = image_data.shape[1] * image_data.shape[2]
103
+
104
+ check_error(_lib.clip_add(self.handle, key.encode('utf-8'), ctypes.byref(image), 0))
105
+
106
+ def remove_image(self, key: str) -> None:
107
+ check_error(_lib.clip_remove(self.handle, key.encode('utf-8')))
108
+
109
+ def contains_image(self, key: str) -> bool:
110
+ return _lib.clip_contain(self.handle, key.encode('utf-8')) == 1
111
+
112
+ def get_text_feat(self, text: str) -> np.ndarray:
113
+ feat = ClipFeatureItem()
114
+ check_error(_lib.clip_get_text_feat(self.handle, text.encode('utf-8'), ctypes.byref(feat)))
115
+ return np.array(feat.feat[:feat.len])
116
+
117
+ def match_feat(self, feat: np.ndarray, top_k: int = 10) -> List[Tuple[str, float]]:
118
+ feat_item = ClipFeatureItem()
119
+ # feat_item.feat = feat.astype(np.float32).tolist()
120
+ arr = feat.astype(np.float32)
121
+ for i in range(len(arr)):
122
+ feat_item.feat[i] = arr[i]
123
+ feat_item.len = len(feat)
124
+
125
+ results = (ClipResultItem * top_k)()
126
+ check_error(_lib.clip_match_feat(self.handle, ctypes.byref(feat_item), results, top_k))
127
+
128
+ return [(item.key.decode('utf-8'), item.score) for item in results]
129
+
130
+ def match_text(self, text: str, top_k: int = 10) -> List[Tuple[str, float]]:
131
+ results = (ClipResultItem * top_k)()
132
+ check_error(_lib.clip_match_text(self.handle, text.encode('utf-8'), results, top_k))
133
+
134
+ return [(item.key.decode('utf-8'), item.score) for item in results]
135
+
136
+ def match_image(self, image_data: bytes, width: int, height: int, channels: int = 3, top_k: int = 10) -> List[Tuple[str, float]]:
137
+ image = ClipImage()
138
+ image.data = ctypes.cast(ctypes.create_string_buffer(image_data), ctypes.POINTER(ctypes.c_ubyte))
139
+ image.width = width
140
+ image.height = height
141
+ image.channels = channels
142
+ image.stride = width * channels
143
+
144
+ results = (ClipResultItem * top_k)()
145
+ check_error(_lib.clip_match_image(self.handle, ctypes.byref(image), ctypes.byref(results), top_k))
146
+
147
+ return [(item.key.decode('utf-8'), item.score) for item in results]