Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import clip | |
| import gradio as gr | |
| from utils import * | |
| import os | |
| # Load the open CLIP model | |
| model, preprocess = clip.load("ViT-B/32", device=device) | |
| from pathlib import Path | |
| # Download from Github Releases | |
| if not Path('unsplash-dataset/photo_ids.csv').exists(): | |
| os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''') | |
| if not Path('unsplash-dataset/features.npy').exists(): | |
| os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''') | |
| # Load the photo IDs | |
| photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv") | |
| photo_ids = list(photo_ids['photo_id']) | |
| # Load the features vectors | |
| photo_features = np.load("unsplash-dataset/features.npy") | |
| # Convert features to Tensors: Float32 on CPU and Float16 on GPU | |
| if device == "cpu": | |
| photo_features = torch.from_numpy(photo_features).float().to(device) | |
| else: | |
| photo_features = torch.from_numpy(photo_features).to(device) | |
| # Print some statistics | |
| print(f"Photos loaded: {len(photo_ids)}") | |
| from PIL import Image | |
| def encode_search_query(net, search_query): | |
| with torch.no_grad(): | |
| tokenized_query = clip.tokenize(search_query) | |
| # print("tokenized_query: ", tokenized_query.shape) | |
| # Encode and normalize the search query using CLIP | |
| text_encoded = net.encode_text(tokenized_query.to(device)) | |
| text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
| # Retrieve the feature vector | |
| # print("text_encoded: ", text_encoded.shape) | |
| return text_encoded | |
| def find_best_matches(text_features, photo_features, photo_ids, results_count=5): | |
| # Compute the similarity between the search query and each photo using the Cosine similarity | |
| # print("text_features: ", text_features.shape) | |
| # print("photo_features: ", photo_features.shape) | |
| similarities = (photo_features @ text_features.T).squeeze(1) | |
| # Sort the photos by their similarity score | |
| best_photo_idx = (-similarities).argsort() | |
| # print("best_photo_idx: ", best_photo_idx.shape) | |
| # print("best_photo_idx: ", best_photo_idx[:results_count]) | |
| result_list = [photo_ids[i] for i in best_photo_idx[:results_count]] | |
| # print("result_list: ", len(result_list)) | |
| # Return the photo IDs of the best matches | |
| return result_list | |
| def search_unslash(net, search_query, photo_features, photo_ids, results_count=10): | |
| # Encode the search query | |
| text_features = encode_search_query(net, search_query) | |
| # Find the best matches | |
| best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count) | |
| return best_photo_ids | |
| def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5): | |
| # Encode the search query | |
| if not query_text and query_photo is None and not query_photo_id: | |
| return [] | |
| text_features = encode_search_query(model, query_text) | |
| if query_photo_id: | |
| # Find the feature vector for the specified photo ID | |
| query_photo_index = photo_ids.index(query_photo_id) | |
| query_photo_features = photo_features[query_photo_index] | |
| # Combine the test and photo queries and normalize again | |
| search_features = text_features + query_photo_features * photo_weight | |
| search_features /= search_features.norm(dim=-1, keepdim=True) | |
| # Find the best match | |
| best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) | |
| elif query_photo is not None: | |
| query_photo = preprocess(query_photo) | |
| query_photo = torch.tensor(query_photo).permute(2, 0, 1) | |
| print(query_photo.shape) | |
| query_photo_features = model.encode_image(query_photo) | |
| query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True) | |
| # Combine the test and photo queries and normalize again | |
| search_features = text_features + query_photo_features * photo_weight | |
| search_features /= search_features.norm(dim=-1, keepdim=True) | |
| # Find the best match | |
| best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) | |
| else: | |
| # Display the results | |
| print("Result...") | |
| best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10) | |
| return best_photo_ids | |
| def fn_query_on_load(): | |
| return "Dogs playing during sunset" | |
| with gr.Blocks() as app: | |
| with gr.Row(): | |
| gr.Markdown( | |
| """ | |
| # CLIP Image Search Engine! | |
| ### Enter search query or/and select image to find the similar images | |
| """) | |
| with gr.Row(visible=True): | |
| with gr.Column(): | |
| with gr.Row(): | |
| search_text = gr.Textbox(value=fn_query_on_load, placeholder='Search..', label=None) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit", variant='primary') | |
| clear_btn = gr.ClearButton() | |
| with gr.Column(visible=True) as input_image_col: | |
| search_image = gr.Image(label='Select from results', interactive=False) | |
| search_image_id = gr.State(None) | |
| with gr.Row(visible=True): | |
| output_images = gr.Gallery(allow_preview=False, label='Results.. ', | |
| value=[], columns=5, rows=2) | |
| output_image_ids = gr.State([]) | |
| def clear_data(): | |
| return { | |
| search_image: None, | |
| output_images: None, | |
| search_text: None, | |
| search_image_id: None, | |
| input_image_col: gr.update(visible=True) | |
| } | |
| clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col]) | |
| def on_select(evt: gr.SelectData, output_image_ids): | |
| return { | |
| search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320", | |
| search_image_id: output_image_ids[evt.index], | |
| input_image_col: gr.update(visible=True) | |
| } | |
| output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col]) | |
| def func_search(query, img, img_id): | |
| best_photo_ids = [] | |
| if img_id: | |
| best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id) | |
| elif img is not None: | |
| img = Image.open(img) | |
| best_photo_ids = search_by_text_and_photo(query, query_photo=img) | |
| elif query: | |
| best_photo_ids = search_by_text_and_photo(query) | |
| if len(best_photo_ids) == 0: | |
| print("Invalid Search Request") | |
| return { | |
| output_image_ids: [], | |
| output_images: [] | |
| } | |
| else: | |
| img_urls = [] | |
| for p_id in best_photo_ids: | |
| url = f"https://unsplash.com/photos/{p_id}/download?w=20" | |
| img_urls.append(url) | |
| valid_images = filter_invalid_urls(img_urls, best_photo_ids) | |
| return { | |
| output_image_ids: valid_images['image_ids'], | |
| output_images: valid_images['image_urls'] | |
| } | |
| submit_btn.click( | |
| func_search, | |
| [search_text, search_image, search_image_id], | |
| [output_images, output_image_ids] | |
| ) | |
| def on_upload(evt: gr.SelectData): | |
| return { | |
| search_image_id: None | |
| } | |
| search_image.upload(on_upload, None, search_image_id) | |
| ''' | |
| Launch the app | |
| ''' | |
| app.launch() | |