File size: 939 Bytes
d33766a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import numpy as np
import cv2

def get_fft_feature(x):
    """
    Computes the Log-Magnitude Spectrum of the input images.
    Args:
        x (torch.Tensor): Input images of shape (B, C, H, W)
    Returns:
        torch.Tensor: Log-magnitude spectrum of shape (B, C, H, W)
    """
    if x.dim() == 3:
        x = x.unsqueeze(0)
        
    # Compute 2D FFT
    fft = torch.fft.fft2(x, norm='ortho')
    
    # Compute magnitude
    mag = torch.abs(fft)
    
    # Apply log scale (add epsilon for stability)
    mag = torch.log(mag + 1e-6)
    
    # Shift zero-frequency component to the center of the spectrum
    mag = torch.fft.fftshift(mag, dim=(-2, -1))
    
    return mag

def min_max_normalize(tensor):
    """
    Min-max normalization for visualization or stable training provided tensor.
    """
    min_val = tensor.min()
    max_val = tensor.max()
    return (tensor - min_val) / (max_val - min_val + 1e-8)