| | import numpy as np |
| | import torch |
| |
|
| |
|
| | def dct(x, norm=None): |
| | x_shape = x.shape |
| | N = x_shape[-1] |
| | x = x.contiguous().view(-1, N) |
| |
|
| | v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1) |
| |
|
| | Vc = torch.view_as_real(torch.fft.fft(v, dim=1)) |
| |
|
| | k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N) |
| | W_r = torch.cos(k) |
| | W_i = torch.sin(k) |
| |
|
| | V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i |
| |
|
| | if norm == 'ortho': |
| | V[:, 0] /= np.sqrt(N) * 2 |
| | V[:, 1:] /= np.sqrt(N / 2) * 2 |
| |
|
| | V = 2 * V.view(*x_shape) |
| |
|
| | return V |
| |
|
| |
|
| | def idct(X, norm=None): |
| | x_shape = X.shape |
| | N = x_shape[-1] |
| |
|
| | X_v = X.contiguous().view(-1, x_shape[-1]) / 2 |
| |
|
| | if norm == 'ortho': |
| | X_v[:, 0] *= np.sqrt(N) * 2 |
| | X_v[:, 1:] *= np.sqrt(N / 2) * 2 |
| |
|
| | k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * np.pi / (2 * N) |
| | W_r = torch.cos(k) |
| | W_i = torch.sin(k) |
| |
|
| | V_t_r = X_v |
| | V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1) |
| |
|
| | V_r = V_t_r * W_r - V_t_i * W_i |
| | V_i = V_t_r * W_i + V_t_i * W_r |
| |
|
| | V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2) |
| |
|
| | |
| | v = torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1) |
| |
|
| | x = v.new_zeros(v.shape) |
| | x[:, ::2] += v[:, :N - (N // 2)] |
| | x[:, 1::2] += v.flip([1])[:, :N // 2] |
| |
|
| | return x.view(*x_shape) |
| |
|