| | import numpy.fft as ffto |
| |
|
| | from autograd.extend import defvjp, primitive, vspace |
| |
|
| | from . import numpy_wrapper as anp |
| | from .numpy_vjps import match_complex |
| | from .numpy_wrapper import wrap_namespace |
| |
|
| | wrap_namespace(ffto.__dict__, globals()) |
| |
|
| |
|
| | |
| | |
| | def fft_grad(get_args, fft_fun, ans, x, *args, **kwargs): |
| | axes, s, norm = get_args(x, *args, **kwargs) |
| | check_no_repeated_axes(axes) |
| | vs = vspace(x) |
| | return lambda g: match_complex(x, truncate_pad(fft_fun(g, *args, **kwargs), vs.shape)) |
| |
|
| |
|
| | defvjp(fft, lambda *args, **kwargs: fft_grad(get_fft_args, fft, *args, **kwargs)) |
| | defvjp(ifft, lambda *args, **kwargs: fft_grad(get_fft_args, ifft, *args, **kwargs)) |
| |
|
| | defvjp(fft2, lambda *args, **kwargs: fft_grad(get_fft_args, fft2, *args, **kwargs)) |
| | defvjp(ifft2, lambda *args, **kwargs: fft_grad(get_fft_args, ifft2, *args, **kwargs)) |
| |
|
| | defvjp(fftn, lambda *args, **kwargs: fft_grad(get_fft_args, fftn, *args, **kwargs)) |
| | defvjp(ifftn, lambda *args, **kwargs: fft_grad(get_fft_args, ifftn, *args, **kwargs)) |
| |
|
| |
|
| | def rfft_grad(get_args, irfft_fun, ans, x, *args, **kwargs): |
| | axes, s, norm = get_args(x, *args, **kwargs) |
| | vs = vspace(x) |
| | gvs = vspace(ans) |
| | check_no_repeated_axes(axes) |
| | if s is None: |
| | s = [vs.shape[i] for i in axes] |
| | check_even_shape(s) |
| |
|
| | |
| | |
| | gs = list(s) |
| | gs[-1] = gs[-1] // 2 + 1 |
| | fac = make_rfft_factors(axes, gvs.shape, gs, s, norm) |
| |
|
| | def vjp(g): |
| | g = anp.conj(g / fac) |
| | r = match_complex(x, truncate_pad((irfft_fun(g, *args, **kwargs)), vs.shape)) |
| | return r |
| |
|
| | return vjp |
| |
|
| |
|
| | def irfft_grad(get_args, rfft_fun, ans, x, *args, **kwargs): |
| | axes, gs, norm = get_args(x, *args, **kwargs) |
| | vs = vspace(x) |
| | gvs = vspace(ans) |
| | check_no_repeated_axes(axes) |
| | if gs is None: |
| | gs = [gvs.shape[i] for i in axes] |
| | check_even_shape(gs) |
| |
|
| | |
| | |
| | s = list(gs) |
| | s[-1] = s[-1] // 2 + 1 |
| |
|
| | def vjp(g): |
| | r = match_complex(x, truncate_pad((rfft_fun(g, *args, **kwargs)), vs.shape)) |
| | fac = make_rfft_factors(axes, vs.shape, s, gs, norm) |
| | r = anp.conj(r) * fac |
| | return r |
| |
|
| | return vjp |
| |
|
| |
|
| | defvjp(rfft, lambda *args, **kwargs: rfft_grad(get_fft_args, irfft, *args, **kwargs)) |
| |
|
| | defvjp(irfft, lambda *args, **kwargs: irfft_grad(get_fft_args, rfft, *args, **kwargs)) |
| |
|
| | defvjp(rfft2, lambda *args, **kwargs: rfft_grad(get_fft2_args, irfft2, *args, **kwargs)) |
| |
|
| | defvjp(irfft2, lambda *args, **kwargs: irfft_grad(get_fft2_args, rfft2, *args, **kwargs)) |
| |
|
| | defvjp(rfftn, lambda *args, **kwargs: rfft_grad(get_fftn_args, irfftn, *args, **kwargs)) |
| |
|
| | defvjp(irfftn, lambda *args, **kwargs: irfft_grad(get_fftn_args, rfftn, *args, **kwargs)) |
| |
|
| | defvjp( |
| | fftshift, lambda ans, x, axes=None: lambda g: match_complex(x, anp.conj(ifftshift(anp.conj(g), axes))) |
| | ) |
| | defvjp( |
| | ifftshift, lambda ans, x, axes=None: lambda g: match_complex(x, anp.conj(fftshift(anp.conj(g), axes))) |
| | ) |
| |
|
| |
|
| | @primitive |
| | def truncate_pad(x, shape): |
| | |
| | slices = [slice(n) for n in shape] |
| | pads = tuple( |
| | zip(anp.zeros(len(shape), dtype=int), anp.maximum(0, anp.array(shape) - anp.array(x.shape))) |
| | ) |
| | return anp.pad(x, pads, "constant")[tuple(slices)] |
| |
|
| |
|
| | defvjp(truncate_pad, lambda ans, x, shape: lambda g: match_complex(x, truncate_pad(g, vspace(x).shape))) |
| |
|
| |
|
| | |
| | def check_no_repeated_axes(axes): |
| | axes_set = set(axes) |
| | if len(axes) != len(axes_set): |
| | raise NotImplementedError("FFT gradient for repeated axes not implemented.") |
| |
|
| |
|
| | def check_even_shape(shape): |
| | if shape[-1] % 2 != 0: |
| | raise NotImplementedError("Real FFT gradient for odd lengthed last axes is not implemented.") |
| |
|
| |
|
| | def get_fft_args(a, d=None, axis=-1, norm=None, *args, **kwargs): |
| | axes = [axis] |
| | if d is not None: |
| | d = [d] |
| | return axes, d, norm |
| |
|
| |
|
| | def get_fft2_args(a, s=None, axes=(-2, -1), norm=None, *args, **kwargs): |
| | return axes, s, norm |
| |
|
| |
|
| | def get_fftn_args(a, s=None, axes=None, norm=None, *args, **kwargs): |
| | if axes is None: |
| | axes = list(range(a.ndim)) |
| | return axes, s, norm |
| |
|
| |
|
| | def make_rfft_factors(axes, resshape, facshape, normshape, norm): |
| | """make the compression factors and compute the normalization |
| | for irfft and rfft. |
| | """ |
| | N = 1.0 |
| | for n in normshape: |
| | N = N * n |
| |
|
| | |
| | |
| | |
| | |
| | fac = anp.zeros(resshape) |
| | fac[...] = 2 |
| | index = [slice(None)] * len(resshape) |
| | if facshape[-1] <= resshape[axes[-1]]: |
| | index[axes[-1]] = (0, facshape[-1] - 1) |
| | else: |
| | index[axes[-1]] = (0,) |
| | fac[tuple(index)] = 1 |
| | if norm is None: |
| | fac /= N |
| | return fac |
| |
|