from typing import Dict, List, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from kornia.constants import pi from kornia.filters import GaussianBlur2d, SpatialGradient from kornia.geometry.conversions import cart2pol from kornia.utils import create_meshgrid # Precomputed coefficients for Von Mises kernel, given N and K(appa). sqrt2: float = 1.4142135623730951 COEFFS_N1_K1: List[float] = [0.38214156, 0.48090413] COEFFS_N2_K8: List[float] = [0.14343168, 0.268285, 0.21979234] COEFFS_N3_K8: List[float] = [0.14343168, 0.268285, 0.21979234, 0.15838885] COEFFS: Dict[str, List[float]] = {'xy': COEFFS_N1_K1, 'rhophi': COEFFS_N2_K8, 'theta': COEFFS_N3_K8} urls: Dict[str, str] = { k: f'https://github.com/manyids2/mkd_pytorch/raw/master/mkd_pytorch/mkd-{k}-64.pth' for k in ['cart', 'polar', 'concat'] } def get_grid_dict(patch_size: int = 32) -> Dict[str, torch.Tensor]: r"""Get cartesian and polar parametrizations of grid.""" kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True) x = kgrid[0, :, :, 0] y = kgrid[0, :, :, 1] rho, phi = cart2pol(x, y) grid_dict = {'x': x, 'y': y, 'rho': rho, 'phi': phi} return grid_dict def get_kron_order(d1: int, d2: int) -> torch.Tensor: r"""Get order for doing kronecker product.""" kron_order = torch.zeros([d1 * d2, 2], dtype=torch.int64) for i in range(d1): for j in range(d2): kron_order[i * d2 + j, 0] = i kron_order[i * d2 + j, 1] = j return kron_order class MKDGradients(nn.Module): r"""Module, which computes gradients of given patches, stacked as [magnitudes, orientations]. Given gradients $g_x$, $g_y$ with respect to $x$, $y$ respectively, - $\mathbox{mags} = $\sqrt{g_x^2 + g_y^2 + eps}$ - $\mathbox{oris} = $\mbox{tan}^{-1}(\nicefrac{g_y}{g_x})$. Args: patch_size: Input patch size in pixels. Returns: gradients of given patches. Shape: - Input: (B, 1, patch_size, patch_size) - Output: (B, 2, patch_size, patch_size) Example: >>> patches = torch.rand(23, 1, 32, 32) >>> gradient = MKDGradients() >>> g = gradient(patches) # 23x2x32x32 """ def __init__(self) -> None: super().__init__() self.eps = 1e-8 self.grad = SpatialGradient(mode='diff', order=1, normalized=False) def forward(self, x: torch.Tensor) -> torch.Tensor: if not isinstance(x, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(x)}") if not len(x.shape) == 4: raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}") # Modify 'diff' gradient. Before we had lambda function, but it is not jittable grads_xy = -self.grad(x) gx = grads_xy[:, :, 0, :, :] gy = grads_xy[:, :, 1, :, :] y = torch.cat(cart2pol(gx, gy, self.eps), dim=1) return y def __repr__(self) -> str: return self.__class__.__name__ class VonMisesKernel(nn.Module): r"""Module, which computes parameters of Von Mises kernel given coefficients, and embeds given patches. Args: patch_size: Input patch size in pixels. coeffs: List of coefficients. Some examples are hardcoded in COEFFS, Returns: Von Mises embedding of given parametrization. Shape: - Input: (B, 1, patch_size, patch_size) - Output: (B, d, patch_size, patch_size) Examples: >>> oris = torch.rand(23, 1, 32, 32) >>> vm = VonMisesKernel(patch_size=32, ... coeffs=[0.14343168, ... 0.268285, ... 0.21979234]) >>> emb = vm(oris) # 23x7x32x32 """ def __init__(self, patch_size: int, coeffs: Union[list, tuple]) -> None: super().__init__() self.patch_size = patch_size b_coeffs: torch.Tensor = torch.tensor(coeffs) self.register_buffer('coeffs', b_coeffs) # Compute parameters. n: int = len(coeffs) - 1 self.n: int = n self.d: int = 2 * n + 1 # Precompute helper variables. emb0 = torch.ones([1, 1, patch_size, patch_size]) frange = torch.arange(n) + 1 frange = frange.reshape(-1, 1, 1) weights = torch.zeros([2 * n + 1]) weights[: n + 1] = torch.sqrt(b_coeffs) weights[n + 1:] = torch.sqrt(b_coeffs[1:]) weights = weights.reshape(-1, 1, 1) self.register_buffer('emb0', emb0) self.register_buffer('frange', frange) self.register_buffer('weights', weights) def forward(self, x: torch.Tensor) -> torch.Tensor: if not isinstance(x, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(x)}") if not len(x.shape) == 4 or x.shape[1] != 1: raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {x.shape}") # TODO: unify the two lines below when pytorch 1.6 support is dropped emb0: torch.Tensor = torch.jit.annotate(torch.Tensor, self.emb0) emb0 = emb0.to(x).repeat(x.size(0), 1, 1, 1) frange = self.frange.to(x) * x emb1 = torch.cos(frange) emb2 = torch.sin(frange) embedding = torch.cat([emb0, emb1, emb2], dim=1) embedding = self.weights * embedding return embedding def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'patch_size=' + str(self.patch_size) + ', ' + 'n=' + str(self.n) + ', ' + 'd=' + str(self.d) + ', ' + 'coeffs=' + str(self.coeffs) + ')' ) class EmbedGradients(nn.Module): r"""Module that computes gradient embedding, weighted by sqrt of magnitudes of given patches. Args: patch_size: Input patch size in pixels. relative: absolute or relative gradients. Returns: Gradient embedding. Shape: - Input: (B, 2, patch_size, patch_size) - Output: (B, 7, patch_size, patch_size) Examples: >>> grads = torch.rand(23, 2, 32, 32) >>> emb_grads = EmbedGradients(patch_size=32, ... relative=False) >>> emb = emb_grads(grads) # 23x7x32x32 """ def __init__(self, patch_size: int = 32, relative: bool = False) -> None: super().__init__() self.patch_size = patch_size self.relative = relative self.eps = 1e-8 # Theta kernel for gradients. self.kernel = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS['theta']) # Relative gradients. kgrid = create_meshgrid(height=patch_size, width=patch_size, normalized_coordinates=True) _, phi = cart2pol(kgrid[:, :, :, 0], kgrid[:, :, :, 1]) self.register_buffer('phi', phi) def emb_mags(self, mags: torch.Tensor) -> torch.Tensor: """Embed square roots of magnitudes with eps for numerical reasons.""" mags = torch.sqrt(mags + self.eps) return mags def forward(self, grads: torch.Tensor) -> torch.Tensor: if not isinstance(grads, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(grads)}") if not len(grads.shape) == 4: raise ValueError(f"Invalid input shape, we expect Bx2xHxW. Got: {grads.shape}") mags = grads[:, :1, :, :] oris = grads[:, 1:, :, :] if self.relative: oris = oris - self.phi.to(oris) y = self.kernel(oris) * self.emb_mags(mags) return y def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'patch_size=' + str(self.patch_size) + ', ' + 'relative=' + str(self.relative) + ')' ) def spatial_kernel_embedding(kernel_type, grids: dict) -> torch.Tensor: r"""Compute embeddings for cartesian and polar parametrizations.""" factors = {"phi": 1.0, "rho": pi / sqrt2, "x": pi / 2, "y": pi / 2} if kernel_type == 'cart': coeffs_ = 'xy' params_ = ['x', 'y'] elif kernel_type == 'polar': coeffs_ = 'rhophi' params_ = ['phi', 'rho'] # Infer patch_size. keys = list(grids.keys()) patch_size = grids[keys[0]].shape[-1] # Scale appropriately. grids_normed = {k: v * factors[k] for k, v in grids.items()} grids_normed = {k: v.unsqueeze(0).unsqueeze(0).float() for k, v in grids_normed.items()} # x,y/rho,phi kernels. vm_a = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_]) vm_b = VonMisesKernel(patch_size=patch_size, coeffs=COEFFS[coeffs_]) emb_a = vm_a(grids_normed[params_[0]]).squeeze() emb_b = vm_b(grids_normed[params_[1]]).squeeze() # Final precomputed position embedding. kron_order = get_kron_order(vm_a.d, vm_b.d) spatial_kernel = emb_a.index_select(0, kron_order[:, 0]) * emb_b.index_select(0, kron_order[:, 1]) return spatial_kernel class ExplicitSpacialEncoding(nn.Module): r"""Module that computes explicit cartesian or polar embedding. Args: kernel_type: Parametrization of kernel ``'polar'`` or ``'cart'``. fmap_size: Input feature map size in pixels. in_dims: Dimensionality of input feature map. do_gmask: Apply gaussian mask. do_l2: Apply l2-normalization. Returns: Explicit cartesian or polar embedding. Shape: - Input: (B, in_dims, fmap_size, fmap_size) - Output: (B, out_dims, fmap_size, fmap_size) Example: >>> emb_ori = torch.rand(23, 7, 32, 32) >>> ese = ExplicitSpacialEncoding(kernel_type='polar', ... fmap_size=32, ... in_dims=7, ... do_gmask=True, ... do_l2=True) >>> desc = ese(emb_ori) # 23x175x32x32 """ def __init__( self, kernel_type: str = 'polar', fmap_size: int = 32, in_dims: int = 7, do_gmask: bool = True, do_l2: bool = True, ) -> None: super().__init__() if kernel_type not in ['polar', 'cart']: raise NotImplementedError(f'{kernel_type} is not valid, use polar or cart).') self.kernel_type = kernel_type self.fmap_size = fmap_size self.in_dims = in_dims self.do_gmask = do_gmask self.do_l2 = do_l2 self.grid = get_grid_dict(fmap_size) self.gmask = None # Precompute embedding. emb = spatial_kernel_embedding(self.kernel_type, self.grid) # Gaussian mask. if self.do_gmask: self.gmask = self.get_gmask(sigma=1.0) emb = emb * self.gmask # Store precomputed embedding. self.register_buffer('emb', emb.unsqueeze(0)) self.d_emb: int = emb.shape[0] self.out_dims: int = self.in_dims * self.d_emb self.odims: int = self.out_dims # Store kronecker form. emb2, idx1 = self.init_kron() self.register_buffer('emb2', emb2) self.register_buffer('idx1', idx1) def get_gmask(self, sigma: float) -> torch.Tensor: """Compute Gaussian mask.""" norm_rho = self.grid['rho'] / self.grid['rho'].max() gmask = torch.exp(-1 * norm_rho ** 2 / sigma ** 2) return gmask def init_kron(self) -> Tuple[torch.Tensor, torch.Tensor]: """Initialize helper variables to calculate kronecker.""" kron = get_kron_order(self.in_dims, self.d_emb) _emb = torch.jit.annotate(torch.Tensor, self.emb) emb2 = torch.index_select(_emb, 1, kron[:, 1]) return emb2, kron[:, 0] def forward(self, x: torch.Tensor) -> torch.Tensor: if not isinstance(x, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(x)}") if not ((len(x.shape) == 4) | (x.shape[1] == self.in_dims)): raise ValueError(f"Invalid input shape, we expect Bx{self.in_dims}xHxW. Got: {x.shape}") idx1 = torch.jit.annotate(torch.Tensor, self.idx1) emb1 = torch.index_select(x, 1, idx1) output = emb1 * self.emb2 output = output.sum(dim=(2, 3)) if self.do_l2: output = F.normalize(output, dim=1) return output def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'kernel_type=' + str(self.kernel_type) + ', ' + 'fmap_size=' + str(self.fmap_size) + ', ' + 'in_dims=' + str(self.in_dims) + ', ' + 'out_dims=' + str(self.out_dims) + ', ' + 'do_gmask=' + str(self.do_gmask) + ', ' + 'do_l2=' + str(self.do_l2) + ')' ) class Whitening(nn.Module): r"""Module, performs supervised or unsupervised whitening. This is based on the paper "Understanding and Improving Kernel Local Descriptors". See :cite:`mukundan2019understanding` for more details. Args: xform: Variant of whitening to use. None, 'lw', 'pca', 'pcaws', 'pcawt'. whitening_model: Dictionary with keys 'mean', 'eigvecs', 'eigvals' holding torch.Tensors. in_dims: Dimensionality of input descriptors. output_dims: (int) Dimensionality reduction. keval: Shrinkage parameter. t: Attenuation parameter. Returns: l2-normalized, whitened descriptors. Shape: - Input: (B, in_dims, fmap_size, fmap_size) - Output: (B, out_dims, fmap_size, fmap_size) Examples: >>> descs = torch.rand(23, 238) >>> whitening_model = {'pca': {'mean': torch.zeros(238), ... 'eigvecs': torch.eye(238), ... 'eigvals': torch.ones(238)}} >>> whitening = Whitening(xform='pcawt', ... whitening_model=whitening_model, ... in_dims=238, ... output_dims=128, ... keval=40, ... t=0.7) >>> wdescs = whitening(descs) # 23x128 """ def __init__( self, xform: str, whitening_model: Union[Dict[str, Dict[str, torch.Tensor]], None], in_dims: int, output_dims: int = 128, keval: int = 40, t: float = 0.7, ) -> None: super().__init__() self.xform = xform self.in_dims = in_dims self.keval = keval self.t = t self.pval = 1.0 # Compute true output_dims. output_dims = min(output_dims, in_dims) self.output_dims = output_dims # Initialize identity transform. self.mean = nn.Parameter(torch.zeros(in_dims), requires_grad=True) self.evecs = nn.Parameter(torch.eye(in_dims)[:, :output_dims], requires_grad=True) self.evals = nn.Parameter(torch.ones(in_dims)[:output_dims], requires_grad=True) if whitening_model is not None: self.load_whitening_parameters(whitening_model) def load_whitening_parameters(self, whitening_model: Dict[str, Dict[str, torch.Tensor]]) -> None: algo = 'lw' if self.xform == 'lw' else 'pca' wh_model = whitening_model[algo] self.mean.data = wh_model['mean'] self.evecs.data = wh_model['eigvecs'][:, : self.output_dims] self.evals.data = wh_model['eigvals'][: self.output_dims] modifications = { 'pca': self._modify_pca, 'lw': self._modify_lw, 'pcaws': self._modify_pcaws, 'pcawt': self._modify_pcawt, } # Call modification. modifications[self.xform]() def _modify_pca(self) -> None: """Modify powerlaw parameter.""" self.pval = 0.5 def _modify_lw(self) -> None: """No modification required.""" def _modify_pcaws(self) -> None: """Shrinkage for eigenvalues.""" alpha = self.evals[self.keval] evals = ((1 - alpha) * self.evals) + alpha self.evecs.data = self.evecs @ torch.diag(torch.pow(evals, -0.5)) def _modify_pcawt(self) -> None: """Attenuation for eigenvalues.""" m = -0.5 * self.t self.evecs.data = self.evecs @ torch.diag(torch.pow(self.evals, m)) def forward(self, x: torch.Tensor) -> torch.Tensor: if not isinstance(x, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(x)}") if not len(x.shape) == 2: raise ValueError(f"Invalid input shape, we expect NxD. Got: {x.shape}") x = x - self.mean # Center the data. x = x @ self.evecs # Apply rotation and/or scaling. x = torch.sign(x) * torch.pow(torch.abs(x), self.pval) # Powerlaw. return F.normalize(x, dim=1) def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'xform=' + str(self.xform) + ', ' + 'in_dims=' + str(self.in_dims) + ', ' + 'output_dims=' + str(self.output_dims) + ')' ) class MKDDescriptor(nn.Module): r"""Module that computes Multiple Kernel local descriptors. This is based on the paper "Understanding and Improving Kernel Local Descriptors". See :cite:`mukundan2019understanding` for more details. Args: patch_size: Input patch size in pixels. kernel_type: Parametrization of kernel ``'concat'``, ``'cart'``, ``'polar'``. whitening: Whitening transform to apply ``None``, ``'lw'``, ``'pca'``, ``'pcawt'``, ``'pcaws'``. training_set: Set that model was trained on ``'liberty'``, ``'notredame'``, ``'yosemite'``. output_dims: Dimensionality reduction. Returns: Explicit cartesian or polar embedding. Shape: - Input: :math:`(B, in_{dims}, fmap_{size}, fmap_{size})`. - Output: :math:`(B, out_{dims}, fmap_{size}, fmap_{size})`, Examples: >>> patches = torch.rand(23, 1, 32, 32) >>> mkd = MKDDescriptor(patch_size=32, ... kernel_type='concat', ... whitening='pcawt', ... training_set='liberty', ... output_dims=128) >>> desc = mkd(patches) # 23x128 """ def __init__( self, patch_size: int = 32, kernel_type: str = 'concat', whitening: str = 'pcawt', training_set: str = 'liberty', output_dims: int = 128, ) -> None: super().__init__() self.patch_size: int = patch_size self.kernel_type: str = kernel_type self.whitening: str = whitening self.training_set: str = training_set self.sigma = 1.4 * (patch_size / 64) self.smoothing = GaussianBlur2d((5, 5), (self.sigma, self.sigma), 'replicate') self.gradients = MKDGradients() # This stupid thing needed for jitting... polar_s: str = 'polar' cart_s: str = 'cart' self.parametrizations = [polar_s, cart_s] if self.kernel_type == 'concat' else [self.kernel_type] # Initialize cartesian/polar embedding with absolute/relative gradients. self.odims: int = 0 relative_orientations = {polar_s: True, cart_s: False} self.feats = {} for parametrization in self.parametrizations: gradient_embedding = EmbedGradients(patch_size=patch_size, relative=relative_orientations[parametrization]) spatial_encoding = ExplicitSpacialEncoding( kernel_type=parametrization, fmap_size=patch_size, in_dims=gradient_embedding.kernel.d ) self.feats[parametrization] = nn.Sequential(gradient_embedding, spatial_encoding) self.odims += spatial_encoding.odims # Compute true output_dims. self.output_dims: int = min(output_dims, self.odims) # Load supervised(lw)/unsupervised(pca) model trained on training_set. if self.whitening is not None: whitening_models = torch.hub.load_state_dict_from_url( urls[self.kernel_type], map_location=lambda storage, loc: storage ) whitening_model = whitening_models[training_set] self.whitening_layer = Whitening( whitening, whitening_model, in_dims=self.odims, output_dims=self.output_dims ) self.odims = self.output_dims self.eval() def forward(self, patches: torch.Tensor) -> torch.Tensor: if not isinstance(patches, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(patches)}") if not len(patches.shape) == 4: raise ValueError(f"Invalid input shape, we expect Bx1xHxW. Got: {patches.shape}") # Extract gradients. g = self.smoothing(patches) g = self.gradients(g) # Extract polar/cart features. features = [] for parametrization in self.parametrizations: self.feats[parametrization].to(g.device) features.append(self.feats[parametrization](g)) # Concatenate. y = torch.cat(features, dim=1) # l2-normalize. y = F.normalize(y, dim=1) # Whiten descriptors. if self.whitening is not None: y = self.whitening_layer(y) return y def __repr__(self) -> str: return ( self.__class__.__name__ + '(' + 'patch_size=' + str(self.patch_size) + ', ' + 'kernel_type=' + str(self.kernel_type) + ', ' + 'whitening=' + str(self.whitening) + ', ' + 'training_set=' + str(self.training_set) + ', ' + 'output_dims=' + str(self.output_dims) + ')' ) def load_whitening_model(kernel_type: str, training_set: str) -> Dict: whitening_models = torch.hub.load_state_dict_from_url(urls[kernel_type], map_location=lambda storage, loc: storage) whitening_model = whitening_models[training_set] return whitening_model class SimpleKD(nn.Module): """Example to write custom Kernel Descriptors.""" def __init__( self, patch_size: int = 32, kernel_type: str = 'polar', # 'cart' 'polar' whitening: str = 'pcawt', # 'lw', 'pca', 'pcaws', 'pcawt training_set: str = 'liberty', # 'liberty', 'notredame', 'yosemite' output_dims: int = 128, ) -> None: super().__init__() relative: bool = kernel_type == 'polar' sigma: float = 1.4 * (patch_size / 64) self.patch_size = patch_size # Sequence of modules. smoothing = GaussianBlur2d((5, 5), (sigma, sigma), 'replicate') gradients = MKDGradients() ori = EmbedGradients(patch_size=patch_size, relative=relative) ese = ExplicitSpacialEncoding(kernel_type=kernel_type, fmap_size=patch_size, in_dims=ori.kernel.d) wh = Whitening( whitening, load_whitening_model(kernel_type, training_set), in_dims=ese.odims, output_dims=output_dims ) self.features = nn.Sequential(smoothing, gradients, ori, ese, wh) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.features(x)