WACA-UNet / model.py
ymin98's picture
Upload 15 files
dd5a134 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
import math
from typing import List, Tuple, Optional, Dict
class WACA_CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(WACA_CBAM, self).__init__()
self.channels = channels
if channels < reduction or channels // reduction == 0:
self.reduced_channels = channels // 2 if channels > 1 else 1
else:
self.reduced_channels = channels // reduction
self.fc_layers = nn.Sequential(
nn.Conv2d(self.channels, self.reduced_channels, kernel_size=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(self.reduced_channels, self.channels, kernel_size=1, bias=False)
)
self.spatial_attn = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
nn.Sigmoid()
)
def forward(self, x):
avg_pool = F.adaptive_avg_pool2d(x, 1)
max_pool = F.adaptive_max_pool2d(x, 1)
avg_out = self.fc_layers(avg_pool)
max_out = self.fc_layers(max_pool)
gate_logits = avg_out + max_out
weakness_scores = torch.sigmoid(-gate_logits)
attn_scores = torch.sigmoid(gate_logits)
gated_weak = x * weakness_scores
squeezed_2_avg = F.adaptive_avg_pool2d(gated_weak, 1)
squeezed_2_max = F.adaptive_max_pool2d(gated_weak, 1)
gate_logits_2 = self.fc_layers(squeezed_2_avg+ squeezed_2_max) # current
# gate_logits_2 = self.fc_layers(squeezed_2_avg) + self.fc_layers(squeezed_2_max) # naive
attn_scores_2 = torch.sigmoid(gate_logits_2)
gated_attn = x * (attn_scores + attn_scores_2) * 0.5
# Spatial Attention (CBAM)
avg_out = torch.mean(gated_attn, dim=1, keepdim=True)
max_out, _ = torch.max(gated_attn, dim=1, keepdim=True)
sa_input = torch.cat([avg_out, max_out], dim=1)
sa_weight = self.spatial_attn(sa_input)
out = gated_attn * sa_weight
return out
##################################################################################
# copy from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
class LayerNorm(nn.Module):
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class GRN(nn.Module):
""" GRN (Global Response Normalization) layer
"""
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
#################################################################################################
import torch
import torch.nn as nn
from torch.nn import functional as F
class ConvNeXtV2BlockWACA_Atrous(nn.Module):
def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilation=3):
super().__init__()
# Atrous (dilated) depthwise convolution
# dilation을 적용하면서 같은 receptive field를 유지하기 위해 padding 조정
padding = dilation * 3 # kernel_size=7이므로 (7-1)//2 * dilation
self.dwconv = nn.Conv2d(
in_ch, in_ch,
kernel_size=7,
padding=padding,
groups=in_ch,
dilation=dilation # atrous convolution 적용
)
self.norm = LayerNorm(in_ch, eps=1e-6)
self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
self.act = nn.GELU()
self.grn = GRN(4 * in_ch)
self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
self.fow = WACA_CBAM(out_ch, reduction=reduction)
self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input_x = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
x = self.fow(x)
x = self.drop_path(x)
out = self.proj(input_x) + x
return out
# Multi-scale atrous convolution을 사용하는 버전
class ConvNeXtV2BlockWACA_MultiAtrous(nn.Module):
def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 2, 4]):
super().__init__()
# 여러 dilation rate를 가진 depthwise convolution들
self.dwconv_branches = nn.ModuleList([
nn.Conv2d(
in_ch, in_ch // len(dilations),
kernel_size=7,
padding=d * 3, # kernel_size=7에 대한 padding
groups=in_ch // len(dilations),
dilation=d
) for d in dilations
])
# 브랜치들을 합친 후 원래 채널 수로 맞추기
self.combine_conv = nn.Conv2d(in_ch, in_ch, 1)
self.norm = LayerNorm(in_ch, eps=1e-6)
self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
self.act = nn.GELU()
self.grn = GRN(4 * in_ch)
self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
self.fow = WACA_CBAM(out_ch, reduction=reduction)
self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input_x = x
# Multi-scale atrous convolution
branch_outputs = []
for i, dwconv in enumerate(self.dwconv_branches):
# 각 브랜치에 해당하는 채널 선택
channels_per_branch = x.size(1) // len(self.dwconv_branches)
start_idx = i * channels_per_branch
end_idx = (i + 1) * channels_per_branch if i < len(self.dwconv_branches) - 1 else x.size(1)
branch_input = x[:, start_idx:end_idx, :, :]
branch_outputs.append(dwconv(branch_input))
# 모든 브랜치 출력을 concatenate
x = torch.cat(branch_outputs, dim=1)
x = self.combine_conv(x)
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
x = self.fow(x)
x = self.drop_path(x)
out = self.proj(input_x) + x
return out
# ASPP (Atrous Spatial Pyramid Pooling) 스타일의 버전
class ConvNeXtV2BlockWACA_ASPP(nn.Module):
def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 6, 12, 18]):
super().__init__()
# ASPP 스타일의 parallel atrous convolutions
self.aspp_branches = nn.ModuleList()
for dilation in dilations:
if dilation == 1:
# 첫 번째 브랜치는 일반 convolution
branch = nn.Conv2d(in_ch, in_ch // len(dilations), 1)
else:
# 나머지는 atrous convolution
branch = nn.Conv2d(
in_ch, in_ch // len(dilations),
kernel_size=3,
padding=dilation,
dilation=dilation,
groups=in_ch // len(dilations)
)
self.aspp_branches.append(branch)
# Global Average Pooling branch
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_ch, in_ch // len(dilations), 1),
)
# 모든 브랜치를 합치는 convolution
total_channels = (len(dilations) + 1) * (in_ch // len(dilations))
self.combine_conv = nn.Conv2d(total_channels, in_ch, 1)
self.norm = LayerNorm(in_ch, eps=1e-6)
self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
self.act = nn.GELU()
self.grn = GRN(4 * in_ch)
self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
self.fow = WACA_CBAM(out_ch, reduction=reduction)
self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input_x = x
h, w = x.size()[2:]
# ASPP branches
branch_outputs = []
for branch in self.aspp_branches:
branch_outputs.append(branch(x))
# Global average pooling branch
global_feat = self.global_avg_pool(x)
global_feat = F.interpolate(global_feat, size=(h, w), mode='bilinear', align_corners=False)
branch_outputs.append(global_feat)
# Concatenate all branches
x = torch.cat(branch_outputs, dim=1)
x = self.combine_conv(x)
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
x = self.fow(x)
x = self.drop_path(x)
out = self.proj(input_x) + x
return out
#################################################################################################
class ConvNeXtV2BlockWACA(nn.Module):
def __init__(self, in_ch, out_ch, reduction=16, drop_path=0.,use_grn=True):
super().__init__()
self.dwconv = nn.Conv2d(in_ch, in_ch, kernel_size=7, padding=3, groups=in_ch)
self.norm = LayerNorm(in_ch, eps=1e-6)
self.pwconv1 = nn.Linear(in_ch, 4 * in_ch)
self.act = nn.GELU()
self.grn = GRN(4 * in_ch) if use_grn else nn.Identity()
self.pwconv2 = nn.Linear(4 * in_ch, out_ch)
self.fow = WACA_CBAM(out_ch,reduction=reduction)
self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
input_x = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
x = self.fow(x)
x = self.drop_path(x)
out = self.proj(input_x) + x
return out
class AttentionGate(nn.Module):
def __init__(self, in_ch_x, in_ch_g, out_ch):
super().__init__()
self.act = nn.ReLU(inplace=True)
self.w_x_g = nn.Conv2d(in_ch_x + in_ch_g, out_ch, kernel_size=1, stride=1, padding=0, bias=False)
self.attn = nn.Conv2d(out_ch, out_ch, kernel_size=1, padding=0, bias=False)
def forward(self, x, g):
res = x
xg = torch.cat([x, g], dim=1) # B, (x_c+g_c), H, W
xg = self.w_x_g(xg)
xg = self.act(xg)
attn = torch.sigmoid(self.attn(xg))
out = res * attn
return out
class WACA_Unet(nn.Module):
def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16,
depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs):
super().__init__()
self.depth = depth
chs = [base_ch * 2**i for i in range(depth+1)]
self.drop_path = drop_path
n_enc_blocks = depth + 1
n_dec_blocks = depth
total_blocks = n_enc_blocks + n_dec_blocks
drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist()
enc_dp_rates = drop_path_rates[:n_enc_blocks]
dec_dp_rates = drop_path_rates[n_enc_blocks:]
# Encoder
self.enc_blocks = nn.ModuleList([
block(in_ch, chs[0], reduction, drop_path=enc_dp_rates[0])
] + [
block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1])
for i in range(depth)
])
self.pool = nn.ModuleList([
nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i])
for i in range(depth)
])
# Decoder
self.upconvs = nn.ModuleList([
nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2)
for i in reversed(range(depth))
])
self.dec_blocks = nn.ModuleList([
block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i])
for i in reversed(range(depth))
])
# Attention Gates
self.attn_gates = nn.ModuleList([
AttentionGate(chs[i], chs[i], chs[i])
for i in reversed(range(depth))
])
self.final_head = nn.Sequential(
nn.Conv2d(chs[0], out_ch, kernel_size=1)
)
def forward(self, x):
enc_feats = []
for i, enc in enumerate(self.enc_blocks):
x = enc(x)
enc_feats.append(x)
if i < self.depth:
x = self.pool[i](x)
# Decoder
for i in range(self.depth):
x = self.upconvs[i](x)
enc_feat = enc_feats[self.depth-1-i]
# AttentionGate: (encoder feature, decoder upconv output)
attn_enc_feat = self.attn_gates[i](enc_feat, x)
x = torch.cat([attn_enc_feat, x], dim=1)
x = self.dec_blocks[i](x)
out = self.final_head(x)
return {
'x_recon': out
}
###############################################################################
from torch.nn.utils.rnn import pack_padded_sequence
class GRUStem(nn.Module):
"""
Zero-padded variable channels. 각 채널을 공유 인코더(phi)로 임베딩 후,
채널 축을 시간축으로 간주해 BiGRU로 통합.
"""
def __init__(self, out_channels: int = 64, embed_channels: int = 16, small_input: bool = True):
super().__init__()
stride = 1 if small_input else 2
self.phi = nn.Sequential(
nn.Conv2d(1, embed_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(embed_channels),
nn.ReLU(inplace=True),
)
hidden = out_channels // 2
assert hidden > 0, "out_channels must be >=2 to support BiGRU"
self.gru = nn.GRU(input_size=embed_channels, hidden_size=hidden,
num_layers=1, bidirectional=True)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.ReLU(inplace=True)
self.out_channels = out_channels
self.small_input = small_input
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, Cmax, H, W] with zero-padded channels
B, Cmax, H, W = x.shape
# non-zero channel lengths
with torch.no_grad():
nonzero_ch = (x.abs().sum(dim=(2, 3)) > 0) # [B, Cmax]
lengths = nonzero_ch.sum(dim=1).clamp(min=1) # [B]
# shared encoder φ for each channel
feat_per_c = [self.phi(x[:, c:c+1, :, :]) for c in range(Cmax)] # list of [B,E,H',W']
Fstack = torch.stack(feat_per_c, dim=0) # [Cmax, B, E, H', W']
Cseq, Bsz, E, Hp, Wp = Fstack.shape
# sequence for GRU: [T=Cmax, N=B*Hp*Wp, E]
Fseq = Fstack.permute(0, 1, 3, 4, 2).contiguous().view(Cseq, Bsz * Hp * Wp, E)
lens = lengths.repeat_interleave(Hp * Wp).cpu() # [N]
packed = pack_padded_sequence(Fseq, lens, enforce_sorted=False)
_, h_n = self.gru(packed) # [2, N, hidden]
h_cat = torch.cat([h_n[0], h_n[1]], dim=-1) # [N, out_ch]
out = h_cat.view(Bsz, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() # [B,out,H',W']
out = self.act(self.bn(out))
return out
##################################################################
class _PoolDownMixin:
def __init__(self, small_input: bool):
self._stride = 1 if small_input else 2
def _maybe_down(self, y: torch.Tensor) -> torch.Tensor:
if self._stride == 2:
return F.avg_pool2d(y, 2)
return y
class FourierStem2D(nn.Module, _PoolDownMixin):
def __init__(self, out_dim=64, basis="chebyshev", small_input: bool = True):
nn.Module.__init__(self); _PoolDownMixin.__init__(self, small_input)
assert basis in ("fourier", "chebyshev")
self.out_dim = out_dim; self.basis = basis
self.proj = nn.Linear(out_dim, out_dim)
self._basis_cache: Dict[Tuple[int, str, torch.device, torch.dtype], torch.Tensor] = {}
def _get_basis(self, C, device, dtype):
key = (C, self.basis, device, dtype)
if key in self._basis_cache: return self._basis_cache[key]
idx = torch.linspace(-1, 1, C, device=device, dtype=dtype).unsqueeze(0)
if self.basis == "fourier":
B = torch.stack([torch.cos(idx * i * math.pi) for i in range(1, self.out_dim+1)], dim=-1)
else:
B = torch.stack([torch.cos(i * torch.acos(idx)) for i in range(1, self.out_dim+1)], dim=-1)
self._basis_cache[key] = B; return B
def forward(self, x: torch.Tensor):
B, C, H, W = x.shape; device, dtype = x.device, x.dtype
x_flat = x.permute(0,2,3,1).reshape(-1, C) # [BHW,C]
basis = self._get_basis(C, device, dtype)[0] # [C,D]
emb = (x_flat @ basis) # [BHW,D]
emb = self.proj(emb).view(B, H, W, self.out_dim).permute(0,3,1,2)
return self._maybe_down(emb)
class WACA_Unet_stem(nn.Module):
def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16,
depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs):
super().__init__()
self.depth = depth
chs = [base_ch * 2**i for i in range(depth+1)]
self.drop_path = drop_path
n_enc_blocks = depth + 1
n_dec_blocks = depth
total_blocks = n_enc_blocks + n_dec_blocks
drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist()
enc_dp_rates = drop_path_rates[:n_enc_blocks]
dec_dp_rates = drop_path_rates[n_enc_blocks:]
# self.stem = GRUStem(out_channels=16,embed_channels=16,small_input=True)
self.stem = FourierStem2D(chs[0])
# self.up0 = nn.Upsample(scale_factor=2,mode='bicubic')
# Encoder
self.enc_blocks = nn.ModuleList([
block(chs[0], chs[0], reduction, drop_path=enc_dp_rates[0])
] + [
block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1])
for i in range(depth)
])
self.pool = nn.ModuleList([
nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i])
for i in range(depth)
])
# Decoder
self.upconvs = nn.ModuleList([
nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2)
for i in reversed(range(depth))
])
self.dec_blocks = nn.ModuleList([
block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i])
for i in reversed(range(depth))
])
# Attention Gates
self.attn_gates = nn.ModuleList([
AttentionGate(chs[i], chs[i], chs[i])
for i in reversed(range(depth))
])
self.final_head = nn.Sequential(
nn.Conv2d(chs[0], out_ch, kernel_size=1)
)
def forward(self, x):
enc_feats = []
x = self.stem(x)
# x = self.up0(x)
for i, enc in enumerate(self.enc_blocks):
x = enc(x)
enc_feats.append(x)
if i < self.depth:
x = self.pool[i](x)
# Decoder
for i in range(self.depth):
x = self.upconvs[i](x)
enc_feat = enc_feats[self.depth-1-i]
# AttentionGate: (encoder feature, decoder upconv output)
attn_enc_feat = self.attn_gates[i](enc_feat, x)
x = torch.cat([attn_enc_feat, x], dim=1)
x = self.dec_blocks[i](x)
out = self.final_head(x)
return {
'x_recon': out
}
if __name__ == '__main__':
for block in [ConvNeXtV2BlockWACA]:
print(f"Testing block: {block.__name__}")
model = WACA_Unet_stem(in_ch=25, out_ch=1,block=block, depth=4)
dummy = torch.randn(2,25, 384, 384)
out = model(dummy)['x_recon']
print(f"Input shape: {dummy.shape}")
print(f"Output shape: {out.shape}")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params:,}")