sanketmalde's picture
Update app.py
fd0f9ee verified
import os
import gradio as gr
from PIL import Image
from img2vec_pytorch import Img2Vec
from sklearn.metrics.pairwise import cosine_similarity
def load_images(img_dir):
print("Loading images and vectors...\n")
# For each test image, we store the filename and vector as key, value in a dictionary
for file in os.listdir(img_dir):
filename = os.fsdecode(file)
img = Image.open(os.path.join(img_dir, filename)).convert('RGB')
vec = img2vec.get_vec(img)
image_vectors[filename] = vec
images[filename] = img
print("Loaded " + filename)
return (f"Finished loading {len(image_vectors)} files.")
def image_sim(image):
result_images = []
result_index = 0
sims = {}
# Convert given image to vec
try:
query_img = Image.open(image).convert('RGB')
query_vec = img2vec.get_vec(query_img)
except Exception as e:
while result_index < 3:
if result_index == 0 and len(sims) == 0:
result_images.append(gr.Textbox(value="Error loading input image!"))
if result_index == 0 and len(sims) != 0:
result_images.append(gr.Textbox(value="No matching images found!"))
result_images.append(gr.Image(value=image, label="0.00", show_label=True, visible=False))
result_index += 1
return result_images
# Compute cosine similarity with all the available images
for key in list(image_vectors.keys()):
sims[key] = cosine_similarity(query_vec.reshape((1, -1)), image_vectors[key].reshape((1, -1)))[0][0]
# Reverse sort and pick the top 3 > 80% match
d_view = [(v, k) for k, v in sims.items()]
d_view.sort(reverse=True)
for v, k in d_view:
print(f"{v}, {k}\n")
if v > 0.99 or v < 0.8 or result_index >= 3:
continue
if result_index == 0:
result_images.append(gr.Textbox(value="Warning: similar images already exist in collection:"))
result_images.append(gr.Image(value=images[k], label=str(v), show_label=True, visible=True))
result_index += 1
# Fill in the rest of the images as not visible
while result_index < 3:
if result_index == 0 and len(sims) == 0:
result_images.append(gr.Textbox(value="Please load images first!"))
if result_index == 0 and len(sims) != 0:
result_images.append(gr.Textbox(value="No matching images found!"))
result_images.append(gr.Image(value=image, label="0.00", show_label=True, visible=False))
result_index += 1
return result_images
image_vectors = {}
images = {}
img2vec = Img2Vec()
with gr.Blocks() as demo:
default_dir = "boxes"
img_dir = gr.Textbox(label="Image Directory", value=default_dir)
output = gr.Textbox(label="", value="Please load images to continue")
load_images_button = gr.Button("Load Images")
load_images_button.click(fn=load_images, inputs=[img_dir], outputs=output)
gr.Interface(
image_sim,
gr.Image(type="filepath", label="Input Image"),
[gr.Textbox(label="Image similarity search results:"), gr.Image(label="Image 1"), gr.Image(label="Image 2"),
gr.Image(label="Image 3")],
allow_flagging="never",
examples=[
os.path.join(default_dir, "box1.jpeg"),
os.path.join(default_dir, "box17.jpeg"),
os.path.join(default_dir, "box29.jpeg"),
os.path.join(default_dir, "box30.jpeg"),
],
)
if __name__ == "__main__":
demo.launch(show_api=False, debug=True, share=True)