import torch import torch.nn as nn from einops import rearrange class BiasFree_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(BiasFree_LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) def forward(self, x): sigma = x.var(-1, keepdim=True, unbiased=False) return x / torch.sqrt(sigma + 1e-5) * self.weight class WithBias_LayerNorm(nn.Module): def __init__(self, normalized_shape): super(WithBias_LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) def forward(self, x): mu = x.mean(-1, keepdim=True) sigma = x.var(-1, keepdim=True, unbiased=False) return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias class LayerNorm(nn.Module): def __init__(self, dim, LayerNorm_type): super(LayerNorm, self).__init__() if LayerNorm_type == 'BiasFree': self.body = BiasFree_LayerNorm(dim) else: self.body = WithBias_LayerNorm(dim) def forward(self, x): return self.body(x) class FSAS(nn.Module): def __init__(self, dim, bias=False): super(FSAS, self).__init__() self.to_hidden = nn.Linear(dim, dim * 6, bias=bias) self.project_out = nn.Linear(dim * 2, dim, bias=bias) self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias') self.patch_size = 8 def forward(self, x): hidden = self.to_hidden(x) # Shape: (batch_size, channels, height, width) # Generate q, k, v tensors q, k, v = self.to_hidden(hidden).chunk(3, dim=1) # Shape: (batch_size, channels, height, width) # Process q and k in the frequency domain q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size) k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size) q_fft = torch.fft.rfft2(q_patch.float()) k_fft = torch.fft.rfft2(k_patch.float()) out = q_fft * k_fft # Frequency domain multiplication out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size)) out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size, patch2=self.patch_size) # Normalize the output out = self.norm(out) # Adjust v's shape to match out v = rearrange(v, 'b c h w -> b (h w) c') # Ensure matching shape # Compute final output output = v * out # Element-wise multiplication output = self.project_out(output) return output if __name__ == '__main__': # Instantiate FSAS model dim = 64 # Example dimension, you can change it model = FSAS(dim) # Example input tensor with shape (batch_size, channels, height, width) batch_size = 1 channels = dim height = 32 width = 32 input_tensor = torch.randn(batch_size, height*width, channels) # Forward pass through the model output_tensor = model(input_tensor) # Print input and output shapes print(f"Input shape: {input_tensor.shape}") print(f"Output shape: {output_tensor.shape}")