from __future__ import absolute_import import numpy as np from dtcwt.numpy import Transform2d as Transform2d_np from dtcwt.numpy import Pyramid def appropriate_complex_type_for(X): """Return an appropriate complex data type depending on the type of X. If X is already complex, return that, if it is floating point return a complex type of the appropriate size and if it is integer, choose an complex floating point type depending on the result of :py:func:`numpy.asfarray`. """ X = asfarray(X) if np.issubsctype(X.dtype, np.complex64) or \ np.issubsctype(X.dtype, np.complex128): return X.dtype elif np.issubsctype(X.dtype, np.float32): return np.complex64 elif np.issubsctype(X.dtype, np.float64): return np.complex128 # God knows, err on the side of caution return np.complex128 def asfarray(X): """Similar to :py:func:`numpy.asfarray` except that this function tries to preserve the original datatype of X if it is already a floating point type and will pass floating point arrays through directly without copying. """ X = np.asanyarray(X) return np.asfarray(X, dtype=X.dtype) class Transform2d(object): """ An implementation of the 2D DT-CWT via numpy. Parameters ---------- biort: str or np.array The biorthogonal wavelet family to use. If a string, will use this to call pytorch_wavelets.dtcwt.coeffs.biort. If an array, will use these as the values. qshift: str or np.array The quarter shift wavelet family to use. If a string, will use this to call pytorch_wavelets.dtcwt.coeffs.biort. If an array, will use these as the values. .. note:: *biort* and *qshift* are the wavelets which parameterise the transform. If *biort* or *qshift* are strings, they are used as an argument to the :py:func:`dtcwt.coeffs.biort` or :py:func:`dtcwt.coeffs.qshift` functions. Otherwise, they are interpreted as tuples of vectors giving filter coefficients. In the *biort* case, this should be (h0o, g0o, h1o, g1o). In the *qshift* case, this should be (h0a, h0b, g0a, g0b, h1a, h1b, g1a, g1b). .. note:: .. codeauthor:: Fergal Cotter , Feb 2018 """ def __init__(self, biort='near_sym_a', qshift='qshift_a'): self.xfm = Transform2d_np(biort, qshift) def forward(self, X, nlevels=3, include_scale=False): """ Perform a forward transform on an image with multiple channels. Will perform the DTCWT independently on each channel. Data format for the input must have the height and width as the last 2 dimensions. Parameters ---------- X: np.array Input image which you wish to transform. Can be 2, 3, or 4 dimensions, but height and width must be the last 2. nlevels: int Number of levels of the dtcwt transform to calculate. include_scale: bool Whether or not to return the lowpass results at each scale of the transform, or only at the highest scale (as is custom for multiresolution analysis) Returns ------- Yl: ndarray Lowpass output Yh: list(ndarray) Highpass outputs. Will be complex and have one more dimension than the input representing the 6 orientations of the wavelets. This extra dimension will be the third last dimension. The first entry in the list is the first scale. Yscale: list(ndarray) Only returns if include_scale was true. A list of lowpass outputs at each scale. .. codeauthor:: Fergal Cotter , Feb 2018 """ # Reshape the inputs to all be 3d inputs of shape (batch, h, w) X = asfarray(X) s = X.shape if len(s) == 2: X = np.reshape(X, (1, *s)) elif len(s) == 4: X = np.reshape(X, (s[0]*s[1], s[2], s[3])) # Do the dtcwt now with a 3 dimensional input p = self.xfm.forward(X[0], nlevels, include_scale) Yl = np.zeros((X.shape[0], *p.lowpass.shape), dtype=X.dtype) Yh = [np.zeros((X.shape[0], 6, *p.highpasses[i].shape[0:2]), dtype=appropriate_complex_type_for(X)) for i in range(nlevels)] if include_scale: Yscale = [np.zeros((X.shape[0], *p.scales[i].shape), dtype=X.dtype) for i in range(nlevels)] Yl[0] = p.lowpass for i in range(nlevels): Yh[i][0] = p.highpasses[i].transpose((2,0,1)) if include_scale: Yscale[i][0] = p.scales[i] for n in range(1, X.shape[0]): p = self.xfm.forward(X[n], nlevels, include_scale) Yl[n] = p.lowpass for i in range(nlevels): Yh[i][n] = p.highpasses[i].transpose((2,0,1)) if include_scale: Yscale[i][n] = p.scales[i] # Reshape output to match input if len(s) == 2: Yl = Yl[0] Yh = [Yh[i][0] for i in range(nlevels)] if include_scale: Yscale = [Yscale[i][0] for i in range(nlevels)] elif len(s) == 4: Yl = np.reshape(Yl, (s[0], s[1], *Yl.shape[-2:])) Yh = [np.reshape(Yh[i], (s[0], s[1], *Yh[i].shape[-3:])) for i in range(nlevels)] if include_scale: Yscale = [np.reshape(Yscale[i], (s[0], s[1], *Yscale[i].shape[-2:])) for i in range(nlevels)] if include_scale: return Yl, Yh, Yscale else: return Yl, Yh def inverse(self, Yl, Yh, gain_mask=None): """ Perform an inverse transform on an image with multiple channels. Parameters ---------- Yl: ndarray The lowpass coefficients. Can be 2, 3, or 4 dimensions Yh: list(ndarray) The complex high pass coefficients. Must be compatible with the lowpass coefficients. Should have one more dimension. E.g if Yl was of shape [batch, ch, h, w], then the Yh's should be each of shape [batch, ch, 6, h', w'] (with h' and w' being dependent on the scale). gain_mask: None or ndarray Can use this to set subbands to have non-unit gain. Should be anarray of size [nlevels, 6] and can be complex or real. Useful for masking out subbands. Returns ------- X: ndarray An array , X, compatible with the reconstruction. .. codeauthor:: Fergal Cotter , Feb 2018 """ J = len(Yh) s = Yl.shape # Reshape the inputs to all be 3d inputs of shape (batch, h, w) if len(s) == 2: Yl = np.reshape(Yl, (1, *s)) Yh = [np.reshape(Yh[i], (1, *Yh[i].shape)) for i in range(J)] elif len(s) == 4: Yl = np.reshape(Yl, (s[0]*s[1], *s[-2:])) Yh = [np.reshape(Yh[i], (s[0]*s[1], *Yh[i].shape[-3:])) for i in range(J)] # Do the inverse dtcwt now with a 3 dimensional input X = self.xfm.inverse( Pyramid(Yl[0], [np.transpose(Yh[i][0], (1,2,0)) for i in range(J)]), gain_mask=gain_mask) x = np.zeros((Yl.shape[0], *X.shape), dtype=X.dtype) x[0] = X for n in range(1, Yl.shape[0]): X = self.xfm.inverse( Pyramid(Yl[n], [np.transpose(Yh[i][n], (1,2,0)) for i in range(J)])) x[n] = X # Reshape output to match input if len(s) == 2: X = x[0] elif len(s) == 4: X = np.reshape(x, (s[0], s[1], *x.shape[-2:])) else: X = x return X