cherrydata / ultralytics /Else /FSAS_U-Net.py
Voidljc
Your commit message
aa24fe8
Raw
History Blame Contribute Delete
3.33 kB
import torch
import torch.nn as nn
from einops import rearrange
class BiasFree_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(BiasFree_LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
def forward(self, x):
sigma = x.var(-1, keepdim=True, unbiased=False)
return x / torch.sqrt(sigma + 1e-5) * self.weight
class WithBias_LayerNorm(nn.Module):
def __init__(self, normalized_shape):
super(WithBias_LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
def forward(self, x):
mu = x.mean(-1, keepdim=True)
sigma = x.var(-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
class LayerNorm(nn.Module):
def __init__(self, dim, LayerNorm_type):
super(LayerNorm, self).__init__()
if LayerNorm_type == 'BiasFree':
self.body = BiasFree_LayerNorm(dim)
else:
self.body = WithBias_LayerNorm(dim)
def forward(self, x):
return self.body(x)
class FSAS(nn.Module):
def __init__(self, dim, bias=False):
super(FSAS, self).__init__()
self.to_hidden = nn.Linear(dim, dim * 6, bias=bias)
self.project_out = nn.Linear(dim * 2, dim, bias=bias)
self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias')
self.patch_size = 8
def forward(self, x):
hidden = self.to_hidden(x) # Shape: (batch_size, channels, height, width)
# Generate q, k, v tensors
q, k, v = self.to_hidden(hidden).chunk(3, dim=1) # Shape: (batch_size, channels, height, width)
# Process q and k in the frequency domain
q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size)
k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, patch2=self.patch_size)
q_fft = torch.fft.rfft2(q_patch.float())
k_fft = torch.fft.rfft2(k_patch.float())
out = q_fft * k_fft # Frequency domain multiplication
out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size))
out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size, patch2=self.patch_size)
# Normalize the output
out = self.norm(out)
# Adjust v's shape to match out
v = rearrange(v, 'b c h w -> b (h w) c') # Ensure matching shape
# Compute final output
output = v * out # Element-wise multiplication
output = self.project_out(output)
return output
if __name__ == '__main__':
# Instantiate FSAS model
dim = 64 # Example dimension, you can change it
model = FSAS(dim)
# Example input tensor with shape (batch_size, channels, height, width)
batch_size = 1
channels = dim
height = 32
width = 32
input_tensor = torch.randn(batch_size, height*width, channels)
# Forward pass through the model
output_tensor = model(input_tensor)
# Print input and output shapes
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output_tensor.shape}")