DragStream / frequency_utils.py
bowmanchow's picture
add code
0328207
Raw
History Blame Contribute Delete
32.8 kB
import math
from typing import List, Sequence, Tuple, Union
import torch
from torch import Tensor
from PIL import Image
import numpy as np
import os
from pathlib import Path
torch.set_printoptions(
linewidth=10000,
)
def _get_center_distance(size: Tuple[int], device: str = "cpu") -> Tensor:
"""Compute the distance of each matrix element to the center.
Args:
size (Tuple[int]): [m, n].
device (str, optional): cpu/cuda. Defaults to 'cpu'.
Returns:
Tensor: [m, n].
"""
m, n = size
i_ind = torch.tile(
torch.tensor([[[i]] for i in range(m)], device=device), dims=[1, n, 1]
).float() # [m, n, 1]
j_ind = torch.tile(
torch.tensor([[[i] for i in range(n)]], device=device), dims=[m, 1, 1]
).float() # [m, n, 1]
ij_ind = torch.cat([i_ind, j_ind], dim=-1) # [m, n, 2]
ij_ind = ij_ind.reshape([m * n, 1, 2]) # [m * n, 1, 2]
center_ij = torch.tensor(((m - 1) / 2, (n - 1) / 2), device=device).reshape(1, 2)
center_ij = torch.tile(center_ij, dims=[m * n, 1, 1])
dist = torch.cdist(ij_ind, center_ij, p=2).reshape([m, n])
return dist
def _get_ideal_weights(
size: Tuple[int], D0: int, lowpass: bool = True, device: str = "cpu"
) -> Tensor:
"""Get H(u, v) of ideal bandpass filter.
Args:
size (Tuple[int]): [H, W].
D0 (int): The cutoff frequency.
lowpass (bool): True for low-pass filter, otherwise for high-pass filter. Defaults to True.
device (str, optional): cpu/cuda. Defaults to 'cpu'.
Returns:
Tensor: [H, W].
"""
center_distance = _get_center_distance(size, device)
center_distance[center_distance > D0] = -1
center_distance[center_distance != -1] = 1
if lowpass is True:
center_distance[center_distance == -1] = 0
else:
center_distance[center_distance == 1] = 0
center_distance[center_distance == -1] = 1
return center_distance
def _to_freq(image: Tensor) -> Tensor:
"""Convert from spatial domain to frequency domain.
Args:
image (Tensor): [B, C, H, W].
Returns:
Tensor: [B, C, H, W]
"""
img_fft = torch.fft.fft2(image)
img_fft_shift = torch.fft.fftshift(img_fft)
return img_fft_shift
def _to_space(image_fft: Tensor) -> Tensor:
"""Convert from frequency domain to spatial domain.
Args:
image_fft (Tensor): [B, C, H, W].
Returns:
Tensor: [B, C, H, W].
"""
img_ifft_shift = torch.fft.ifftshift(image_fft)
img_ifft = torch.fft.ifft2(img_ifft_shift)
img = img_ifft.real.clamp(0, 1)
return img
def ideal_bandpass(image: Tensor, D0: int, lowpass: bool = True) -> Tensor:
"""Low-pass filter for images.
Args:
image (Tensor): [B, C, H, W].
D0 (int): Cutoff frequency.
lowpass (bool): True for low-pass filter, otherwise for high-pass filter. Defaults to True.
Returns:
Tensor: [B, C, H, W].
"""
img_fft = _to_freq(image)
weights = _get_ideal_weights(img_fft.shape[-2:], D0=D0, lowpass=lowpass, device=image.device)
img_fft = img_fft * weights
img = _to_space(img_fft)
return img
# Butterworth
def _get_butterworth_weights(size: Tuple[int], D0: int, n: int, device: str = "cpu") -> Tensor:
"""Get H(u, v) of Butterworth filter.
Args:
size (Tuple[int]): [H, W].
D0 (int): The cutoff frequency.
n (int): Order of Butterworth filters.
device (str, optional): cpu/cuda. Defaults to 'cpu'.
Returns:
Tensor: [H, W].
"""
center_distance = _get_center_distance(size=size, device=device)
weights = 1 / (1 + torch.pow(center_distance / D0, 2 * n))
return weights
def butterworth(image: Tensor, D0: int, n: int) -> Tensor:
"""Butterworth low-pass filter for images.
Args:
image (Tensor): [B, C, H, W].
D0 (int): Cutoff frequency.
n (int): Order of the Butterworth low-pass filter.
Returns:
Tensor: [B, C, H, W].
"""
img_fft = _to_freq(image)
weights = _get_butterworth_weights(image.shape[-2:], D0, n, device=image.device)
img_fft = weights * img_fft
img = _to_space(img_fft)
return img
# def my_butterworth_low_pass_filter(
# shape,
# stop_freqs: List[float],
# n=4,
# ):
# assert len(shape) == len(stop_freqs)
# grid = torch.meshgrid(
# *[torch.arange(s, dtype=torch.float32) for s in shape],
# indexing='ij',
# )
# # ( [shape[0], shape[1], ..., shape[N]] ) * len(shape)
# indices = torch.stack(grid, dim=-1).float()
# # print(f"{indices.shape = }")
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# max_len = torch.tensor(shape).float()
# max_len -= 1.0
# max_len /= 2.0
# # print(f"{max_len = }")
# # print(f"{max_len.shape = }")
# # [len(shape)]
# max_len = max_len.view(*([1]*len(shape)), -1)
# # print(f"{max_len.shape = }")
# # [1, 1, ..., 1, len(shape)]
# normalized_indices = indices / max_len
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# normalized_indices_offset = normalized_indices - 1
# # print(f"{normalized_indices_offset.shape = }")
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# stop_freqs_torch = torch.tensor(stop_freqs).float().view(*([1]*len(shape)), -1)
# # print(f"{stop_freqs_torch.shape = }")
# # [1, 1, ..., 1, len(shape)]
# scaled_normalized_indices_offset = normalized_indices_offset / stop_freqs_torch
# # print(f"{scaled_normalized_indices_offset.shape = }")
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# filter_ = 1.0 / (1.0 + torch.pow(scaled_normalized_indices_offset.norm(p=2, dim=-1), 2 * n))
# return filter_
# def my_butterworth_low_pass_filter_non_center(
# shape,
# stop_freqs: List[float],
# n=4,
# ):
# new_shape = [
# 2*i-1
# for i in shape
# ]
# filter_ = my_butterworth_low_pass_filter(
# new_shape,
# n=n,
# stop_freqs=stop_freqs,
# )
# if len(shape) == 1:
# crop_filter = filter_[-shape[0]:]
# elif len(shape) == 2:
# crop_filter = filter_[-shape[0]:, -shape[1]:]
# elif len(shape) == 3:
# crop_filter = filter_[-shape[0]:, -shape[1]:, -shape[2]:]
# else:
# raise ValueError("Shape must be 1D, 2D, or 3D.")
# return crop_filter
# def my_butterworth_high_pass_filter(
# shape,
# stop_freqs: List[float],
# n=4,
# ):
# assert len(shape) == len(stop_freqs)
# grid = torch.meshgrid(
# *[torch.arange(s, dtype=torch.float32) for s in shape],
# indexing='ij',
# )
# # ( [shape[0], shape[1], ..., shape[N]] ) * len(shape)
# indices = torch.stack(grid, dim=-1).float()
# # print(f"{indices.shape = }")
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# max_len = torch.tensor(shape).float()
# max_len -= 1.0
# max_len /= 2.0
# # print(f"{max_len = }")
# # print(f"{max_len.shape = }")
# # [len(shape)]
# max_len = max_len.view(*([1]*len(shape)), -1)
# # print(f"{max_len.shape = }")
# # [1, 1, ..., 1, len(shape)]
# normalized_indices = indices / max_len
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# normalized_indices_offset = normalized_indices - 1
# # print(f"{normalized_indices_offset.shape = }")
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# stop_freqs_torch = torch.tensor(stop_freqs).float().view(*([1]*len(shape)), -1)
# # print(f"{stop_freqs_torch.shape = }")
# # [1, 1, ..., 1, len(shape)]
# scaled_normalized_indices_offset = stop_freqs_torch / normalized_indices_offset
# # print(f"{scaled_normalized_indices_offset.shape = }")
# # [shape[0], shape[1], ..., shape[N], len(shape)]
# filter_ = 1.0 / (1.0 + torch.pow(scaled_normalized_indices_offset.norm(p=2, dim=-1), 2 * n))
# return filter_
# def my_butterworth_high_pass_filter_non_center(
# shape,
# stop_freqs: List[float],
# n=4,
# ):
# new_shape = [
# 2*i-1
# for i in shape
# ]
# filter_ = my_butterworth_high_pass_filter(
# new_shape,
# n=n,
# stop_freqs=stop_freqs,
# )
# if len(shape) == 1:
# crop_filter = filter_[-shape[0]:]
# elif len(shape) == 2:
# crop_filter = filter_[-shape[0]:, -shape[1]:]
# elif len(shape) == 3:
# crop_filter = filter_[-shape[0]:, -shape[1]:, -shape[2]:]
# else:
# raise ValueError("Shape must be 1D, 2D, or 3D.")
# return crop_filter
# ------------------------ Image loading ------------------------
def load_grayscale_image():
# Try common sample images; fall back to skimage if available; else ask user to put an image in cwd
candidates = ["onion.png", "cameraman.tif", "peppers.png", "lena.png", "camera.png"]
for name in candidates:
if os.path.exists(name):
# img = Image.open(name).convert('L')
img = Image.open(name).convert("RGB")
image_np = np.asarray(img, dtype=np.float64)
# print(f"{image_np = }")
image_np = image_np / 255.0
# print(f"{image_np = }")
return image_np
raise FileNotFoundError(
"Could not find a local image. Place an image (e.g., cameraman.tif/peppers.png) in the working directory."
)
# ------------------------ DCT implementations (orthonormal) ------------------------
def dct2_matrix_ortho(N, device="cpu", dtype=torch.float32):
# T2[k, n] = sqrt(2/N) * beta(k) * cos(pi/N * (n + 0.5) * k), beta(0)=1/sqrt(2)
n = torch.arange(N, device=device, dtype=dtype)
k = torch.arange(N, device=device, dtype=dtype).unsqueeze(1)
W = torch.cos(math.pi / N * (n + 0.5) * k) # [N, N]
beta = torch.ones(N, device=device, dtype=dtype)
beta[0] = 1 / math.sqrt(2.0)
T = (math.sqrt(2.0 / N) * beta).unsqueeze(1) * W
return T # orthonormal; inverse is T.T
def dct1_matrix_ortho(N, device="cpu", dtype=torch.float32):
# T1[k, n] = sqrt(2/(N-1)) * alpha(k) * alpha(n) * cos(pi/(N-1) * n*k)
# alpha(0)=alpha(N-1)=1/sqrt(2), else 1. Self-inverse (orthonormal and symmetric).
if N < 2:
# N=1 trivial case
return torch.ones((1, 1), device=device, dtype=dtype)
n = torch.arange(N, device=device, dtype=dtype)
k = torch.arange(N, device=device, dtype=dtype).unsqueeze(1)
C = torch.cos(math.pi / (N - 1) * (n * k)) # [N, N]
alpha = torch.ones(N, device=device, dtype=dtype)
alpha[0] = 1 / math.sqrt(2.0)
alpha[-1] = 1 / math.sqrt(2.0)
T = math.sqrt(2.0 / (N - 1)) * (alpha.unsqueeze(1) * C * alpha.unsqueeze(0))
return T # orthonormal, symmetric, self-inverse
def dct2_ortho(x, T2=None):
# x: [N] float tensor. Returns DCT-II (orthonormal) [N].
x = x.reshape(-1)
N = x.numel()
if T2 is None:
T2 = dct2_matrix_ortho(N, device=x.device, dtype=x.dtype)
return T2 @ x
def idct2_ortho(X, T2=None):
# Inverse of DCT-II (orthonormal) is transpose
X = X.reshape(-1)
N = X.numel()
if T2 is None:
T2 = dct2_matrix_ortho(N, device=X.device, dtype=X.dtype)
return T2.t() @ X
def dct1_ortho(x, T1=None):
# x: [N] float tensor. Returns DCT-I (orthonormal) [N].
x = x.reshape(-1)
N = x.numel()
if T1 is None:
T1 = dct1_matrix_ortho(N, device=x.device, dtype=x.dtype)
return T1 @ x
def idct1_ortho(X, T1=None):
# DCT-I orthonormal is self-inverse
X = X.reshape(-1)
N = X.numel()
if T1 is None:
T1 = dct1_matrix_ortho(N, device=X.device, dtype=X.dtype)
return T1 @ X
def _complex_dtype_from_real(real_dtype):
if real_dtype == torch.float32:
return torch.complex64
if real_dtype == torch.float64:
return torch.complex128
raise TypeError("Only float32/float64 supported.")
def dct2_fft(x, dim=-1, norm="ortho"):
"""
DCT-II via even-symmetric 2N extension and torch.fft.rfft.
x: real tensor (..., N)
Returns: real tensor (..., N)
norm: 'ortho' (orthonormal, like scipy.fft.dct(..., type=2, norm='ortho')) or None (unnormalized).
"""
if not torch.is_floating_point(x):
raise TypeError("x must be float tensor")
N = x.shape[dim]
if N < 1:
return x.clone()
# Even extension [x, flip(x)]
x_flip = torch.flip(x, dims=(dim,))
s = torch.cat([x, x_flip], dim=dim) # (..., 2N)
# RFFT over length 2N
S = torch.fft.rfft(s, n=2 * N, dim=dim) # (..., N+1)
# k = 0..N-1
k = torch.arange(N, device=x.device, dtype=x.dtype)
# exp(-j*pi*k/(2N))
ctype = _complex_dtype_from_real(x.dtype)
twiddle = torch.exp(-1j * math.pi * k / (2.0 * N)).to(dtype=ctype, device=x.device)
for _ in range(dim, S.dim() - 1):
twiddle = twiddle.unsqueeze(-1)
# Take real part; factor 1/2 (see derivation)
C = (S.narrow(dim, 0, N) * twiddle).real * 0.5 # (..., N)
if norm == "ortho":
# Orthonormal scaling: sqrt(2/N) * beta(k), beta(0)=1/sqrt(2)
C = C * math.sqrt(2.0 / N)
index0 = [slice(None)] * C.dim()
index0[dim] = 0
C[tuple(index0)] /= math.sqrt(2.0)
elif norm is None:
pass
else:
raise ValueError("norm must be 'ortho' or None")
return C
def idct2_fft(C, dim=-1, norm="ortho"):
"""
Inverse of dct2_fft (i.e., DCT-III) using torch.fft.irfft.
C: real tensor (..., N) with same norm used in dct2_fft.
Returns real tensor (..., N).
"""
if not torch.is_floating_point(C):
raise TypeError("C must be float tensor")
N = C.shape[dim]
if N < 1:
return C.clone()
# Undo orthonormal scaling to get "unnormalized" DCT-II coefficients
Cun = C
if norm == "ortho":
Cun = C / math.sqrt(2.0 / N)
index0 = [slice(None)] * Cun.dim()
index0[dim] = 0
Cun = Cun.clone()
Cun[tuple(index0)] *= math.sqrt(2.0)
elif norm is None:
Cun = C
else:
raise ValueError("norm must be 'ortho' or None")
# Build unique half-spectrum (length N+1) for the 2N-length irfft
# S[k] = 2*Cun[k] * exp(+j*pi*k/(2N)), for k=0..N-1
k = torch.arange(N, device=C.device, dtype=C.dtype)
ctype = _complex_dtype_from_real(C.dtype)
twiddle = torch.exp(+1j * math.pi * k / (2.0 * N)).to(dtype=ctype, device=C.device)
for _ in range(dim, C.dim() - 1):
twiddle = twiddle.unsqueeze(-1)
# Allocate (..., N+1)
new_shape = list(Cun.shape)
new_shape[dim] = N + 1
S_half = torch.zeros(*new_shape, dtype=ctype, device=C.device)
# Fill 0..N-1
# real times complex -> cast below
S_part = (2.0 * Cun) * twiddle.real - 0j
S_part = (2.0 * Cun).to(ctype) * twiddle
S_half.narrow(dim, 0, N).copy_(S_part)
# Nyquist (k=N) is zero for the chosen even-symmetric extension
indexN = [slice(None)] * S_half.dim()
indexN[dim] = N
S_half[tuple(indexN)] = 0
# irfft to length 2N, take first N samples
s = torch.fft.irfft(S_half, n=2 * N, dim=dim) # (..., 2N)
# Slice first N along dim
x = s.narrow(dim, 0, N)
return x
# --------- N-D (multi-axis) DCT-II / IDCT-II built from the 1D versions ---------
def _normalize_dims(dims, ndim):
if isinstance(dims, int):
dims = (dims,)
dims = tuple(d if d >= 0 else d + ndim for d in dims)
if any(d < 0 or d >= ndim for d in dims):
raise ValueError("dims out of range for input tensor.")
# You can enforce uniqueness if desired:
if len(set(dims)) != len(dims):
raise ValueError("dims must be unique.")
return dims
def dct2_nd_fft(x, dims, norm="ortho"):
"""
N-D DCT-II applied along the specified dimensions.
x: real tensor
dims: tuple of axes (e.g., (-2,-1) for 2D, (-3,-2,-1) for 3D)
norm: 'ortho' or None
"""
dims = _normalize_dims(dims, x.ndim)
y = x
for d in dims:
y = dct2_fft(y, dim=d, norm=norm)
return y
def idct2_nd_fft(X, dims, norm="ortho"):
"""
N-D inverse of DCT-II (DCT-III) along the specified dimensions.
"""
dims = _normalize_dims(dims, X.ndim)
y = X
for d in dims:
y = idct2_fft(y, dim=d, norm=norm)
return y
def _to_device_dtype(x, device, dtype):
if device is None:
device = x.device if isinstance(x, torch.Tensor) else "cpu"
if dtype is None:
dtype = torch.float64 # match MATLAB double
return device, dtype
def _omega_grid_1d(N, shifted, device, dtype):
# Digital radian frequency samples on FFT bins.
# unshifted: ω_k = 2π k / N, k=0..N-1 (DC at index 0)
# shifted: fftshift layout (DC at center), monotonically increasing from negative to positive
k = torch.arange(N, device=device, dtype=dtype)
w = 2.0 * math.pi * k / N
# [0, 2π)
if shifted:
w = torch.fft.fftshift(w) # center DC
return w
def _tan_half_abs(w, eps=1e-12):
# Safe |tan(w/2)| to avoid overflow at w=π.
half = 0.5 * w
c = torch.cos(half)
s = torch.sin(half)
# Where cos is near zero, use a very large value (approach infinity)
# large but not inf to avoid NaNs downstream
large = torch.finfo(w.dtype).max ** 0.5
t = torch.where(c.abs() < eps, torch.sign(s) * large, s / c)
return t.abs()
def butterworth_mask_1d(
N,
fc,
order,
btype="low",
shifted=False,
device=None,
dtype=None,
):
"""
1D Butterworth frequency mask equivalent to MATLAB butter+freqz magnitude.
- N: number of FFT bins
- fc: normalized cutoff(s) in cycles/sample (relative to 1 sample) with 0 < fc < 0.5
low/high: scalar; bandpass/stop: [f1, f2] with 0 < f1 < f2 < 0.5
* fc is equivalent to Wn / 2 in MATLAB's butter function. e.g. butter(4, 0.25) is equivalent to fc=0.125 here.
- order: integer >= 1
- btype: 'low', 'high', 'bandpass', 'stop'
- shifted: if True, return mask in fftshift layout (DC at center)
"""
assert isinstance(N, int) and N >= 2
assert isinstance(order, int) and order >= 1
btype = btype.lower()
if btype in ("low", "high"):
fc = float(fc)
assert 0.0 < fc < 0.5
else:
assert len(fc) == 2
f1, f2 = float(fc[0]), float(fc[1])
assert 0.0 < f1 < f2 < 0.5
fc = (f1, f2)
device, dtype = _to_device_dtype(torch.empty(0), device, dtype)
w = _omega_grid_1d(N, shifted=shifted, device=device, dtype=dtype) # 0..2π (or centered)
# Bilinear mapping (prewarped): Ω = 2 * tan(ω/2)
Om = 2.0 * _tan_half_abs(w) # analog rad/sec (normalized T=1)
if btype == "low":
# Prewarp analog cutoff: Ωc = 2*tan(π*fc)
Oc = 2.0 * math.tan(math.pi * fc)
ratio = (Om / Oc).clamp_min(0)
mag = 1.0 / torch.sqrt(1.0 + ratio.pow(2 * order))
elif btype == "high":
Oc = 2.0 * math.tan(math.pi * fc)
# Handle Om=0 => magnitude=0
ratio = torch.where(Om > 0, (Oc / Om), torch.full_like(Om, float("inf")))
mag = 1.0 / torch.sqrt(1.0 + ratio.pow(2 * order))
elif btype == "bandpass":
f1, f2 = fc
O1 = 2.0 * math.tan(math.pi * f1)
O2 = 2.0 * math.tan(math.pi * f2)
B = O2 - O1
O0 = math.sqrt(O1 * O2)
# D(Ω) = (Ω^2 - Ω0^2)/(B*Ω)
denom = B * Om
# denom=0 at Om=0 -> D=inf, magnitude=0
D = torch.where(denom != 0, (Om.pow(2) - O0**2) / denom, torch.full_like(Om, float("inf")))
mag = 1.0 / torch.sqrt(1.0 + D.abs().pow(2 * order))
elif btype in ("stop", "bandstop", "bandreject"):
f1, f2 = fc
O1 = 2.0 * math.tan(math.pi * f1)
O2 = 2.0 * math.tan(math.pi * f2)
B = O2 - O1
O0 = math.sqrt(O1 * O2)
# D(Ω) = (B*Ω)/(Ω^2 - Ω0^2)
denom = Om.pow(2) - O0**2
# denom=0 at Om=O0 -> D=inf, magnitude=0
D = torch.where(denom != 0, (B * Om) / denom, torch.full_like(Om, float("inf")))
mag = 1.0 / torch.sqrt(1.0 + D.abs().pow(2 * order))
else:
raise ValueError("btype must be 'low', 'high', 'bandpass', or 'stop'.")
return mag.to(dtype=dtype, device=device)
def butterworth_mask_2d_separable(
shape,
fc,
order,
btype="low",
shifted=False,
device=None,
dtype=None,
):
"""
2D separable Butterworth mask (rows × cols), equivalent to applying 1D Butterworth along rows and columns (zero-phase). Not an isotropic circular Butterworth.
- shape: (M, N)
- fc: scalar or 2-tuple for low/high; for band types, pass 2-tuples for each axis: ([f1y,f2y], [f1x,f2x]) You can also pass scalar or 2-tuple to apply same cutoffs on both axes.
- order: integer or 2-tuple for (order_y, order_x)
- btype: 'low', 'high', 'bandpass', 'stop'
- shifted: if True, both axes are centered (fftshift layout)
"""
M, N = int(shape[0]), int(shape[1])
assert M >= 2 and N >= 2
device, dtype = _to_device_dtype(torch.empty(0), device, dtype)
# Normalize fc/order to per-axis tuples
if btype in ("low", "high"):
if not isinstance(fc, (list, tuple)):
fcy = fcx = fc
else:
assert len(fc) == 2
fcy, fcx = fc
else:
# band types
if isinstance(fc[0], (list, tuple)) and isinstance(fc[1], (list, tuple)):
fcy, fcx = fc
else:
# same band on both axes
fcy = fcx = fc
if isinstance(order, (list, tuple)):
oy, ox = int(order[0]), int(order[1])
else:
oy = ox = int(order)
Hy = butterworth_mask_1d(M, fcy, oy, btype=btype, shifted=shifted, device=device, dtype=dtype)
Hx = butterworth_mask_1d(N, fcx, ox, btype=btype, shifted=shifted, device=device, dtype=dtype)
# Outer product to build separable 2D mask
H2 = Hy.reshape(M, 1) * Hx.reshape(1, N)
return H2
def _freqvec_norm(
N: int,
shifted: bool,
device=None,
dtype=None,
):
"""
Normalized frequency vector in [-0.5, 0.5), length N.
- shifted=False: DC at index 0 (unshifted FFT layout)
- shifted=True: DC at center (fftshift layout)
"""
if device is None:
device = "cpu"
if dtype is None:
dtype = torch.float64
k = torch.arange(N, device=device, dtype=dtype)
if shifted:
f = (k - torch.floor(torch.tensor(N / 2, dtype=dtype, device=device))) / N
else:
f = k / N
f = torch.where(f >= 0.5, f - 1.0, f) # wrap into [-0.5, 0.5)
return f # [N]
def _radial_frequency_nd(
shape: Sequence[int],
shifted: bool,
device=None,
dtype=None,
):
"""
Radial normalized frequency R in [-0.5,0.5) computed over all axes.
Returns R with shape 'shape'.
"""
if device is None:
device = "cpu"
if dtype is None:
dtype = torch.float64
grids = [_freqvec_norm(N, shifted=shifted, device=device, dtype=dtype) for N in shape]
# list of tensors, each shape = shape
meshes = torch.meshgrid(*grids, indexing="ij")
R2 = torch.zeros(shape, dtype=dtype, device=device)
for g in meshes:
R2 = R2 + g**2
R = torch.sqrt(R2)
return R
def butterworth_nd(
shape: Sequence[int],
cutoff: Union[float, Tuple[float, float]],
order: int,
btype: str = "low",
shifted: bool = False,
device=None,
dtype=None,
):
"""Isotropic N-D Butterworth mask (low/high/bandpass/bandstop).
Args:
shape: iterable of ints, e.g., (H, W) or (D, H, W) ...
cutoff:
- 'low'/'high': scalar D0 in (0, 0.5]
- 'bandpass'/'bandstop': tuple (D1, D2) with 0 < D1 < D2 <= 0.5
order: integer >= 1
btype: 'low' | 'high' | 'bandpass' | 'bandstop' (alias 'stop')
shifted: if True, mask is centered (fftshift layout); else unshifted
device, dtype: optional torch device/dtype (defaults: CPU, float64)
Returns:
H: tensor with shape 'shape', values in [0, 1].
"""
assert len(shape) >= 1 and all(int(s) >= 1 for s in shape), "Invalid shape."
order = int(order)
assert order >= 1, "order must be >= 1"
btype = btype.lower()
if btype in ("low", "high"):
D0 = float(cutoff)
# assert 0.0 < D0 <= 0.5, "cutoff must be in (0, 0.5]"
else:
D1, D2 = float(cutoff[0]), float(cutoff[1])
# assert 0.0 < D1 < D2 <= 0.5, "for band types: 0 < D1 < D2 <= 0.5"
B = D2 - D1
D0 = math.sqrt(D1 * D2)
if device is None:
device = "cpu"
if dtype is None:
dtype = torch.float64
R = _radial_frequency_nd(
tuple(int(s) for s in shape), shifted=shifted, device=device, dtype=dtype
)
eps = torch.finfo(dtype).eps
# print(f"{R = }")
if btype == "low":
# H = 1 / (1 + (R/D0)^(2n))
ratio = (R / D0).clamp_min(0)
H = 1.0 / (1.0 + ratio.pow(2 * order))
elif btype == "high":
# H = 1 / (1 + (D0/R)^(2n)), H(DC)=0
# avoid divide-by-zero at R=0
safe_R = torch.where(R > 0, R, torch.tensor(1.0, device=device, dtype=dtype)) # dummy
ratio = D0 / safe_R
H = 1.0 / (1.0 + ratio.pow(2 * order))
# enforce DC = 0
H = torch.where(R > 0, H, torch.zeros_like(H))
elif btype == "bandpass":
# D = (R^2 - D0^2) / (B*R); H = 1 / (1 + |D|^(2n))
# Handle R=0 -> D=inf -> H=0
denom = B * R
D = torch.where(denom != 0, (R.pow(2) - D0**2) / denom, torch.full_like(R, float("inf")))
H = 1.0 / (1.0 + D.abs().pow(2 * order))
elif btype in ("bandstop", "stop", "bandreject"):
# D = (B*R) / (R^2 - D0^2); H = 1 / (1 + |D|^(2n))
# Handle R^2 - D0^2 = 0 -> D=inf -> H=0 (deep notch at R=D0)
denom = R.pow(2) - D0**2
D = torch.where(denom != 0, (B * R) / denom, torch.full_like(R, float("inf")))
H = 1.0 / (1.0 + D.abs().pow(2 * order))
else:
raise ValueError("btype must be 'low', 'high', 'bandpass', or 'bandstop'.")
return H
def butterworth_low_pass_filter(
tensor: torch.Tensor,
dims: Sequence[int],
cutoff: float,
order: int,
shifted: bool = False,
device=None,
dtype=None,
):
"""
Applies a Butterworth low-pass filter to the input tensor.
the dims specify which dim should be perform filtering
return filtered tensor
"""
if not isinstance(dims, (list, tuple)):
dims = (dims,)
ndims_total = tensor.ndim
# Normalize dims (handle negatives)
norm_dims = _normalize_dims(dims, ndim=ndims_total)
original_dtype = tensor.dtype
work_dtype = dtype or (tensor.dtype if torch.is_floating_point(tensor) else torch.float32)
if work_dtype == torch.bfloat16 or work_dtype == torch.float16:
work_dtype = torch.float32
device = device or tensor.device
# Prepare frequency-domain representation
x = tensor.to(device=device, dtype=work_dtype)
X = torch.fft.fftn(x, dim=norm_dims)
if shifted:
X = torch.fft.fftshift(X, dim=norm_dims)
# Build isotropic Butterworth mask over the selected dims
shape_subset = [x.shape[d] for d in norm_dims]
H_small = butterworth_nd(
shape=shape_subset,
cutoff=cutoff,
order=order,
btype="low",
shifted=shifted,
device=device,
dtype=work_dtype,
)
# Broadcast mask into full tensor shape
mask_shape = [1] * ndims_total
for i, d in enumerate(norm_dims):
mask_shape[d] = shape_subset[i]
H = H_small.view(*mask_shape)
# Apply mask
X_filtered = X * H
# Inverse FFT
if shifted:
X_filtered = torch.fft.ifftshift(X_filtered, dim=norm_dims)
x_filtered = torch.fft.ifftn(X_filtered, dim=norm_dims).real
return x_filtered.to(dtype=original_dtype)
# def fft_denoise(tensor, dim, fft_ratio):
# assert len(dim) == 2
# original_dtype = tensor.dtype
# tensor = tensor.to(torch.float32)
# # Create low pass filter
# LPF = butterworth_low_pass_filter(
# (tensor.shape[dim[0]], tensor.shape[dim[1]]),
# n=4,
# d_s=fft_ratio,
# )
# LPF = LPF.to(dtype=tensor.dtype, device=tensor.device)
# # print(f"{LPF = }")
# # print(f"{LPF.shape = }")
# for _ in range(dim[0]):
# LPF = LPF.unsqueeze(0)
# for _ in range(dim[1] + 1, len(tensor.shape)):
# LPF = LPF.unsqueeze(-1)
# # print(f"{LPF.shape = }")
# # FFT
# latents_freq_k = torch.fft.fftn(tensor, dim=dim)
# # print(f"{latents_freq_k.shape = }")
# latents_freq_k = torch.fft.fftshift(latents_freq_k, dim=dim)
# # print(f"{latents_freq_k.shape = }")
# new_freq_k = latents_freq_k * LPF
# # IFFT
# new_freq_k = torch.fft.ifftshift(new_freq_k, dim=dim)
# denoised_k = torch.fft.ifftn(new_freq_k, dim=dim).real
# denoised_k = denoised_k.to(original_dtype)
# return denoised_k
if __name__ == "__main__":
# x = torch.linspace(0, 2 * np.pi, 8)
# y = torch.linspace(0, 2 * np.pi, 8)
# X, Y = torch.meshgrid(x, y, indexing='ij')
# latents = (
# torch.sin(2 * X + Y) +
# torch.sin(X + 3 * Y) +
# torch.sin(3 * X - 2 * Y)
# ) + 1
# latents += 0.01 * torch.randn_like(latents) # Add Gaussian noise
# # latents = torch.randn([8, 8])
# print(f"latents = \n{latents}")
# latents_freq = torch.fft.fftn(latents, dim=(-2, -1))
# print(f"latents_freq = \n{torch.abs(latents_freq)}")
# latents_freq_shift = torch.fft.fftshift(latents_freq, dim=(-2, -1))
# print(f"latents_freq_shift = \n{torch.abs(latents_freq_shift)}")
# latents_freq_dct = dct_2d(latents)
# print(f"latents_freq_dct = \n{latents_freq_dct}")
# LPF_1 = butterworth_low_pass_filter(latents=latents, d_s=-1.0)
# print(f"LPF_1 = \n{LPF_1}")
# LPF_2 = my_butterworth_low_pass_filter_non_center(
# shape=latents.shape,
# stop_freqs=[0.25, 0.25],
# n=4,
# )
# print(f"LPF_2 = \n{LPF_2}")
# LPF_3 = my_butterworth_low_pass_filter(
# shape=latents.shape,
# stop_freqs=[0.25, 0.25],
# n=4,
# )
# print(f"LPF_3 = \n{LPF_3}")
# img = load_grayscale_image()
# # Extract middle column as 1-D signal
# col = img.shape[1] // 2 - 1
# print(f"{col = }")
# x_np = img[:, col].astype(np.float32) # [H]
# # print(f"{x_np = }")
# N = x_np.shape[0]
# print(f"{N = }")
# device = 'cpu'
# dtype = torch.float64
# x = torch.from_numpy(img).to(device=device, dtype=dtype)
# print(f"{x = }")
# # Transforms
# Xf = torch.fft.fftn(x, dim=(-3, -2, -1), norm=None) # complex64
# print(f"{Xf = }")
# x_reconstructed = torch.fft.ifftn(Xf, dim=(-3, -2, -1), norm=None)
# print(f"{x_reconstructed = }")
# print(f"{(x - x_reconstructed).abs().max() = }")
# Xd2 = dct2_nd_fft(x, dims=(-3, -2, -1), norm="ortho") # float
# print(f"{Xd2 = }")
# x_reconstructed = idct2_nd_fft(Xd2, dims=(-1, -2, -3), norm="ortho")
# print(f"{x_reconstructed = }")
# print(f"{(x - x_reconstructed).abs().max() = }")
# H1 = butterworth_mask_1d(16, 0.125, 4, btype='low', shifted=True)
# print(f"{H1 = }")
H2 = butterworth_nd([30, 52], 1.0, 4, btype="low", shifted=True)
print(f"{H2 = }")
# ---- Planar wave demo with Butterworth low-pass filtering ----
def demo_planar_wave():
# Generate 2D planar wave: low-frequency + added high-frequency component
H, W = 128, 128
device = "cpu"
y = torch.arange(H, device=device).view(H, 1)
x = torch.arange(W, device=device).view(1, W)
# Low-frequency component
kx_low, ky_low = 2, 3
low = torch.sin(2 * math.pi * (kx_low * x / W + ky_low * y / H))
# High-frequency component
kx_high, ky_high = 20, 24
high = 0.5 * torch.sin(2 * math.pi * (kx_high * x / W + ky_high * y / H))
signal = low + high
# Apply Butterworth low-pass (cutoff chosen to keep low freq, attenuate high freq)
cutoff = 0.12 # normalized radial cutoff (<=0.5)
order = 4
filtered = butterworth_low_pass_filter(
signal, dims=(-2, -1), cutoff=cutoff, order=order, shifted=True
)
# Metrics
mse_before = (signal - low).pow(2).mean()
mse_after = (filtered - low).pow(2).mean()
residual_energy_ratio = (filtered - low).pow(2).sum() / (signal - low).pow(2).sum()
print("Planar wave demo:")
print(f"mse_before={mse_before.item():.6e}")
print(f"mse_after ={mse_after.item():.6e}")
print(f"residual_energy_ratio={residual_energy_ratio.item():.4%}")
# Quick sanity: high frequency suppression (should be << 1)
assert (
mse_after < mse_before
), "Filtering did not reduce error to low-frequency ground truth."
demo_planar_wave()