| import torch | |
| import numpy as np | |
| import cv2 | |
| def get_fft_feature(x): | |
| """ | |
| Computes the Log-Magnitude Spectrum of the input images. | |
| Args: | |
| x (torch.Tensor): Input images of shape (B, C, H, W) | |
| Returns: | |
| torch.Tensor: Log-magnitude spectrum of shape (B, C, H, W) | |
| """ | |
| if x.dim() == 3: | |
| x = x.unsqueeze(0) | |
| # Compute 2D FFT | |
| fft = torch.fft.fft2(x, norm='ortho') | |
| # Compute magnitude | |
| mag = torch.abs(fft) | |
| # Apply log scale (add epsilon for stability) | |
| mag = torch.log(mag + 1e-6) | |
| # Shift zero-frequency component to the center of the spectrum | |
| mag = torch.fft.fftshift(mag, dim=(-2, -1)) | |
| return mag | |
| def min_max_normalize(tensor): | |
| """ | |
| Min-max normalization for visualization or stable training provided tensor. | |
| """ | |
| min_val = tensor.min() | |
| max_val = tensor.max() | |
| return (tensor - min_val) / (max_val - min_val + 1e-8) | |