Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from facenet_pytorch import InceptionResnetV1 | |
| import torch.nn as nn | |
| import torchvision.transforms as tf | |
| import numpy as np | |
| import torch | |
| import faiss | |
| import h5py | |
| import os | |
| import random | |
| from PIL import Image | |
| import matplotlib.cm as cm | |
| import matplotlib as mpl | |
| img_names = [] | |
| with open('list_eval_partition.txt', 'r') as f: | |
| for line in f: | |
| img_name, dtype = line.rstrip().split(' ') | |
| img_names.append(img_name) | |
| # For a model pretrained on VGGFace2 | |
| print('Loading model weights ........') | |
| class SiameseModel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = InceptionResnetV1(pretrained='vggface2') | |
| def forward(self, x): | |
| x = self.backbone(x) | |
| x = torch.nn.functional.normalize(x, dim=1) | |
| return x | |
| model = SiameseModel() | |
| model.load_state_dict(torch.load('model.pt', map_location=torch.device('cpu'))) | |
| print("____________LOADED___________________") | |
| model.eval() | |
| # Make FAISS index | |
| print('Make index .............') | |
| index = faiss.IndexFlatL2(512) | |
| hf = h5py.File('face_vecs_full.h5', 'r') | |
| for key in hf.keys(): | |
| vec = np.array(hf.get(key)) | |
| index.add(vec) | |
| hf.close() | |
| # Function to search image | |
| def image_search(image, k=5): | |
| transform = tf.Compose([ | |
| tf.Resize((160, 160)), | |
| tf.ToTensor() | |
| ]) | |
| query_img = transform(image) | |
| query_img = torch.unsqueeze(query_img, 0) | |
| model.eval() | |
| query_vec = model(query_img).detach().numpy() | |
| D, I = index.search(query_vec, k=k) | |
| retrieval_imgs = [] | |
| FOLDER = 'img_align_celeba' | |
| for idx in I[0]: | |
| img_file_name = img_names[idx] | |
| path = os.path.join(FOLDER, img_file_name) | |
| image = Image.open(path) | |
| retrieval_imgs.append((image, '')) | |
| return retrieval_imgs | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
| gr.Markdown(''' | |
| # Face Image Retrieval with Content-based Retrieval Image (CBIR) & Saliency Map | |
| -------- | |
| ''') | |
| with gr.Row(): | |
| with gr.Column(): | |
| image = gr.Image(type='pil', scale=1) | |
| slider = gr.Slider(1, 10, value=5, step=1, label='Number of retrieval image') | |
| with gr.Row(): | |
| btn = gr.Button('Search') | |
| clear_btn = gr.ClearButton() | |
| gallery = gr.Gallery(label='Retrieval Images', columns=[5], show_label=True, scale=2) | |
| img_dir = './img_align_celeba' | |
| examples = random.choices(img_names, k=6) | |
| examples = [os.path.join(img_dir, ex) for ex in examples] | |
| examples = [Image.open(img) for img in examples] | |
| with gr.Row(): | |
| gr.Examples( | |
| examples = examples, | |
| inputs = image | |
| ) | |
| btn.click(image_search, | |
| inputs= [image, slider], | |
| outputs= [gallery]) | |
| def clear_image(): | |
| return None | |
| clear_btn.click( | |
| fn = clear_image, | |
| inputs = [], | |
| outputs = [image] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |