| |
| |
| |
|
|
| import sys |
| from pdb import set_trace as bb |
| from PIL import Image |
| import numpy as np |
|
|
| import matplotlib.pyplot as pl; pl.ion() |
| import torch |
| import torch.nn.functional as F |
|
|
| from core import functional as myF |
| from .common import cpu, nparray, image, image_with_trf |
|
|
|
|
| def dbgfig(*args, **kwargs): |
| assert len(args) >= 2 |
| dbg = args[-1] |
| if isinstance(dbg, str): |
| dbg = dbg.split() |
| for name in args[:-1]: |
| if {name,'all'} & set(dbg): |
| return pl.figure(name, **kwargs) |
| return False |
|
|
|
|
| def noticks(ax=None): |
| if ax is None: ax = pl.gca() |
| ax.set_xticks(()) |
| ax.set_yticks(()) |
| return ax |
|
|
|
|
| def plot_grid( corres, ax1, ax2=None, marker='+' ): |
| """ corres = Nx2 or Nx4 list of correspondences |
| """ |
| if marker is True: marker = '+' |
|
|
| corres = nparray(corres) |
| |
| center = corres[:,[1,0]].mean(axis=0) |
| colors = np.arctan2(*(corres[:,[1,0]] - center).T) |
| colors = np.int32(64*colors/np.pi) % 128 |
|
|
| all_colors = np.unique(colors) |
| palette = {m:pl.cm.hsv(i/float(len(all_colors))) for i,m in enumerate(all_colors)} |
|
|
| for m in all_colors: |
| x, y = corres[colors==m,0:2].T |
| ax1.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0) |
|
|
| if not ax2: return |
| for m in all_colors: |
| x, y = corres[colors==m,2:4].T |
| ax2.plot(x, y, marker, ms=10, mew=2, color=palette[m], scalex=0, scaley=0) |
|
|
|
|
| def show_correspondences( img0, img1, corres, F=None, fig='last', show_grid=True, bb=None, clf=False): |
| img0, trf0 = img0 if isinstance(img0, tuple) else (img0, torch.eye(3)) |
| img1, trf1 = img1 if isinstance(img1, tuple) else (img1, torch.eye(3)) |
| if not bb: pl.ioff() |
| fig, axes = pl.subplots(2, 2, num=fig_num(fig, 'viz_corres')) |
| for i, ax in enumerate(axes.ravel()): |
| if clf: ax.cla() |
| noticks(ax).numaxis = i % 2 |
| ax.imshow( [image(img0),image(img1)][i%2] ) |
|
|
| if corres.shape == (3,3): |
| from pytools.hfuncs import applyh |
| H, W = axes[0,0].images[0].get_size() |
| pos1 = np.mgrid[:H,:W].reshape(2,-1)[::-1].T |
| pos2 = applyh(corres, pos1) |
| corres = np.concatenate((pos1,pos2), axis=-1) |
|
|
| inv = np.linalg.inv |
| corres = myF.affmul((inv(nparray(trf0)),inv(nparray(trf1))), nparray(corres)) |
| print(f">> Displaying {len(corres)} correspondences (move you mouse over the images)") |
|
|
| (ax1, ax2), (ax3, ax4) = axes |
| if corres.shape[-1] > 4: |
| corres = corres[corres[:,4]>0,:] |
| if show_grid: plot_grid(corres, ax3, ax4, marker=show_grid) |
|
|
| def mouse_move(event): |
| if event.inaxes==None: return |
| numaxis = event.inaxes.numaxis |
| if numaxis<0: return |
| x,y = event.xdata, event.ydata |
| ax1.lines.clear() |
| ax2.lines.clear() |
| sl = slice(2*numaxis, 2*(numaxis+1)) |
| n = np.sum((corres[:,sl] - [x,y])**2,axis=1).argmin() |
| print("\rdisplaying #%d (%d,%d) --> (%d,%d), score=%g, code=%g" % (n, |
| corres[n,0],corres[n,1],corres[n,2],corres[n,3], |
| corres[n,4] if corres.shape[-1] > 4 else np.nan, |
| corres[n,5] if corres.shape[-1] > 5 else np.nan), end=' '*7);sys.stdout.flush() |
| x,y = corres[n,0:2] |
| ax1.plot(x, y, '+', ms=10, mew=2, color='blue', scalex=False, scaley=False) |
| x,y = corres[n,2:4] |
| ax2.plot(x, y, '+', ms=10, mew=2, color='red', scalex=False, scaley=False) |
| if F is not None: |
| ax = None |
| if numaxis == 0: |
| line = corres[n,0:2] @ F[:2] + F[2] |
| ax = ax2 |
| if numaxis == 1: |
| line = corres[n,2:4] @ F.T[:2] + F.T[2] |
| ax = ax1 |
| if ax: |
| x = np.linspace(-10000,10000,2) |
| y = (line[2]+line[0]*x) / -line[1] |
| ax.plot(x, y, '-', scalex=0, scaley=0) |
|
|
| |
| renderer = fig.canvas.get_renderer() |
| ax1.draw(renderer) |
| ax2.draw(renderer) |
| fig.canvas.blit(ax1.bbox) |
| fig.canvas.blit(ax2.bbox) |
|
|
| cid_move = fig.canvas.mpl_connect('motion_notify_event',mouse_move) |
| pl.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.99, wspace=0.02, hspace=0.02) |
| bb() if bb else pl.show() |
| fig.canvas.mpl_disconnect(cid_move) |
| |
|
|
| def closest( grid, event ): |
| query = (event.xdata, event.ydata) |
| n = np.linalg.norm(grid.reshape(-1,2) - query, axis=1).argmin() |
| return np.unravel_index(n, grid.shape[:2]) |
|
|
|
|
| def local_maxima( arr2d, top=5 ): |
| maxpooled = F.max_pool2d( arr2d[None, None], 3, padding=1, stride=1)[0,0] |
| local_maxima = (arr2d == maxpooled).nonzero() |
| order = arr2d[local_maxima.split(1,dim=1)].ravel().argsort() |
| return local_maxima[order[-5:]].T |
|
|
|
|
| def fig_num( fig, default, clf=False ): |
| if fig == 'last': num = pl.gcf().number |
| elif fig: num = fig.number |
| else: num = default |
| if clf: pl.figure(num).clf() |
| return num |
|
|
|
|
| def viz_correlation_maps( img1, img2, corr, level=0, fig=None, grid1=None, grid2=None, show_grid=False, bb=bb, **kw ): |
| fig, ((ax1, ax2), (ax4, ax3)) = pl.subplots(2, 2, num=fig_num(fig, 'viz_correlation_maps', clf=True)) |
| img1 = image(img1) |
| img2 = image(img2) |
| noticks(ax1).imshow( img1 ) |
| noticks(ax2).imshow( img2 ) |
| ax4.hist(corr.ravel()[7:7777777:7].cpu().numpy(), bins=50) |
|
|
| if isinstance(corr, tuple): |
| H1, W1 = corr.grid.shape[:2] |
| corr = torch.from_numpy(corr.res_map).view(H1,W1,*corr.res_map.shape[-2:]) |
|
|
| if grid1 is None: |
| s1 = int(0.5 + np.sqrt(img1.size / (3 * corr[...,0,0].numel()))) |
| grid1 = nparray(torch.ones_like(corr[:,:,0,0]).nonzero()*s1)[:,1::-1] |
| if level == 0: grid1 += s1//2 |
| if show_grid: plot_grid(grid1, ax1) |
| grid1 = nparray(grid1).reshape(*corr[:,:,0,0].shape,2) |
|
|
| if grid2 is None: |
| s2 = int(0.5 + np.sqrt(img2.size / (3 * corr[0,0,...].numel()))) |
| grid2 = nparray(torch.ones_like(corr[0,0]).nonzero()*s2)[:,::-1] |
| grid2 = nparray(grid2).reshape(*corr.shape[2:],2) |
|
|
| def mouse_move(ev): |
| if ev.inaxes is ax1: |
| ax3.images.clear() |
| n = closest(grid1, ev) |
| ax3.imshow(corr[n].cpu().float(), vmin=0, **kw) |
|
|
| |
| lm = nparray(local_maxima(corr[n])) |
| for ax in (ax3, ax2): |
| if ax is ax2 and not show_grid: |
| ax1.lines.clear() |
| ax1.plot(*grid1[n], 'xr', ms=10, scalex=0, scaley=0) |
| ax.lines.clear() |
| x, y = grid2[y,x].T if ax is ax2 else lm[::-1] |
| if ax is not ax3: |
| ax.plot(x, y, 'xr', ms=10, scalex=0, scaley=0, label='local maxima') |
| print(f"\rCorr channel {n}. Min={corr[n].min():g}, Avg={corr[n].mean():g}, Max={corr[n].max():g} ", end='') |
|
|
| mouse_move(FakeEvent(0,0,inaxes=ax1)) |
| cid_move = fig.canvas.mpl_connect('motion_notify_event', mouse_move) |
| pl.subplots_adjust(0,0,1,1,0,0) |
| pl.sca(ax4) |
| if bb: bb(); fig.canvas.mpl_disconnect(cid_move) |
|
|
| def viz_correspondences( img1, img2, corres1, corres2, fig=None ): |
| img1, img2 = map(image, (img1, img2)) |
| fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = pl.subplots(3,2, num=fig_num(fig, 'viz_correspondences')) |
| for ax in fig.axes: noticks(ax) |
| ax1.imshow( img1 ) |
| ax2.imshow( img2 ) |
| ax3.imshow( img1 ) |
| ax4.imshow( img2 ) |
| corres1, corres2 = map(cpu, (corres1, corres2)) |
| plot_grid( corres1[0], ax1, ax2 ) |
| plot_grid( corres2[0], ax3, ax4 ) |
|
|
| corres1, corres2 = corres1[1].float(), corres2[1].float() |
| ceiling = np.ceil(max(corres1.max(), corres2.max()).item()) |
| ax5.imshow( corres1, vmin=0, vmax=ceiling ) |
| ax6.imshow( corres2, vmin=0, vmax=ceiling ) |
| bb() |
|
|
|
|
| class FakeEvent: |
| def __init__(self, xdata, ydata, **kw): |
| self.xdata = xdata |
| self.ydata = ydata |
| for name, val in kw.items(): |
| setattr(self, name, val) |
|
|
|
|
| def show_random_pairs( db, pair_idxs=None, **kw ): |
| print('Showing random pairs from', db) |
|
|
| if pair_idxs is None: |
| pair_idxs = np.random.permutation(len(db)) |
|
|
| for pair_idx in pair_idxs: |
| print(f'{pair_idx=}') |
| try: |
| img1_path, img2_path = map(db.imgs.get_image_path, db.pairs[pair_idx]) |
| print(f'{img1_path=}\n{img2_path=}') |
| if hasattr(db, 'get_corres_path'): |
| print(f'corres_path = {db.get_corres_path(pair_idx)}') |
| except: pass |
| (img1, img2), gt = db[pair_idx] |
|
|
| if 'corres' in gt: |
| corres = gt['corres'] |
| else: |
| |
| from datasets.utils import corres_from_homography |
| corres = corres_from_homography(gt['homography'], *img1.size) |
|
|
| show_correspondences(img1, img2, corres, **kw) |
|
|
|
|
| if __name__=='__main__': |
| import argparse |
| import test_singlescale as pump |
|
|
| parser = argparse.ArgumentParser('Correspondence visualization') |
| parser.add_argument('--img1', required=True, help='path to first image') |
| parser.add_argument('--img2', required=True, help='path to second image') |
| parser.add_argument('--corres', required=True, help='path to correspondences') |
| args = parser.parse_args() |
|
|
| corres = np.load(args.corres)['corres'] |
|
|
| args.resize = 0 |
| imgs = tuple(map(image, pump.Main.load_images(args))) |
|
|
| show_correspondences(*imgs, corres) |
|
|