Image_Embedding / app.py
GOWaz's picture
Upload 11 files
a8e2ab4 verified
import os
import gradio as gr
from PIL import Image
from sklearn.metrics.pairwise import cosine_similarity
from clip_embedding import Clip
from efficientnet_embedding import EfficientNet
from vit_embedding import Vit
from resnet_embedding import Resnet
from dino_embedding import Dino
from histogram_embedding import cosine, get_embedding
from bovw_embedding import Bovw
resnet = Resnet()
vit = Vit()
efficientnet = EfficientNet()
bovw = Bovw()
dino = Dino()
clip = Clip()
def get_image_embedding(image: Image.Image, name):
match name:
case "ResNet":
return resnet.get_embedding(image).cpu().numpy()
case "VIT":
return vit.get_embedding(image).cpu().numpy()
case "EfficientNet":
return efficientnet.get_embedding(image).cpu().numpy()
case "Histogram":
return get_embedding(image)
case "BOVW":
return bovw.get_embedding(image)
case "DINO":
return dino.get_embedding(image).cpu().numpy()
case _:
return clip.get_embedding(image).cpu().numpy()
def compare_images(main_img, compare_imgs, name):
results = []
if name in ("Histogram", "ResNet", "BOVW"):
main_emb = get_image_embedding(main_img, name)
for img in compare_imgs:
emb = get_image_embedding(img, name)
results.append((img, round(cosine(main_emb, emb) * 100, 2)))
else:
main_embedding = get_image_embedding(main_img, name)
for img in compare_imgs:
emb = get_image_embedding(img, name)
score = cosine_similarity(main_embedding, emb)[0][0]
percentage = round(score * 100, 2)
results.append((img, percentage))
results.sort(key=lambda x: x[1], reverse=True)
return results
model_list = ["CLIP", "VIT", "EfficientNet", "ResNet", "DINO", "Histogram", "BOVW"]
with gr.Blocks() as demo:
gr.Tab("Image Embedding")
gr.Markdown("# Image Similarity Finder")
gr.Markdown(
"Upload a main image and compare it to others. Results show similarity percentages using embeddings.")
with gr.Row():
with gr.Column():
main_image = gr.Image(type="pil", label="Main Image")
compare_images_input = gr.File(file_count="multiple", file_types=["image"], label="Comparison Images")
modelName = gr.Dropdown(model_list, label="Model", value=model_list[0])
submit_btn = gr.Button("Compare")
with gr.Column():
gallery = gr.Gallery(label="Similarity Results")
similarity_text = gr.Textbox(label="Similarity Scores")
def process_comparison(main_img, compare_files, name):
compare_imgs = [Image.open(file.name) for file in compare_files]
results = compare_images(main_img, compare_imgs, name)
# Prepare outputs
images = [result[0] for result in results]
scores = [f"Image: {os.path.basename(result[0].filename)} -> Similarity: {result[1]:.2f}%" for result in
results]
return images, "\n".join(scores)
submit_btn.click(
fn=process_comparison,
inputs=[main_image, compare_images_input, modelName],
outputs=[gallery, similarity_text]
)
demo.launch()