#Project ref: #1. https://github.com/haltakov/natural-language-image-search by Vladimir Haltakov #2. OpenAI's CLIP import torch import requests import numpy as np import pandas as pd import gradio as gr from io import BytesIO device = 'cuda' if torch.cuda.is_available() else 'cpu' from PIL import Image as pil from transformers import CLIPProcessor, CLIPModel,CLIPTokenizer from sentence_transformers import SentenceTransformer, util model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device) processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32') examples = [['flag.jpg'],['flowe.jpg'],['dance.jpg']] data = pd.read_csv('./photos.tsv000', sep='\t', header=0) photo_features= np.load('./features.npy') ids = pd.read_csv('./photo_ids.csv') ids = list(ids['photo_id']) def encode_img(image): img = pil.fromarray(image.astype('uint8'),'RGB') with torch.no_grad(): processed = processor(text = None,images=image, return_tensors='pt',padding=True)['pixel_values'] search_photo_feature = model.get_image_features(processed.to(device)) search_photo_feature /= search_photo_feature.norm(dim=1,keepdim=True) img_encoded = search_photo_feature.cpu().numpy() return img_encoded def encode_txt(text): with torch.no_grad(): inp = tokenizer([text],padding=True, return_tensors='pt') inp = processor(text=[inp], images=None, return_tensors='pt',padding=True) text_encoded = model.gt_text_text_features(**inp).detach().numpy() return text_encoded def similarity(feature, photo_features): similarities = list((feature @ photo_features.T).squeeze(0)) return similarities def find_best_matches(image, mode, text): if mode == 'Text2Image': text_features = encode_txt(text) similarities = similarity(text_features, photo_features) else: img_features = encode_img(image) similarities = similarity(img_features, photo_features) best_photos = sorted(zip(similarities, range(photo_features.shape[0]))) matched_images = [] for i in range(4): idx = best_photos[i][1] photo_id = ids[idx] photo_data = data[data['photo_id'] == photo_id].iloc[0] response = requests.get(photo_data['photo_image_url']+ '?w=640') img = pil.open(BytesIO(response.content)) matched_images.append(img) return matched_images demo = gr.Interface(fn=find_best_matches, inputs = [ gr.Image(label='Search image'), gr.Radio(['Text to Image', 'Image to Image']), gr.Textbox(lines=1, label = 'Text query')], examples = examples, outputs = [gr.Gallery(label= 'Generated images', show_label=False, elem_id='gallery', scale = 3)], title = 'CLIP tutorial', description = 'This application gives best 4 results for the query by the user. Input is either text or image and get image and text as results respectively.') demo.launch()