Spaces:
Sleeping
Sleeping
| import sys | |
| import argparse | |
| import configparser | |
| import pickle | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import clip | |
| import annoy | |
| CONFIG_PATH = "app.ini" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--pkl', type=str, help='input pickle produced by create_embedding.py') | |
| parser.add_argument('--url', type=str, help='the base URL for the images') | |
| args = parser.parse_args() | |
| return args | |
| def parse_config_file(): | |
| config = configparser.ConfigParser() | |
| config.read(CONFIG_PATH) | |
| config_args = argparse.Namespace(**config['DEFAULT']) | |
| return config_args | |
| if len(sys.argv) == 1: | |
| print(f"no command line arguments, using {CONFIG_PATH}") | |
| args = parse_config_file() | |
| else: | |
| print("using command line arguments, ignoring ini file") | |
| args = parse_args() | |
| assert "pkl" in args and args.pkl is not None | |
| assert "url" in args and args.url is not None | |
| assert args.url.endswith("/") | |
| print("arguments:", args) | |
| pickle_filename, base_url = args.pkl, args.url | |
| data = pickle.load(open(pickle_filename, "rb")) | |
| # the data might be float16 so that the pkl is small, | |
| # but we use float32 in-memory to avoid numerical issues. | |
| # tbh i'm not sure there are any such issues. | |
| embeddings = data["embeddings"].astype(np.float32) | |
| embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None] | |
| n, d = embeddings.shape | |
| def build_ann_index(embeddings): | |
| print("annoy indexing") | |
| n, d = embeddings.shape | |
| annoy_index = annoy.AnnoyIndex(d, "angular") | |
| for i, vec in enumerate(embeddings): | |
| annoy_index.add_item(i, vec) | |
| annoy_index.build(10) | |
| print("done") | |
| return annoy_index | |
| filenames = data["filenames"] | |
| def thumb_patch(filename): | |
| prefix = "PhotoLibrary" | |
| assert filename.startswith(prefix) | |
| return prefix + ".thumbs" + filename[len(prefix): ] | |
| print("patching filenames") | |
| filenames = [thumb_patch(filename) for filename in filenames] | |
| folders = ["/".join(filename.split("/")[:-1]) for filename in filenames] | |
| # to make smart indexing possible: | |
| folders = np.array(folders) | |
| urls = [base_url + filename for filename in filenames] | |
| urls = np.array(urls) | |
| annoy_index = build_ann_index(embeddings) | |
| model, preprocess = clip.load('RN50', device=device) | |
| def embed_text(text): | |
| tokens = clip.tokenize([text]).to(device) | |
| with torch.no_grad(): | |
| text_features = model.encode_text(tokens) | |
| assert text_features.shape == (1, d) | |
| text_features = text_features.cpu().numpy()[0] | |
| text_features /= np.linalg.norm(text_features) | |
| return text_features | |
| def drop_same_folder(indices): | |
| folder_list = folders[indices] | |
| filled = set() | |
| kept = [] | |
| for indx, folder in zip(indices, folder_list): | |
| if folder not in filled: | |
| filled.add(folder) | |
| kept.append(indx) | |
| return kept | |
| def features_to_gallery(features): | |
| indices = annoy_index.get_nns_by_vector(features, n=500) | |
| indices = drop_same_folder(indices)[:50] | |
| top_urls = urls[indices] | |
| return top_urls.tolist(), indices | |
| def image_retrieval_from_text(text): | |
| text_features = embed_text(text) | |
| return features_to_gallery(text_features) | |
| def image_retrieval_from_image(state, selected_locally): | |
| if state is None or len(state) == 0: | |
| return [], [] | |
| selected = state[int(selected_locally)] | |
| return features_to_gallery(embeddings[selected]) | |
| def query_uploaded_image(uploaded_image): | |
| image = preprocess(uploaded_image) | |
| image_batch = torch.tensor(np.stack([image])).to(device) | |
| with torch.no_grad(): | |
| image_features = model.encode_image(image_batch).float() | |
| image_features = image_features.cpu().numpy() | |
| assert len(image_features) == 1 | |
| image_features = image_features[0] | |
| assert len(image_features) == d | |
| return features_to_gallery(image_features) | |
| def show_folder(state, selected_locally): | |
| if state is None or len(state) == 0: | |
| return [], [] | |
| selected = state[int(selected_locally)] | |
| target_folder = folders[selected] | |
| indices = [] | |
| # linear search | |
| for i, folder in enumerate(folders): | |
| if folder == target_folder: | |
| indices.append(i) | |
| top_urls = urls[indices] | |
| return top_urls.tolist(), indices | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| state = gr.State() | |
| with gr.Row(variant="compact"): | |
| text = gr.Textbox( | |
| label="Enter search query", | |
| show_label=False, | |
| max_lines=1, | |
| placeholder="Enter your prompt", | |
| ).style(container=False) | |
| text_query_button = gr.Button("Search").style(full_width=False) | |
| with gr.Row(variant="compact"): | |
| uploaded_image = gr.Image(tool="select", type="pil", show_label=False) | |
| query_uploaded_image_button = gr.Button("Show similiar to uploaded") | |
| gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery" | |
| ).style(columns=5, container=False) | |
| with gr.Row(): | |
| filename_textbox = gr.Textbox("", show_label=False).style(container=False) | |
| with gr.Row(): | |
| show_folder_button = gr.Button("Show folder of selected") | |
| image_query_button = gr.Button("Show similar to selected") | |
| selected = gr.Number(0, show_label=False, visible=False) | |
| text_query_button.click(image_retrieval_from_text, [text], [gallery, state]) | |
| image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state]) | |
| show_folder_button.click(show_folder, [state, selected], [gallery, state]) | |
| query_uploaded_image_button.click(query_uploaded_image, [uploaded_image], [gallery, state]) | |
| def get_select_index(evt: gr.SelectData, state): | |
| selected_locally = evt.index | |
| selected = state[int(selected_locally)] | |
| return selected_locally, filenames[selected] | |
| gallery.select(get_select_index, [state], [selected, filename_textbox]) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |