| import clip | |
| import logging | |
| import os | |
| import pandas as pd | |
| from PIL import Image | |
| import random | |
| import torch | |
| class SearchEngineModel(): | |
| def __init__(self, image_root_dir, csv_file_path): | |
| self.logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.image_root_dir = image_root_dir | |
| self.csv_file_path = csv_file_path | |
| self.model, self.preprocess = self.load_clip_model() | |
| def load_clip_model(self): | |
| model, preprocess = clip.load("ViT-B/32", device=self.device) | |
| return model, preprocess | |
| def encode_images(self, model, preprocess, image_folder, csv_file_path): | |
| encoded_images = [] | |
| image_paths = [] | |
| if (not os.path.exists(csv_file_path)): | |
| dataset_images = os.listdir(image_folder) | |
| total_nof_dataset_images = len(dataset_images) | |
| for idx, filename in enumerate(dataset_images): | |
| if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
| self.logger.info('[%d/%d] Processing %s...'%(idx, total_nof_dataset_images, filename)) | |
| image_path = os.path.join(image_folder, filename) | |
| image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image) | |
| encoded_images.append(image_features) | |
| image_paths.append(image_path) | |
| encoded_images = torch.cat(encoded_images) | |
| image_features_df = pd.DataFrame(encoded_images) | |
| image_features_df['path'] = image_paths | |
| image_features_df.to_csv(csv_file_path, index=False) | |
| else: | |
| image_features_df = pd.read_csv(csv_file_path) | |
| image_paths = image_features_df['path'].values.tolist() | |
| encoded_images = image_features_df.drop(columns=['path']) | |
| encoded_images = torch.Tensor(image_features_df.drop(columns=['path']).values) | |
| return encoded_images, image_paths | |
| def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show): | |
| encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path) | |
| similarity = encoded_images @ prompt_features.T | |
| values, indices = similarity.topk(nofimages_to_show, dim=0) | |
| results = [] | |
| for value, index in zip(values, indices): | |
| results.append(image_paths[index]) | |
| return results | |
| def search_image_by_text_prompt(self, text_prompt, nofimages_to_show): | |
| query = clip.tokenize([text_prompt]).to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.model.encode_text(query) | |
| search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show) | |
| return search_results | |
| def search_image_by_image_prompt(self, image_prompt, nofimages_to_show): | |
| image = self.preprocess(image_prompt).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(image) | |
| search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show) | |
| return search_results |