import torch import torch.nn as nn import torch.nn.functional as F from timm.layers import DropPath import math from typing import List, Tuple, Optional, Dict class WACA_CBAM(nn.Module): def __init__(self, channels, reduction=16): super(WACA_CBAM, self).__init__() self.channels = channels if channels < reduction or channels // reduction == 0: self.reduced_channels = channels // 2 if channels > 1 else 1 else: self.reduced_channels = channels // reduction self.fc_layers = nn.Sequential( nn.Conv2d(self.channels, self.reduced_channels, kernel_size=1, bias=False), nn.ReLU(inplace=True), nn.Conv2d(self.reduced_channels, self.channels, kernel_size=1, bias=False) ) self.spatial_attn = nn.Sequential( nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), nn.Sigmoid() ) def forward(self, x): avg_pool = F.adaptive_avg_pool2d(x, 1) max_pool = F.adaptive_max_pool2d(x, 1) avg_out = self.fc_layers(avg_pool) max_out = self.fc_layers(max_pool) gate_logits = avg_out + max_out weakness_scores = torch.sigmoid(-gate_logits) attn_scores = torch.sigmoid(gate_logits) gated_weak = x * weakness_scores squeezed_2_avg = F.adaptive_avg_pool2d(gated_weak, 1) squeezed_2_max = F.adaptive_max_pool2d(gated_weak, 1) gate_logits_2 = self.fc_layers(squeezed_2_avg+ squeezed_2_max) # current # gate_logits_2 = self.fc_layers(squeezed_2_avg) + self.fc_layers(squeezed_2_max) # naive attn_scores_2 = torch.sigmoid(gate_logits_2) gated_attn = x * (attn_scores + attn_scores_2) * 0.5 # Spatial Attention (CBAM) avg_out = torch.mean(gated_attn, dim=1, keepdim=True) max_out, _ = torch.max(gated_attn, dim=1, keepdim=True) sa_input = torch.cat([avg_out, max_out], dim=1) sa_weight = self.spatial_attn(sa_input) out = gated_attn * sa_weight return out ################################################################################## # copy from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py class LayerNorm(nn.Module): """ LayerNorm that supports two data formats: channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). """ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape, ) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class GRN(nn.Module): """ GRN (Global Response Normalization) layer """ def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x ################################################################################################# import torch import torch.nn as nn from torch.nn import functional as F class ConvNeXtV2BlockWACA_Atrous(nn.Module): def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilation=3): super().__init__() # Atrous (dilated) depthwise convolution # dilation을 적용하면서 같은 receptive field를 유지하기 위해 padding 조정 padding = dilation * 3 # kernel_size=7이므로 (7-1)//2 * dilation self.dwconv = nn.Conv2d( in_ch, in_ch, kernel_size=7, padding=padding, groups=in_ch, dilation=dilation # atrous convolution 적용 ) self.norm = LayerNorm(in_ch, eps=1e-6) self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) self.act = nn.GELU() self.grn = GRN(4 * in_ch) self.pwconv2 = nn.Linear(4 * in_ch, out_ch) self.fow = WACA_CBAM(out_ch, reduction=reduction) self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input_x = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # BCHW -> BHWC x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) # BHWC -> BCHW x = self.fow(x) x = self.drop_path(x) out = self.proj(input_x) + x return out # Multi-scale atrous convolution을 사용하는 버전 class ConvNeXtV2BlockWACA_MultiAtrous(nn.Module): def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 2, 4]): super().__init__() # 여러 dilation rate를 가진 depthwise convolution들 self.dwconv_branches = nn.ModuleList([ nn.Conv2d( in_ch, in_ch // len(dilations), kernel_size=7, padding=d * 3, # kernel_size=7에 대한 padding groups=in_ch // len(dilations), dilation=d ) for d in dilations ]) # 브랜치들을 합친 후 원래 채널 수로 맞추기 self.combine_conv = nn.Conv2d(in_ch, in_ch, 1) self.norm = LayerNorm(in_ch, eps=1e-6) self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) self.act = nn.GELU() self.grn = GRN(4 * in_ch) self.pwconv2 = nn.Linear(4 * in_ch, out_ch) self.fow = WACA_CBAM(out_ch, reduction=reduction) self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input_x = x # Multi-scale atrous convolution branch_outputs = [] for i, dwconv in enumerate(self.dwconv_branches): # 각 브랜치에 해당하는 채널 선택 channels_per_branch = x.size(1) // len(self.dwconv_branches) start_idx = i * channels_per_branch end_idx = (i + 1) * channels_per_branch if i < len(self.dwconv_branches) - 1 else x.size(1) branch_input = x[:, start_idx:end_idx, :, :] branch_outputs.append(dwconv(branch_input)) # 모든 브랜치 출력을 concatenate x = torch.cat(branch_outputs, dim=1) x = self.combine_conv(x) x = x.permute(0, 2, 3, 1) # BCHW -> BHWC x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) # BHWC -> BCHW x = self.fow(x) x = self.drop_path(x) out = self.proj(input_x) + x return out # ASPP (Atrous Spatial Pyramid Pooling) 스타일의 버전 class ConvNeXtV2BlockWACA_ASPP(nn.Module): def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 6, 12, 18]): super().__init__() # ASPP 스타일의 parallel atrous convolutions self.aspp_branches = nn.ModuleList() for dilation in dilations: if dilation == 1: # 첫 번째 브랜치는 일반 convolution branch = nn.Conv2d(in_ch, in_ch // len(dilations), 1) else: # 나머지는 atrous convolution branch = nn.Conv2d( in_ch, in_ch // len(dilations), kernel_size=3, padding=dilation, dilation=dilation, groups=in_ch // len(dilations) ) self.aspp_branches.append(branch) # Global Average Pooling branch self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(in_ch, in_ch // len(dilations), 1), ) # 모든 브랜치를 합치는 convolution total_channels = (len(dilations) + 1) * (in_ch // len(dilations)) self.combine_conv = nn.Conv2d(total_channels, in_ch, 1) self.norm = LayerNorm(in_ch, eps=1e-6) self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) self.act = nn.GELU() self.grn = GRN(4 * in_ch) self.pwconv2 = nn.Linear(4 * in_ch, out_ch) self.fow = WACA_CBAM(out_ch, reduction=reduction) self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input_x = x h, w = x.size()[2:] # ASPP branches branch_outputs = [] for branch in self.aspp_branches: branch_outputs.append(branch(x)) # Global average pooling branch global_feat = self.global_avg_pool(x) global_feat = F.interpolate(global_feat, size=(h, w), mode='bilinear', align_corners=False) branch_outputs.append(global_feat) # Concatenate all branches x = torch.cat(branch_outputs, dim=1) x = self.combine_conv(x) x = x.permute(0, 2, 3, 1) # BCHW -> BHWC x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) # BHWC -> BCHW x = self.fow(x) x = self.drop_path(x) out = self.proj(input_x) + x return out ################################################################################################# class ConvNeXtV2BlockWACA(nn.Module): def __init__(self, in_ch, out_ch, reduction=16, drop_path=0.,use_grn=True): super().__init__() self.dwconv = nn.Conv2d(in_ch, in_ch, kernel_size=7, padding=3, groups=in_ch) self.norm = LayerNorm(in_ch, eps=1e-6) self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) self.act = nn.GELU() self.grn = GRN(4 * in_ch) if use_grn else nn.Identity() self.pwconv2 = nn.Linear(4 * in_ch, out_ch) self.fow = WACA_CBAM(out_ch,reduction=reduction) self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): input_x = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # BCHW -> BHWC x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 3, 1, 2) # BHWC -> BCHW x = self.fow(x) x = self.drop_path(x) out = self.proj(input_x) + x return out class AttentionGate(nn.Module): def __init__(self, in_ch_x, in_ch_g, out_ch): super().__init__() self.act = nn.ReLU(inplace=True) self.w_x_g = nn.Conv2d(in_ch_x + in_ch_g, out_ch, kernel_size=1, stride=1, padding=0, bias=False) self.attn = nn.Conv2d(out_ch, out_ch, kernel_size=1, padding=0, bias=False) def forward(self, x, g): res = x xg = torch.cat([x, g], dim=1) # B, (x_c+g_c), H, W xg = self.w_x_g(xg) xg = self.act(xg) attn = torch.sigmoid(self.attn(xg)) out = res * attn return out class WACA_Unet(nn.Module): def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16, depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs): super().__init__() self.depth = depth chs = [base_ch * 2**i for i in range(depth+1)] self.drop_path = drop_path n_enc_blocks = depth + 1 n_dec_blocks = depth total_blocks = n_enc_blocks + n_dec_blocks drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist() enc_dp_rates = drop_path_rates[:n_enc_blocks] dec_dp_rates = drop_path_rates[n_enc_blocks:] # Encoder self.enc_blocks = nn.ModuleList([ block(in_ch, chs[0], reduction, drop_path=enc_dp_rates[0]) ] + [ block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1]) for i in range(depth) ]) self.pool = nn.ModuleList([ nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i]) for i in range(depth) ]) # Decoder self.upconvs = nn.ModuleList([ nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2) for i in reversed(range(depth)) ]) self.dec_blocks = nn.ModuleList([ block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i]) for i in reversed(range(depth)) ]) # Attention Gates self.attn_gates = nn.ModuleList([ AttentionGate(chs[i], chs[i], chs[i]) for i in reversed(range(depth)) ]) self.final_head = nn.Sequential( nn.Conv2d(chs[0], out_ch, kernel_size=1) ) def forward(self, x): enc_feats = [] for i, enc in enumerate(self.enc_blocks): x = enc(x) enc_feats.append(x) if i < self.depth: x = self.pool[i](x) # Decoder for i in range(self.depth): x = self.upconvs[i](x) enc_feat = enc_feats[self.depth-1-i] # AttentionGate: (encoder feature, decoder upconv output) attn_enc_feat = self.attn_gates[i](enc_feat, x) x = torch.cat([attn_enc_feat, x], dim=1) x = self.dec_blocks[i](x) out = self.final_head(x) return { 'x_recon': out } ############################################################################### from torch.nn.utils.rnn import pack_padded_sequence class GRUStem(nn.Module): """ Zero-padded variable channels. 각 채널을 공유 인코더(phi)로 임베딩 후, 채널 축을 시간축으로 간주해 BiGRU로 통합. """ def __init__(self, out_channels: int = 64, embed_channels: int = 16, small_input: bool = True): super().__init__() stride = 1 if small_input else 2 self.phi = nn.Sequential( nn.Conv2d(1, embed_channels, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(embed_channels), nn.ReLU(inplace=True), ) hidden = out_channels // 2 assert hidden > 0, "out_channels must be >=2 to support BiGRU" self.gru = nn.GRU(input_size=embed_channels, hidden_size=hidden, num_layers=1, bidirectional=True) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.ReLU(inplace=True) self.out_channels = out_channels self.small_input = small_input def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, Cmax, H, W] with zero-padded channels B, Cmax, H, W = x.shape # non-zero channel lengths with torch.no_grad(): nonzero_ch = (x.abs().sum(dim=(2, 3)) > 0) # [B, Cmax] lengths = nonzero_ch.sum(dim=1).clamp(min=1) # [B] # shared encoder φ for each channel feat_per_c = [self.phi(x[:, c:c+1, :, :]) for c in range(Cmax)] # list of [B,E,H',W'] Fstack = torch.stack(feat_per_c, dim=0) # [Cmax, B, E, H', W'] Cseq, Bsz, E, Hp, Wp = Fstack.shape # sequence for GRU: [T=Cmax, N=B*Hp*Wp, E] Fseq = Fstack.permute(0, 1, 3, 4, 2).contiguous().view(Cseq, Bsz * Hp * Wp, E) lens = lengths.repeat_interleave(Hp * Wp).cpu() # [N] packed = pack_padded_sequence(Fseq, lens, enforce_sorted=False) _, h_n = self.gru(packed) # [2, N, hidden] h_cat = torch.cat([h_n[0], h_n[1]], dim=-1) # [N, out_ch] out = h_cat.view(Bsz, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() # [B,out,H',W'] out = self.act(self.bn(out)) return out ################################################################## class _PoolDownMixin: def __init__(self, small_input: bool): self._stride = 1 if small_input else 2 def _maybe_down(self, y: torch.Tensor) -> torch.Tensor: if self._stride == 2: return F.avg_pool2d(y, 2) return y class FourierStem2D(nn.Module, _PoolDownMixin): def __init__(self, out_dim=64, basis="chebyshev", small_input: bool = True): nn.Module.__init__(self); _PoolDownMixin.__init__(self, small_input) assert basis in ("fourier", "chebyshev") self.out_dim = out_dim; self.basis = basis self.proj = nn.Linear(out_dim, out_dim) self._basis_cache: Dict[Tuple[int, str, torch.device, torch.dtype], torch.Tensor] = {} def _get_basis(self, C, device, dtype): key = (C, self.basis, device, dtype) if key in self._basis_cache: return self._basis_cache[key] idx = torch.linspace(-1, 1, C, device=device, dtype=dtype).unsqueeze(0) if self.basis == "fourier": B = torch.stack([torch.cos(idx * i * math.pi) for i in range(1, self.out_dim+1)], dim=-1) else: B = torch.stack([torch.cos(i * torch.acos(idx)) for i in range(1, self.out_dim+1)], dim=-1) self._basis_cache[key] = B; return B def forward(self, x: torch.Tensor): B, C, H, W = x.shape; device, dtype = x.device, x.dtype x_flat = x.permute(0,2,3,1).reshape(-1, C) # [BHW,C] basis = self._get_basis(C, device, dtype)[0] # [C,D] emb = (x_flat @ basis) # [BHW,D] emb = self.proj(emb).view(B, H, W, self.out_dim).permute(0,3,1,2) return self._maybe_down(emb) class WACA_Unet_stem(nn.Module): def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16, depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs): super().__init__() self.depth = depth chs = [base_ch * 2**i for i in range(depth+1)] self.drop_path = drop_path n_enc_blocks = depth + 1 n_dec_blocks = depth total_blocks = n_enc_blocks + n_dec_blocks drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist() enc_dp_rates = drop_path_rates[:n_enc_blocks] dec_dp_rates = drop_path_rates[n_enc_blocks:] # self.stem = GRUStem(out_channels=16,embed_channels=16,small_input=True) self.stem = FourierStem2D(chs[0]) # self.up0 = nn.Upsample(scale_factor=2,mode='bicubic') # Encoder self.enc_blocks = nn.ModuleList([ block(chs[0], chs[0], reduction, drop_path=enc_dp_rates[0]) ] + [ block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1]) for i in range(depth) ]) self.pool = nn.ModuleList([ nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i]) for i in range(depth) ]) # Decoder self.upconvs = nn.ModuleList([ nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2) for i in reversed(range(depth)) ]) self.dec_blocks = nn.ModuleList([ block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i]) for i in reversed(range(depth)) ]) # Attention Gates self.attn_gates = nn.ModuleList([ AttentionGate(chs[i], chs[i], chs[i]) for i in reversed(range(depth)) ]) self.final_head = nn.Sequential( nn.Conv2d(chs[0], out_ch, kernel_size=1) ) def forward(self, x): enc_feats = [] x = self.stem(x) # x = self.up0(x) for i, enc in enumerate(self.enc_blocks): x = enc(x) enc_feats.append(x) if i < self.depth: x = self.pool[i](x) # Decoder for i in range(self.depth): x = self.upconvs[i](x) enc_feat = enc_feats[self.depth-1-i] # AttentionGate: (encoder feature, decoder upconv output) attn_enc_feat = self.attn_gates[i](enc_feat, x) x = torch.cat([attn_enc_feat, x], dim=1) x = self.dec_blocks[i](x) out = self.final_head(x) return { 'x_recon': out } if __name__ == '__main__': for block in [ConvNeXtV2BlockWACA]: print(f"Testing block: {block.__name__}") model = WACA_Unet_stem(in_ch=25, out_ch=1,block=block, depth=4) dummy = torch.randn(2,25, 384, 384) out = model(dummy)['x_recon'] print(f"Input shape: {dummy.shape}") print(f"Output shape: {out.shape}") total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total trainable parameters: {total_params:,}")