Spaces:
Runtime error
Runtime error
| import clip | |
| import torch | |
| import hashlib | |
| import numpy as np | |
| from PIL import Image | |
| from image_process import crop_image_from_background | |
| from image_process import concate_images_vertically | |
| from image_process import get_products_from_pdf_file | |
| class ImageSearch(): | |
| def __init__(self): | |
| print('CLIP Models', clip.available_models()) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model, self.preprocess = clip.load('ViT-B/32', self.device) | |
| self.top_k = 2 | |
| self.base_product_features = {} | |
| self.base_product_pdf_imgs = {} | |
| def encode_image(self, image_list): | |
| if not isinstance(image_list, list): | |
| image_list = [image_list] | |
| img_batch = torch.cat([self.preprocess(image).unsqueeze(0) for image in image_list], dim=0).to(self.device) | |
| with torch.no_grad(): | |
| img_features = self.model.encode_image(img_batch) | |
| return img_features.cpu().numpy() | |
| def get_base_product_features(self, product_hashes, product_images): | |
| _product_hashs = [] | |
| _product_images = [] | |
| for product_hash, product_image in zip(product_hashes, product_images): | |
| if product_hash in self.base_product_features: | |
| continue | |
| _product_hashs.append(product_hash) | |
| _product_images.append(product_image) | |
| if len(_product_hashs) > 0: | |
| _product_features = self.encode_image(_product_images) | |
| for _product_hash, _product_feature in zip(_product_hashs, _product_features): | |
| self.base_product_features[_product_hash] = _product_feature | |
| def upload_products_pdf_file(self, pdf_files): | |
| product_names = [] | |
| product_hashes = [] | |
| product_cropped_images = [] | |
| for pdf_file in pdf_files: | |
| products_images = get_products_from_pdf_file(pdf_file) | |
| if len(products_images) > 0: | |
| for product_images in products_images: | |
| product_pdf_image = concate_images_vertically(product_images) | |
| product_hash = hashlib.md5(product_pdf_image.tobytes()).hexdigest() | |
| self.base_product_pdf_imgs[product_hash] = product_pdf_image | |
| product_hashes.append(product_hash) | |
| # 每份作品登记证书,第一页是文本,第二页是图片 | |
| product_image = product_images[1] | |
| if len(products_images) == 1: | |
| product_names.append(pdf_file.name.split('/')[-1]) | |
| else: | |
| product_names.append(product_hash) | |
| product_image = crop_image_from_background(product_image) | |
| product_image = Image.fromarray(product_image) | |
| product_cropped_images.append(product_image) | |
| self.get_base_product_features(product_hashes, product_cropped_images) | |
| return zip(product_cropped_images, product_names) | |
| def upload_wait2search_image(self, image_infos): | |
| wait2search_image_list = [] | |
| wait2search_image_hashes = [] | |
| wait2search_image_names = [] | |
| for image_info in image_infos: | |
| image_file, image_label = image_info | |
| wait2search_image = Image.open(image_file) | |
| wait2search_image_names.append(image_file.split('/')[-1]) | |
| wait2search_image_list.append(wait2search_image) | |
| wait2search_image_hashes.append(hashlib.md5(wait2search_image.tobytes()).hexdigest()) | |
| search_results = self.search_image(wait2search_image_hashes, wait2search_image_list, wait2search_image_names) | |
| return search_results | |
| def search_image(self, wait2search_image_hashes, wait2search_image_list, wait2search_image_names): | |
| base_product_features = torch.from_numpy(np.array(list(self.base_product_features.values()))).to(self.device) | |
| base_product_features /= base_product_features.norm(dim=-1, keepdim=True) | |
| wait2search_image_features = torch.from_numpy(self.encode_image(wait2search_image_list)).to(self.device) | |
| wait2search_image_features /= wait2search_image_features.norm(dim=-1, keepdim=True) | |
| similarity = wait2search_image_features @ base_product_features.T | |
| values, indices = similarity.topk(self.top_k) | |
| search_results = {} | |
| for idx, (value, indice) in enumerate(zip(values, indices)): | |
| pdf_img_list = [] | |
| for i in range(self.top_k): | |
| base_product_hash = list(self.base_product_features.keys())[indice[i]] | |
| if value[i] > 0.75: | |
| base_product_pdf_img = self.base_product_pdf_imgs[base_product_hash] | |
| else: | |
| base_product_pdf_img = Image.open('assets/not_found.png').convert('RGB') | |
| pdf_img_list.append(base_product_pdf_img) | |
| res_pdf_img = concate_images_vertically(pdf_img_list, list(value.cpu().numpy())) | |
| search_results[wait2search_image_hashes[idx]] = res_pdf_img | |
| return zip(list(search_results.values()), wait2search_image_names) | |
| SearchImageTask = ImageSearch() |