import os import gradio as gr from PIL import Image from img2vec_pytorch import Img2Vec from sklearn.metrics.pairwise import cosine_similarity def load_images(img_dir): print("Loading images and vectors...\n") # For each test image, we store the filename and vector as key, value in a dictionary for file in os.listdir(img_dir): filename = os.fsdecode(file) img = Image.open(os.path.join(img_dir, filename)).convert('RGB') vec = img2vec.get_vec(img) image_vectors[filename] = vec images[filename] = img print("Loaded " + filename) return (f"Finished loading {len(image_vectors)} files.") def image_sim(image): result_images = [] result_index = 0 sims = {} # Convert given image to vec try: query_img = Image.open(image).convert('RGB') query_vec = img2vec.get_vec(query_img) except Exception as e: while result_index < 3: if result_index == 0 and len(sims) == 0: result_images.append(gr.Textbox(value="Error loading input image!")) if result_index == 0 and len(sims) != 0: result_images.append(gr.Textbox(value="No matching images found!")) result_images.append(gr.Image(value=image, label="0.00", show_label=True, visible=False)) result_index += 1 return result_images # Compute cosine similarity with all the available images for key in list(image_vectors.keys()): sims[key] = cosine_similarity(query_vec.reshape((1, -1)), image_vectors[key].reshape((1, -1)))[0][0] # Reverse sort and pick the top 3 > 80% match d_view = [(v, k) for k, v in sims.items()] d_view.sort(reverse=True) for v, k in d_view: print(f"{v}, {k}\n") if v > 0.99 or v < 0.8 or result_index >= 3: continue if result_index == 0: result_images.append(gr.Textbox(value="Warning: similar images already exist in collection:")) result_images.append(gr.Image(value=images[k], label=str(v), show_label=True, visible=True)) result_index += 1 # Fill in the rest of the images as not visible while result_index < 3: if result_index == 0 and len(sims) == 0: result_images.append(gr.Textbox(value="Please load images first!")) if result_index == 0 and len(sims) != 0: result_images.append(gr.Textbox(value="No matching images found!")) result_images.append(gr.Image(value=image, label="0.00", show_label=True, visible=False)) result_index += 1 return result_images image_vectors = {} images = {} img2vec = Img2Vec() with gr.Blocks() as demo: default_dir = "boxes" img_dir = gr.Textbox(label="Image Directory", value=default_dir) output = gr.Textbox(label="", value="Please load images to continue") load_images_button = gr.Button("Load Images") load_images_button.click(fn=load_images, inputs=[img_dir], outputs=output) gr.Interface( image_sim, gr.Image(type="filepath", label="Input Image"), [gr.Textbox(label="Image similarity search results:"), gr.Image(label="Image 1"), gr.Image(label="Image 2"), gr.Image(label="Image 3")], allow_flagging="never", examples=[ os.path.join(default_dir, "box1.jpeg"), os.path.join(default_dir, "box17.jpeg"), os.path.join(default_dir, "box29.jpeg"), os.path.join(default_dir, "box30.jpeg"), ], ) if __name__ == "__main__": demo.launch(show_api=False, debug=True, share=True)