File size: 3,329 Bytes
aa24fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from einops import rearrange

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5) * self.weight


class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias


class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        return self.body(x)


class FSAS(nn.Module):
    def __init__(self, dim, bias=False):
        super(FSAS, self).__init__()

        self.to_hidden = nn.Linear(dim, dim * 6, bias=bias)
        self.project_out = nn.Linear(dim * 2, dim, bias=bias)

        self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
        self.patch_size = 8

    def forward(self, x):
        hidden = self.to_hidden(x)  # Shape: (batch_size, channels, height, width)

        # Generate q, k, v tensors
        q, k, v = self.to_hidden(hidden).chunk(3, dim=1)  # Shape: (batch_size, channels, height, width)

        # Process q and k in the frequency domain
        q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size)
        k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size)

        q_fft = torch.fft.rfft2(q_patch.float())
        k_fft = torch.fft.rfft2(k_patch.float())

        out = q_fft * k_fft  # Frequency domain multiplication
        out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
        out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size, patch2=self.patch_size)

        # Normalize the output
        out = self.norm(out)

        # Adjust v's shape to match out
        v = rearrange(v, 'b c h w -> b (h w) c')  # Ensure matching shape

        # Compute final output
        output = v * out  # Element-wise multiplication
        output = self.project_out(output)

        return output


if __name__ == '__main__':
    # Instantiate FSAS model
    dim = 64  # Example dimension, you can change it

    model = FSAS(dim)

    # Example input tensor with shape (batch_size, channels, height, width)
    batch_size = 1
    channels = dim
    height = 32
    width = 32

    input_tensor = torch.randn(batch_size, height*width, channels)

    # Forward pass through the model
    output_tensor = model(input_tensor)

    # Print input and output shapes
    print(f"Input shape: {input_tensor.shape}")
    print(f"Output shape: {output_tensor.shape}")