import clip import torch import hashlib import numpy as np from PIL import Image from image_process import crop_image_from_background from image_process import concate_images_vertically from image_process import get_products_from_pdf_file class ImageSearch(): def __init__(self): print('CLIP Models', clip.available_models()) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model, self.preprocess = clip.load('ViT-B/32', self.device) self.top_k = 2 self.base_product_features = {} self.base_product_pdf_imgs = {} def encode_image(self, image_list): if not isinstance(image_list, list): image_list = [image_list] img_batch = torch.cat([self.preprocess(image).unsqueeze(0) for image in image_list], dim=0).to(self.device) with torch.no_grad(): img_features = self.model.encode_image(img_batch) return img_features.cpu().numpy() def get_base_product_features(self, product_hashes, product_images): _product_hashs = [] _product_images = [] for product_hash, product_image in zip(product_hashes, product_images): if product_hash in self.base_product_features: continue _product_hashs.append(product_hash) _product_images.append(product_image) if len(_product_hashs) > 0: _product_features = self.encode_image(_product_images) for _product_hash, _product_feature in zip(_product_hashs, _product_features): self.base_product_features[_product_hash] = _product_feature def upload_products_pdf_file(self, pdf_files): product_names = [] product_hashes = [] product_cropped_images = [] for pdf_file in pdf_files: products_images = get_products_from_pdf_file(pdf_file) if len(products_images) > 0: for product_images in products_images: product_pdf_image = concate_images_vertically(product_images) product_hash = hashlib.md5(product_pdf_image.tobytes()).hexdigest() self.base_product_pdf_imgs[product_hash] = product_pdf_image product_hashes.append(product_hash) # 每份作品登记证书,第一页是文本,第二页是图片 product_image = product_images[1] if len(products_images) == 1: product_names.append(pdf_file.name.split('/')[-1]) else: product_names.append(product_hash) product_image = crop_image_from_background(product_image) product_image = Image.fromarray(product_image) product_cropped_images.append(product_image) self.get_base_product_features(product_hashes, product_cropped_images) return zip(product_cropped_images, product_names) def upload_wait2search_image(self, image_infos): wait2search_image_list = [] wait2search_image_hashes = [] wait2search_image_names = [] for image_info in image_infos: image_file, image_label = image_info wait2search_image = Image.open(image_file) wait2search_image_names.append(image_file.split('/')[-1]) wait2search_image_list.append(wait2search_image) wait2search_image_hashes.append(hashlib.md5(wait2search_image.tobytes()).hexdigest()) search_results = self.search_image(wait2search_image_hashes, wait2search_image_list, wait2search_image_names) return search_results def search_image(self, wait2search_image_hashes, wait2search_image_list, wait2search_image_names): base_product_features = torch.from_numpy(np.array(list(self.base_product_features.values()))).to(self.device) base_product_features /= base_product_features.norm(dim=-1, keepdim=True) wait2search_image_features = torch.from_numpy(self.encode_image(wait2search_image_list)).to(self.device) wait2search_image_features /= wait2search_image_features.norm(dim=-1, keepdim=True) similarity = wait2search_image_features @ base_product_features.T values, indices = similarity.topk(self.top_k) search_results = {} for idx, (value, indice) in enumerate(zip(values, indices)): pdf_img_list = [] for i in range(self.top_k): base_product_hash = list(self.base_product_features.keys())[indice[i]] if value[i] > 0.75: base_product_pdf_img = self.base_product_pdf_imgs[base_product_hash] else: base_product_pdf_img = Image.open('assets/not_found.png').convert('RGB') pdf_img_list.append(base_product_pdf_img) res_pdf_img = concate_images_vertically(pdf_img_list, list(value.cpu().numpy())) search_results[wait2search_image_hashes[idx]] = res_pdf_img return zip(list(search_results.values()), wait2search_image_names) SearchImageTask = ImageSearch()