Spaces:
Runtime error
Runtime error
File size: 5,190 Bytes
1c3aebe 4b4753d 1c3aebe 4b4753d 1c3aebe 4b4753d 1c3aebe 7f0e191 1c3aebe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import clip
import torch
import hashlib
import numpy as np
from PIL import Image
from image_process import crop_image_from_background
from image_process import concate_images_vertically
from image_process import get_products_from_pdf_file
class ImageSearch():
def __init__(self):
print('CLIP Models', clip.available_models())
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = clip.load('ViT-B/32', self.device)
self.top_k = 2
self.base_product_features = {}
self.base_product_pdf_imgs = {}
def encode_image(self, image_list):
if not isinstance(image_list, list):
image_list = [image_list]
img_batch = torch.cat([self.preprocess(image).unsqueeze(0) for image in image_list], dim=0).to(self.device)
with torch.no_grad():
img_features = self.model.encode_image(img_batch)
return img_features.cpu().numpy()
def get_base_product_features(self, product_hashes, product_images):
_product_hashs = []
_product_images = []
for product_hash, product_image in zip(product_hashes, product_images):
if product_hash in self.base_product_features:
continue
_product_hashs.append(product_hash)
_product_images.append(product_image)
if len(_product_hashs) > 0:
_product_features = self.encode_image(_product_images)
for _product_hash, _product_feature in zip(_product_hashs, _product_features):
self.base_product_features[_product_hash] = _product_feature
def upload_products_pdf_file(self, pdf_files):
product_names = []
product_hashes = []
product_cropped_images = []
for pdf_file in pdf_files:
products_images = get_products_from_pdf_file(pdf_file)
if len(products_images) > 0:
for product_images in products_images:
product_pdf_image = concate_images_vertically(product_images)
product_hash = hashlib.md5(product_pdf_image.tobytes()).hexdigest()
self.base_product_pdf_imgs[product_hash] = product_pdf_image
product_hashes.append(product_hash)
# 每份作品登记证书,第一页是文本,第二页是图片
product_image = product_images[1]
if len(products_images) == 1:
product_names.append(pdf_file.name.split('/')[-1])
else:
product_names.append(product_hash)
product_image = crop_image_from_background(product_image)
product_image = Image.fromarray(product_image)
product_cropped_images.append(product_image)
self.get_base_product_features(product_hashes, product_cropped_images)
return zip(product_cropped_images, product_names)
def upload_wait2search_image(self, image_infos):
wait2search_image_list = []
wait2search_image_hashes = []
wait2search_image_names = []
for image_info in image_infos:
image_file, image_label = image_info
wait2search_image = Image.open(image_file)
wait2search_image_names.append(image_file.split('/')[-1])
wait2search_image_list.append(wait2search_image)
wait2search_image_hashes.append(hashlib.md5(wait2search_image.tobytes()).hexdigest())
search_results = self.search_image(wait2search_image_hashes, wait2search_image_list, wait2search_image_names)
return search_results
def search_image(self, wait2search_image_hashes, wait2search_image_list, wait2search_image_names):
base_product_features = torch.from_numpy(np.array(list(self.base_product_features.values()))).to(self.device)
base_product_features /= base_product_features.norm(dim=-1, keepdim=True)
wait2search_image_features = torch.from_numpy(self.encode_image(wait2search_image_list)).to(self.device)
wait2search_image_features /= wait2search_image_features.norm(dim=-1, keepdim=True)
similarity = wait2search_image_features @ base_product_features.T
values, indices = similarity.topk(self.top_k)
search_results = {}
for idx, (value, indice) in enumerate(zip(values, indices)):
pdf_img_list = []
for i in range(self.top_k):
base_product_hash = list(self.base_product_features.keys())[indice[i]]
if value[i] > 0.75:
base_product_pdf_img = self.base_product_pdf_imgs[base_product_hash]
else:
base_product_pdf_img = Image.open('assets/not_found.png').convert('RGB')
pdf_img_list.append(base_product_pdf_img)
res_pdf_img = concate_images_vertically(pdf_img_list, list(value.cpu().numpy()))
search_results[wait2search_image_hashes[idx]] = res_pdf_img
return zip(list(search_results.values()), wait2search_image_names)
SearchImageTask = ImageSearch() |