CLIP_Image_Search / search_engine_model.py
DanielIglesias97's picture
First upload of the CLIP_Image_Search code to this repository.
11c750c
import clip
import logging
import os
import pandas as pd
from PIL import Image
import random
import torch
class SearchEngineModel():
def __init__(self, image_root_dir, csv_file_path):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.image_root_dir = image_root_dir
self.csv_file_path = csv_file_path
self.model, self.preprocess = self.load_clip_model()
def load_clip_model(self):
model, preprocess = clip.load("ViT-B/32", device=self.device)
return model, preprocess
def encode_images(self, model, preprocess, image_folder, csv_file_path):
encoded_images = []
image_paths = []
if (not os.path.exists(csv_file_path)):
dataset_images = os.listdir(image_folder)
total_nof_dataset_images = len(dataset_images)
for idx, filename in enumerate(dataset_images):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
self.logger.info('[%d/%d] Processing %s...'%(idx, total_nof_dataset_images, filename))
image_path = os.path.join(image_folder, filename)
image = preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = model.encode_image(image)
encoded_images.append(image_features)
image_paths.append(image_path)
encoded_images = torch.cat(encoded_images)
image_features_df = pd.DataFrame(encoded_images)
image_features_df['path'] = image_paths
image_features_df.to_csv(csv_file_path, index=False)
else:
image_features_df = pd.read_csv(csv_file_path)
image_paths = image_features_df['path'].values.tolist()
encoded_images = image_features_df.drop(columns=['path'])
encoded_images = torch.Tensor(image_features_df.drop(columns=['path']).values)
return encoded_images, image_paths
def __search_image_auxiliar_func__(self, prompt_features, nofimages_to_show):
encoded_images, image_paths = self.encode_images(self.model, self.preprocess, self.image_root_dir, self.csv_file_path)
similarity = encoded_images @ prompt_features.T
values, indices = similarity.topk(nofimages_to_show, dim=0)
results = []
for value, index in zip(values, indices):
results.append(image_paths[index])
return results
def search_image_by_text_prompt(self, text_prompt, nofimages_to_show):
query = clip.tokenize([text_prompt]).to(self.device)
with torch.no_grad():
text_features = self.model.encode_text(query)
search_results = self.__search_image_auxiliar_func__(text_features, nofimages_to_show)
return search_results
def search_image_by_image_prompt(self, image_prompt, nofimages_to_show):
image = self.preprocess(image_prompt).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
search_results = self.__search_image_auxiliar_func__(image_features, nofimages_to_show)
return search_results