pic2pdf_explorer / algorithm.py
wuzengcheng
[fix] base features and search features should be in the same device
7f0e191
import clip
import torch
import hashlib
import numpy as np
from PIL import Image
from image_process import crop_image_from_background
from image_process import concate_images_vertically
from image_process import get_products_from_pdf_file
class ImageSearch():
def __init__(self):
print('CLIP Models', clip.available_models())
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, self.preprocess = clip.load('ViT-B/32', self.device)
self.top_k = 2
self.base_product_features = {}
self.base_product_pdf_imgs = {}
def encode_image(self, image_list):
if not isinstance(image_list, list):
image_list = [image_list]
img_batch = torch.cat([self.preprocess(image).unsqueeze(0) for image in image_list], dim=0).to(self.device)
with torch.no_grad():
img_features = self.model.encode_image(img_batch)
return img_features.cpu().numpy()
def get_base_product_features(self, product_hashes, product_images):
_product_hashs = []
_product_images = []
for product_hash, product_image in zip(product_hashes, product_images):
if product_hash in self.base_product_features:
continue
_product_hashs.append(product_hash)
_product_images.append(product_image)
if len(_product_hashs) > 0:
_product_features = self.encode_image(_product_images)
for _product_hash, _product_feature in zip(_product_hashs, _product_features):
self.base_product_features[_product_hash] = _product_feature
def upload_products_pdf_file(self, pdf_files):
product_names = []
product_hashes = []
product_cropped_images = []
for pdf_file in pdf_files:
products_images = get_products_from_pdf_file(pdf_file)
if len(products_images) > 0:
for product_images in products_images:
product_pdf_image = concate_images_vertically(product_images)
product_hash = hashlib.md5(product_pdf_image.tobytes()).hexdigest()
self.base_product_pdf_imgs[product_hash] = product_pdf_image
product_hashes.append(product_hash)
# 每份作品登记证书,第一页是文本,第二页是图片
product_image = product_images[1]
if len(products_images) == 1:
product_names.append(pdf_file.name.split('/')[-1])
else:
product_names.append(product_hash)
product_image = crop_image_from_background(product_image)
product_image = Image.fromarray(product_image)
product_cropped_images.append(product_image)
self.get_base_product_features(product_hashes, product_cropped_images)
return zip(product_cropped_images, product_names)
def upload_wait2search_image(self, image_infos):
wait2search_image_list = []
wait2search_image_hashes = []
wait2search_image_names = []
for image_info in image_infos:
image_file, image_label = image_info
wait2search_image = Image.open(image_file)
wait2search_image_names.append(image_file.split('/')[-1])
wait2search_image_list.append(wait2search_image)
wait2search_image_hashes.append(hashlib.md5(wait2search_image.tobytes()).hexdigest())
search_results = self.search_image(wait2search_image_hashes, wait2search_image_list, wait2search_image_names)
return search_results
def search_image(self, wait2search_image_hashes, wait2search_image_list, wait2search_image_names):
base_product_features = torch.from_numpy(np.array(list(self.base_product_features.values()))).to(self.device)
base_product_features /= base_product_features.norm(dim=-1, keepdim=True)
wait2search_image_features = torch.from_numpy(self.encode_image(wait2search_image_list)).to(self.device)
wait2search_image_features /= wait2search_image_features.norm(dim=-1, keepdim=True)
similarity = wait2search_image_features @ base_product_features.T
values, indices = similarity.topk(self.top_k)
search_results = {}
for idx, (value, indice) in enumerate(zip(values, indices)):
pdf_img_list = []
for i in range(self.top_k):
base_product_hash = list(self.base_product_features.keys())[indice[i]]
if value[i] > 0.75:
base_product_pdf_img = self.base_product_pdf_imgs[base_product_hash]
else:
base_product_pdf_img = Image.open('assets/not_found.png').convert('RGB')
pdf_img_list.append(base_product_pdf_img)
res_pdf_img = concate_images_vertically(pdf_img_list, list(value.cpu().numpy()))
search_results[wait2search_image_hashes[idx]] = res_pdf_img
return zip(list(search_results.values()), wait2search_image_names)
SearchImageTask = ImageSearch()