DanielIglesias97's picture
First upload of the code of NodeImageSearchEngine to the repository.
708f5d3
import clip
import logging
import os
import pandas as pd
from PIL import Image
import random
import torch
class SearchEngineModel():
def __init__(self, image_root_dir, csv_file_path):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.image_root_dir = image_root_dir
self.csv_file_path = csv_file_path
self.model, self.preprocess = self.load_clip_model()
def load_clip_model(self):
model, preprocess = clip.load("ViT-B/32", device=self.device)
return model, preprocess
def encode_images(self, model, preprocess, image_folder, csv_file_path):
encoded_images = []
image_paths = []
if (not os.path.exists(csv_file_path)):
dataset_images = os.listdir(image_folder)
total_nof_dataset_images = len(dataset_images)
for idx, filename in enumerate(dataset_images):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
self.logger.info('[%d/%d] Processing %s...'%(idx, total_nof_dataset_images, filename))
image_path = os.path.join(image_folder, filename)
image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = model.encode_image(image)
encoded_images.append(image_features)
image_paths.append(image_path)
encoded_images = torch.cat(encoded_images)
image_features_df = pd.DataFrame(encoded_images)
image_features_df['path'] = image_paths
image_features_df.to_csv(csv_file_path, index=False)
else:
image_features_df = pd.read_csv(csv_file_path)
image_paths = image_features_df['path'].values.tolist()
encoded_images = image_features_df.drop(columns=['path'])
encoded_images = torch.Tensor(image_features_df.drop(columns=['path']).values)
return encoded_images, image_paths
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show):
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path)
similarity = encoded_images @ prompt_features.T
values, indices = similarity.topk(nofimages_to_show, dim=0)
results = []
for value, index in zip(values, indices):
results.append(image_paths[index])
return results
def search_image_by_text_prompt(self, text_prompt, nofimages_to_show):
query = clip.tokenize([text_prompt]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(query)
search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show)
return search_results
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show):
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show)
return search_results