| import torch |
| import torch.nn as nn |
| from einops import rearrange |
|
|
| class BiasFree_LayerNorm(nn.Module): |
| def __init__(self, normalized_shape): |
| super(BiasFree_LayerNorm, self).__init__() |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
|
| def forward(self, x): |
| sigma = x.var(-1, keepdim=True, unbiased=False) |
| return x / torch.sqrt(sigma + 1e-5) * self.weight |
|
|
|
|
| class WithBias_LayerNorm(nn.Module): |
| def __init__(self, normalized_shape): |
| super(WithBias_LayerNorm, self).__init__() |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
|
| def forward(self, x): |
| mu = x.mean(-1, keepdim=True) |
| sigma = x.var(-1, keepdim=True, unbiased=False) |
| return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim, LayerNorm_type): |
| super(LayerNorm, self).__init__() |
| if LayerNorm_type == 'BiasFree': |
| self.body = BiasFree_LayerNorm(dim) |
| else: |
| self.body = WithBias_LayerNorm(dim) |
|
|
| def forward(self, x): |
| return self.body(x) |
|
|
|
|
| class FSAS(nn.Module): |
| def __init__(self, dim, bias=False): |
| super(FSAS, self).__init__() |
|
|
| self.to_hidden = nn.Linear(dim, dim * 6, bias=bias) |
| self.project_out = nn.Linear(dim * 2, dim, bias=bias) |
|
|
| self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias') |
| self.patch_size = 8 |
|
|
| def forward(self, x): |
| hidden = self.to_hidden(x) |
|
|
| |
| q, k, v = self.to_hidden(hidden).chunk(3, dim=1) |
|
|
| |
| q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size) |
| k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size) |
|
|
| q_fft = torch.fft.rfft2(q_patch.float()) |
| k_fft = torch.fft.rfft2(k_patch.float()) |
|
|
| out = q_fft * k_fft |
| out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size)) |
| out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size, patch2=self.patch_size) |
|
|
| |
| out = self.norm(out) |
|
|
| |
| v = rearrange(v, 'b c h w -> b (h w) c') |
|
|
| |
| output = v * out |
| output = self.project_out(output) |
|
|
| return output |
|
|
|
|
| if __name__ == '__main__': |
| |
| dim = 64 |
|
|
| model = FSAS(dim) |
|
|
| |
| batch_size = 1 |
| channels = dim |
| height = 32 |
| width = 32 |
|
|
| input_tensor = torch.randn(batch_size, height*width, channels) |
|
|
| |
| output_tensor = model(input_tensor) |
|
|
| |
| print(f"Input shape: {input_tensor.shape}") |
| print(f"Output shape: {output_tensor.shape}") |