File size: 2,650 Bytes
b59223f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# ztrain/signal.py
# Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted

import torch

def gaussian_kernel(size, sigma=1.0):
    """
    Generates a 2D Gaussian kernel using PyTorch.

    Parameters:
    - size: The size of the kernel (an integer). It's recommended to use an odd number
            to have a central pixel.
    - sigma: The standard deviation of the Gaussian distribution.

    Returns:
    - A 2D PyTorch tensor representing the Gaussian kernel.
    """
    size = int(size) // 2
    x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1))
    g = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
    return g / g.sum()

def laplacian_kernel(size, scale=1.0):
    """
    Creates a Laplacian kernel for edge detection with an adjustable size and scale factor.

    Parameters:
    - size: The size of the kernel (an integer). It's recommended to use an odd number
            to ensure a central pixel.
    - scale: A float that adjusts the intensity of the edge detection effect.

    Returns:
    - A 2D PyTorch tensor representing the scaled Laplacian kernel.
    """
    if size % 2 == 0:
        raise ValueError("Size must be odd.")
    
    # Initialize the kernel with zeros
    kernel = torch.zeros((size, size), dtype=torch.float32)
    
    # Set the center pixel
    kernel[size // 2, size // 2] = -4.0
    
    # Set the immediate neighbors
    kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0
    kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0
    
    # For larger kernels, adjust the outer pixels (this simplistic approach might need refinement for larger sizes)
    if size > 3:
        for i in range(size):
            for j in range(size):
                if i == 0 or i == size - 1 or j == 0 or j == size - 1:
                    kernel[i, j] = 1.0

    # Apply the scale factor
    kernel *= scale
    
    # Adjust the kernel so that its sum is 0
    center = size // 2
    kernel[center, center] = -torch.sum(kernel) + kernel[center, center]
    
    return kernel

def fftshift(input):
    """
    Reorients the FFT output so the zero-frequency component is at the center.

    Parameters:
    - input: A 2D tensor representing the FFT output.

    Returns:
    - A 2D tensor with the zero-frequency component shifted to the center.
    """
    # For even dimensions, we split at dim_size // 2. For odd dimensions, we need to do (dim_size + 1) // 2
    for dim in range(2):  # assuming input is 2D
        n = input.shape[dim]
        half = (n + 1) // 2
        input = torch.roll(input, shifts=half, dims=dim)
    return input