Spaces:
Sleeping
Sleeping
| import csv | |
| import sys | |
| import gradio as gr | |
| import numpy as np | |
| import skimage.transform | |
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from matplotlib import pyplot as plt | |
| from numpy import matlib as mb | |
| from PIL import Image | |
| csv.field_size_limit(sys.maxsize) | |
| def compute_spatial_similarity(conv1, conv2): | |
| conv1 = conv1.reshape(-1, 7 * 7).T | |
| conv2 = conv2.reshape(-1, 7 * 7).T | |
| pool1 = np.mean(conv1, axis=0) | |
| pool2 = np.mean(conv2, axis=0) | |
| out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0]))) | |
| conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0] | |
| conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0] | |
| im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0])) | |
| for zz in range(conv1_normed.shape[0]): | |
| repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1) | |
| im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1) | |
| similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz) | |
| similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz) | |
| return similarity1, similarity2 | |
| display_transform = transforms.Compose( | |
| [transforms.Resize(256), transforms.CenterCrop((224, 224))] | |
| ) | |
| imagenet_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| class Wrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super(Wrapper, self).__init__() | |
| self.model = model | |
| self.layer4_ouputs = None | |
| def fw_hook(module, input, output): | |
| self.layer4_ouputs = output | |
| self.model.layer4.register_forward_hook(fw_hook) | |
| def forward(self, input): | |
| _ = self.model(input) | |
| return self.layer4_ouputs | |
| def get_layer4(input_image): | |
| l4_model = models.resnet50(pretrained=True) | |
| l4_model.eval() | |
| wrapped_model = Wrapper(l4_model) | |
| with torch.no_grad(): | |
| data = imagenet_transform(input_image).unsqueeze(0) | |
| reference_layer4 = wrapped_model(data) | |
| return reference_layer4.data.to("cpu").numpy() | |
| def NormalizeData(data): | |
| return (data - np.min(data)) / (np.max(data) - np.min(data)) | |
| # Visualization | |
| def visualize_similarities(image1, image2): | |
| print(f"image1: {image1}") | |
| print(f"image2: {image2}") | |
| print(type(image1)) | |
| a = get_layer4(image1).squeeze() | |
| b = get_layer4(image2).squeeze() | |
| sim1, sim2 = compute_spatial_similarity(a, b) | |
| sim1 = NormalizeData(sim1) | |
| sim2 = NormalizeData(sim2) | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| axes[0].imshow(display_transform(image1)) | |
| im1 = axes[0].imshow( | |
| skimage.transform.resize(sim1, (224, 224)), | |
| alpha=0.5, | |
| cmap="jet", | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| axes[1].imshow(display_transform(image2)) | |
| im2 = axes[1].imshow( | |
| skimage.transform.resize(sim2, (224, 224)), | |
| alpha=0.5, | |
| cmap="jet", | |
| vmin=0, | |
| vmax=1, | |
| ) | |
| axes[0].set_axis_off() | |
| axes[1].set_axis_off() | |
| fig.colorbar(im1, ax=axes[0]) | |
| fig.colorbar(im2, ax=axes[1]) | |
| plt.tight_layout() | |
| # q_image = display_transform(image1) | |
| # nearest_image = display_transform(image2) | |
| # # make a binarized veruin of the Q | |
| # fig2, ax2 = plt.subplots(1, figsize=(5, 5)) | |
| # ax2.imshow(display_transform(image1)) | |
| # # create a binarized version of sim1 , for value below 0.5 set to 0 and above 0.5 set to 1 | |
| # sim1_bin = np.where(sim1 > 0.5, 1, 0) | |
| # # create a binarized version of sim2 , for value below 0.5 set to 0 and above 0.5 set to 1 | |
| # sim2_bin = np.where(sim2 > 0.5, 1, 0) | |
| # ax2.imshow( | |
| # skimage.transform.resize(sim1_bin, (224, 224)), | |
| # alpha=1, | |
| # cmap="binary", | |
| # vmin=0, | |
| # vmax=1, | |
| # ) | |
| return fig | |
| blocks = gr.Blocks() | |
| with blocks as demo: | |
| gr.Markdown("# Visualizing Deep Similarity Networks") | |
| gr.Markdown("A quick demo to visualize the similarity between two images.") | |
| gr.Markdown( | |
| "[Original Paper](https://arxiv.org/pdf/1901.00536.pdf) - [Github Page](https://github.com/GWUvision/Similarity-Visualization)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image1 = gr.Image(label="Image 1", type="pil") | |
| image2 = gr.Image(label="Image 2", type="pil") | |
| with gr.Column(): | |
| sim1_output = gr.Plot() | |
| examples = gr.Examples( | |
| examples=[ | |
| [ | |
| "./examples/Red_Winged_Blackbird_0012_6015.jpg", | |
| "./examples/Red_Winged_Blackbird_0025_5342.jpg", | |
| ], | |
| ], | |
| inputs=[image1, image2], | |
| ) | |
| btn = gr.Button("Compute Similarity") | |
| btn.click(visualize_similarities, inputs=[image1, image2], outputs=[sim1_output]) | |
| demo.launch(debug=True) | |
| # blocks.launch(debug=True) | |