Spaces:
Build error
Build error
| from torch.utils.data import DataLoader | |
| import torch | |
| from model.base.geometry import Geometry | |
| from common.evaluation import Evaluator | |
| from common.logger import AverageMeter | |
| from common.logger import Logger | |
| from data import download | |
| from model import chmnet | |
| from matplotlib import pyplot as plt | |
| from matplotlib.patches import ConnectionPatch | |
| from PIL import Image | |
| import numpy as np | |
| import os | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| import torchvision.models as models | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import random | |
| import gradio as gr | |
| # Downloading the Model | |
| torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt') | |
| # Model Initialization | |
| args = dict({ | |
| 'alpha' : [0.05, 0.1], | |
| 'benchmark':'pfpascal', | |
| 'bsz':90, | |
| 'datapath':'../Datasets_CHM', | |
| 'img_size':240, | |
| 'ktype':'psi', | |
| 'load':'pas_psi.pt', | |
| 'thres':'img' | |
| }) | |
| model = chmnet.CHMNet(args['ktype']) | |
| model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu'))) | |
| Evaluator.initialize(args['alpha']) | |
| Geometry.initialize(img_size=args['img_size']) | |
| model.eval(); | |
| # Transforms | |
| chm_transform = transforms.Compose( | |
| [transforms.Resize(args['img_size']), | |
| transforms.CenterCrop((args['img_size'], args['img_size'])), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) | |
| chm_transform_plot = transforms.Compose( | |
| [transforms.Resize(args['img_size']), | |
| transforms.CenterCrop((args['img_size'], args['img_size']))]) | |
| # A Helper Function | |
| to_np = lambda x: x.data.to('cpu').numpy() | |
| # Colors for Plotting | |
| cmap = matplotlib.cm.get_cmap('Spectral') | |
| rgba = cmap(0.5) | |
| colors = [] | |
| for k in range(49): | |
| colors.append(cmap(k/49.0)) | |
| # CHM MODEL | |
| def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform): | |
| # Convert to Tensor | |
| src_img_tnsr = chm_transform(source_image).unsqueeze(0) | |
| tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) | |
| # Selected_points = selected_points.T | |
| keypoints = torch.tensor(selected_points).unsqueeze(0) | |
| n_pts = torch.tensor(np.asarray([number_src_points])) | |
| # RUN CHM ------------------------------------------------------------------------ | |
| with torch.no_grad(): | |
| corr_matrix = model(src_img_tnsr, tgt_img_tnsr) | |
| prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) | |
| # VISUALIZATION | |
| src_points = keypoints[0].squeeze(0).squeeze(0).numpy() | |
| tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() | |
| src_points_converted = [] | |
| w, h = display_transform(source_image).size | |
| for x,y in zip(src_points[0], src_points[1]): | |
| src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])]) | |
| src_points_converted = np.asarray(src_points_converted[:number_src_points]) | |
| tgt_points_converted = [] | |
| w, h = display_transform(target_image).size | |
| for x, y in zip(tgt_points[0], tgt_points[1]): | |
| tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)]) | |
| tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) | |
| tgt_grid = [] | |
| for x, y in zip(tgt_points[0], tgt_points[1]): | |
| tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)]) | |
| # PLOT | |
| fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) | |
| ax[0].imshow(display_transform(source_image)) | |
| ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points]) | |
| ax[0].set_title('Source') | |
| ax[0].set_xticks([]) | |
| ax[0].set_yticks([]) | |
| ax[1].imshow(display_transform(target_image)) | |
| ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points]) | |
| ax[1].set_title('Target') | |
| ax[1].set_xticks([]) | |
| ax[1].set_yticks([]) | |
| for TL in range(49): | |
| ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=10)) | |
| for TL in range(49): | |
| ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=8)) | |
| plt.tight_layout() | |
| fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16) | |
| return fig | |
| # Wrapper | |
| def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100): | |
| A = np.linspace(min_x, max_x, 7) | |
| B = np.linspace(min_y, max_y, 7) | |
| point_list = list(product(A, B)) | |
| new_points = np.asarray(point_list, dtype=np.float64).T | |
| return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot) | |
| # GRADIO APP | |
| iface = gr.Interface(fn=generate_correspondences, | |
| inputs=[gr.inputs.Image(shape=(240, 240), type='pil'), | |
| gr.inputs.Image(shape=(240, 240), type='pil'), | |
| gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinX'), | |
| gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxX'), | |
| gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinY'), | |
| gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxY')], outputs="plot") | |
| iface.launch() |