Spaces:
Build error
Build error
| from cProfile import label | |
| from turtle import title | |
| import numpy as np | |
| import gradio as gr | |
| import pickle | |
| from skimage import io | |
| from scipy.spatial import distance | |
| # all the images name in a list | |
| images = [line.strip() for line in open("holidays_images.dat","r")] | |
| # all the query image names in a list | |
| query_images = [] | |
| for line in open("holidays_images.dat","r"): | |
| imname=line.strip() | |
| imno=int(imname[:-len(".jpg")]) | |
| if imno%100==0: | |
| query_images.append(imname) | |
| with open('saved_cnn.pkl', 'rb') as f: | |
| cnn_embeddings = pickle.load(f) | |
| with open('saved_bovw.pkl', 'rb') as f: | |
| bovw_embeddings = pickle.load(f) | |
| with open('saved_naive.pkl', 'rb') as f: | |
| naive_embeddings = pickle.load(f) | |
| def similarity_all(query_image_name, embeddings, metric): | |
| querry_embedding = embeddings[query_image_name] | |
| scores = {image_name : metric(querry_embedding, embeddings[image_name]) for image_name in images} | |
| return scores | |
| def euclidean_similarity_score(query_embedding, target_embedding): | |
| return np.linalg.norm(query_embedding-target_embedding) | |
| def cosine_similarity_score(query_embedding, target_embedding): | |
| return distance.cosine(np.reshape(query_embedding, -1), np.reshape(target_embedding, -1)) | |
| def retrieve(query_image_name, embeddings_type, metric_type): | |
| if embeddings_type == 'MobileNetV2' : | |
| embeddings = cnn_embeddings | |
| elif embeddings_type == 'BoVW' : | |
| embeddings = bovw_embeddings | |
| else : | |
| embeddings = naive_embeddings | |
| if metric_type == 'Euclidean' : | |
| metric = euclidean_similarity_score | |
| else : | |
| metric = cosine_similarity_score | |
| scores = similarity_all(query_image_name, embeddings, metric) | |
| top = sorted(scores, key=scores.get)[:11] | |
| return io.imread('smallholidays/'+top[0]), [io.imread('smallholidays/'+img) for img in top[1:]] | |
| input_button = gr.inputs.Dropdown(query_images, label='Choice of the query image') | |
| embeddings_selection = gr.inputs.Radio(['MobileNetV2', 'BoVW', 'Baseline'], label='Embeddings') | |
| metric_selection = gr.inputs.Radio(['Euclidean', 'Cosine'], label='Similarity Metric') | |
| retrieved_images = gr.outputs.Carousel(["image"], label='Retrieved images') | |
| description = "This is a demo of the content-based image retrieval system developed as part of the IR course project, 2022. The indexed dataset is [INRIA Holidays](https://lear.inrialpes.fr/~jegou/data.php). \n\nSeveral image embeddings can be used :\n \n-**MobileNetV2** : feature extraction is performed using a MobileNet architecture trained on ImageNet.\n\n-**BoVW (Bag of Visual Words)** : embedding is the BoVW histogram using color histogram as a descriptor.\n\n-**Baseline** : basic descriptor that uses pixel values of the downsized images." | |
| iface = gr.Interface(fn=retrieve, | |
| inputs=[input_button, embeddings_selection, metric_selection], | |
| outputs=[gr.outputs.Image(label='Query image'), retrieved_images], | |
| title='Image Retrieval on INRIA Holidays', | |
| description=description) | |
| iface.launch() |