Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| from PIL import Image | |
| from img2vec_pytorch import Img2Vec | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| def load_images(img_dir): | |
| print("Loading images and vectors...\n") | |
| # For each test image, we store the filename and vector as key, value in a dictionary | |
| for file in os.listdir(img_dir): | |
| filename = os.fsdecode(file) | |
| img = Image.open(os.path.join(img_dir, filename)).convert('RGB') | |
| vec = img2vec.get_vec(img) | |
| image_vectors[filename] = vec | |
| images[filename] = img | |
| print("Loaded " + filename) | |
| return (f"Finished loading {len(image_vectors)} files.") | |
| def image_sim(image): | |
| result_images = [] | |
| result_index = 0 | |
| sims = {} | |
| # Convert given image to vec | |
| try: | |
| query_img = Image.open(image).convert('RGB') | |
| query_vec = img2vec.get_vec(query_img) | |
| except Exception as e: | |
| while result_index < 3: | |
| if result_index == 0 and len(sims) == 0: | |
| result_images.append(gr.Textbox(value="Error loading input image!")) | |
| if result_index == 0 and len(sims) != 0: | |
| result_images.append(gr.Textbox(value="No matching images found!")) | |
| result_images.append(gr.Image(value=image, label="0.00", show_label=True, visible=False)) | |
| result_index += 1 | |
| return result_images | |
| # Compute cosine similarity with all the available images | |
| for key in list(image_vectors.keys()): | |
| sims[key] = cosine_similarity(query_vec.reshape((1, -1)), image_vectors[key].reshape((1, -1)))[0][0] | |
| # Reverse sort and pick the top 3 > 80% match | |
| d_view = [(v, k) for k, v in sims.items()] | |
| d_view.sort(reverse=True) | |
| for v, k in d_view: | |
| print(f"{v}, {k}\n") | |
| if v > 0.99 or v < 0.8 or result_index >= 3: | |
| continue | |
| if result_index == 0: | |
| result_images.append(gr.Textbox(value="Warning: similar images already exist in collection:")) | |
| result_images.append(gr.Image(value=images[k], label=str(v), show_label=True, visible=True)) | |
| result_index += 1 | |
| # Fill in the rest of the images as not visible | |
| while result_index < 3: | |
| if result_index == 0 and len(sims) == 0: | |
| result_images.append(gr.Textbox(value="Please load images first!")) | |
| if result_index == 0 and len(sims) != 0: | |
| result_images.append(gr.Textbox(value="No matching images found!")) | |
| result_images.append(gr.Image(value=image, label="0.00", show_label=True, visible=False)) | |
| result_index += 1 | |
| return result_images | |
| image_vectors = {} | |
| images = {} | |
| img2vec = Img2Vec() | |
| with gr.Blocks() as demo: | |
| default_dir = "boxes" | |
| img_dir = gr.Textbox(label="Image Directory", value=default_dir) | |
| output = gr.Textbox(label="", value="Please load images to continue") | |
| load_images_button = gr.Button("Load Images") | |
| load_images_button.click(fn=load_images, inputs=[img_dir], outputs=output) | |
| gr.Interface( | |
| image_sim, | |
| gr.Image(type="filepath", label="Input Image"), | |
| [gr.Textbox(label="Image similarity search results:"), gr.Image(label="Image 1"), gr.Image(label="Image 2"), | |
| gr.Image(label="Image 3")], | |
| allow_flagging="never", | |
| examples=[ | |
| os.path.join(default_dir, "box1.jpeg"), | |
| os.path.join(default_dir, "box17.jpeg"), | |
| os.path.join(default_dir, "box29.jpeg"), | |
| os.path.join(default_dir, "box30.jpeg"), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_api=False, debug=True, share=True) | |