Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import h5py | |
| import faiss | |
| from PIL import Image | |
| import io | |
| import pickle | |
| import random | |
| import click | |
| def getRandID(): | |
| indx = random.randrange(0, len(index_to_id_dict)) | |
| return index_to_id_dict[indx], indx | |
| def get_image_index(indexType): | |
| if indexType == "FlatIP(default)": | |
| return image_index_IP | |
| elif indexType == "FlatL2": | |
| raise NotImplementedError | |
| return image_index_L2 | |
| elif indexType == "HNSWFlat": | |
| raise NotImplementedError | |
| return image_index_HNSW | |
| elif indexType == "IVFFlat": | |
| raise NotImplementedError | |
| return image_index_IVF | |
| elif indexType == "LSH": | |
| raise NotImplementedError | |
| return image_index_LSH | |
| def get_dna_index(indexType): | |
| if indexType == "FlatIP(default)": | |
| return dna_index_IP | |
| elif indexType == "FlatL2": | |
| raise NotImplementedError | |
| return dna_index_L2 | |
| elif indexType == "HNSWFlat": | |
| raise NotImplementedError | |
| return dna_index_HNSW | |
| elif indexType == "IVFFlat": | |
| raise NotImplementedError | |
| return dna_index_IVF | |
| elif indexType == "LSH": | |
| raise NotImplementedError | |
| return dna_index_LSH | |
| def searchEmbeddings(id, key_type, query_type, index_type, num_results: int = 10): | |
| image_index = get_image_index(index_type) | |
| dna_index = get_dna_index(index_type) | |
| # get index | |
| if query_type == "Image": | |
| query = image_index.reconstruct(id_to_index_dict[id]) | |
| elif query_type == "DNA": | |
| query = dna_index.reconstruct(id_to_index_dict[id]) | |
| else: | |
| raise ValueError(f"Invalid query type: {query_type}") | |
| query = query.astype(np.float32) | |
| query = np.expand_dims(query, axis=0) | |
| # search for query | |
| if key_type == "Image": | |
| index = image_index | |
| elif key_type == "DNA": | |
| index = dna_index | |
| else: | |
| raise ValueError(f"Invalid key type: {key_type}") | |
| _, I = index.search(query, num_results) | |
| closest_ids = [] | |
| for indx in I[0]: | |
| id = index_to_id_dict[indx] | |
| closest_ids.append(id) | |
| return closest_ids | |
| with gr.Blocks() as demo: | |
| # for hf: change all file paths, indx_to_id_dict as well | |
| # load indexes | |
| image_index_IP = faiss.read_index("bioscan_5m_image_IndexFlatIP.index") | |
| # image_index_L2 = faiss.read_index("big_image_index_FlatL2.index") | |
| # image_index_HNSW = faiss.read_index("big_image_index_HNSWFlat.index") | |
| # image_index_IVF = faiss.read_index("big_image_index_IVFFlat.index") | |
| # image_index_LSH = faiss.read_index("big_image_index_LSH.index") | |
| dna_index_IP = faiss.read_index("bioscan_5m_dna_IndexFlatIP.index") | |
| # dna_index_L2 = faiss.read_index("big_dna_index_FlatL2.index") | |
| # dna_index_HNSW = faiss.read_index("big_dna_index_HNSWFlat.index") | |
| # dna_index_IVF = faiss.read_index("big_dna_index_IVFFlat.index") | |
| # dna_index_LSH = faiss.read_index("big_dna_index_LSH.index") | |
| # with open("dataset_processid_list.pickle", "rb") as f: | |
| # dataset_processid_list = pickle.load(f) | |
| # with open("processid_to_index.pickle", "rb") as f: | |
| # processid_to_index = pickle.load(f) | |
| with open("big_indx_to_id_dict.pickle", "rb") as f: | |
| index_to_id_dict = pickle.load(f) | |
| id_to_index_dict = {v: k for k, v in index_to_id_dict.items()} | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| rand_id = gr.Textbox(label="Random ID:") | |
| rand_id_indx = gr.Textbox(label="Index:") | |
| id_btn = gr.Button("Get Random ID") | |
| with gr.Column(): | |
| key_type = gr.Radio(choices=["Image", "DNA"], label="Search From:", value="Image") | |
| query_type = gr.Radio(choices=["Image", "DNA"], label="Search To:", value="Image") | |
| index_type = gr.Radio( | |
| choices=["FlatIP(default)", "FlatL2", "HNSWFlat", "IVFFlat", "LSH"], label="Index:", value="FlatIP(default)" | |
| ) | |
| num_results = gr.Number(label="Number of Results:", value=10, precision=0) | |
| process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for") | |
| process_id_list = gr.Textbox(label="Closest 10 matches:") | |
| search_btn = gr.Button("Search") | |
| id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx]) | |
| search_btn.click( | |
| fn=searchEmbeddings, | |
| inputs=[process_id, key_type, query_type, index_type, num_results], | |
| outputs=[process_id_list], | |
| ) | |
| demo.launch() | |