| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from timm.layers import DropPath |
| | import math |
| | from typing import List, Tuple, Optional, Dict |
| |
|
| |
|
| | class WACA_CBAM(nn.Module): |
| | def __init__(self, channels, reduction=16): |
| | super(WACA_CBAM, self).__init__() |
| | self.channels = channels |
| | if channels < reduction or channels // reduction == 0: |
| | self.reduced_channels = channels // 2 if channels > 1 else 1 |
| | else: |
| | self.reduced_channels = channels // reduction |
| |
|
| | self.fc_layers = nn.Sequential( |
| | nn.Conv2d(self.channels, self.reduced_channels, kernel_size=1, bias=False), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(self.reduced_channels, self.channels, kernel_size=1, bias=False) |
| | ) |
| | self.spatial_attn = nn.Sequential( |
| | nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, x): |
| | avg_pool = F.adaptive_avg_pool2d(x, 1) |
| | max_pool = F.adaptive_max_pool2d(x, 1) |
| | avg_out = self.fc_layers(avg_pool) |
| | max_out = self.fc_layers(max_pool) |
| |
|
| | gate_logits = avg_out + max_out |
| | weakness_scores = torch.sigmoid(-gate_logits) |
| | attn_scores = torch.sigmoid(gate_logits) |
| | gated_weak = x * weakness_scores |
| | squeezed_2_avg = F.adaptive_avg_pool2d(gated_weak, 1) |
| | squeezed_2_max = F.adaptive_max_pool2d(gated_weak, 1) |
| | gate_logits_2 = self.fc_layers(squeezed_2_avg+ squeezed_2_max) |
| | |
| | attn_scores_2 = torch.sigmoid(gate_logits_2) |
| | gated_attn = x * (attn_scores + attn_scores_2) * 0.5 |
| |
|
| | |
| | avg_out = torch.mean(gated_attn, dim=1, keepdim=True) |
| | max_out, _ = torch.max(gated_attn, dim=1, keepdim=True) |
| | sa_input = torch.cat([avg_out, max_out], dim=1) |
| | sa_weight = self.spatial_attn(sa_input) |
| | out = gated_attn * sa_weight |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | class LayerNorm(nn.Module): |
| | """ LayerNorm that supports two data formats: channels_last (default) or channels_first. |
| | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
| | shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
| | with shape (batch_size, channels, height, width). |
| | """ |
| | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
| | super().__init__() |
| | self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| | self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| | self.eps = eps |
| | self.data_format = data_format |
| | if self.data_format not in ["channels_last", "channels_first"]: |
| | raise NotImplementedError |
| | self.normalized_shape = (normalized_shape, ) |
| | |
| | def forward(self, x): |
| | if self.data_format == "channels_last": |
| | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
| | elif self.data_format == "channels_first": |
| | u = x.mean(1, keepdim=True) |
| | s = (x - u).pow(2).mean(1, keepdim=True) |
| | x = (x - u) / torch.sqrt(s + self.eps) |
| | x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| | return x |
| | |
| | class GRN(nn.Module): |
| | """ GRN (Global Response Normalization) layer |
| | """ |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| | self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| |
|
| | def forward(self, x): |
| | Gx = torch.norm(x, p=2, dim=(1,2), keepdim=True) |
| | Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) |
| | return self.gamma * (x * Nx) + self.beta + x |
| |
|
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| |
|
| | class ConvNeXtV2BlockWACA_Atrous(nn.Module): |
| | def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilation=3): |
| | super().__init__() |
| | |
| | |
| | |
| | padding = dilation * 3 |
| | self.dwconv = nn.Conv2d( |
| | in_ch, in_ch, |
| | kernel_size=7, |
| | padding=padding, |
| | groups=in_ch, |
| | dilation=dilation |
| | ) |
| | |
| | self.norm = LayerNorm(in_ch, eps=1e-6) |
| | self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) |
| | self.act = nn.GELU() |
| | self.grn = GRN(4 * in_ch) |
| | self.pwconv2 = nn.Linear(4 * in_ch, out_ch) |
| | self.fow = WACA_CBAM(out_ch, reduction=reduction) |
| | |
| | self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | |
| | def forward(self, x): |
| | input_x = x |
| | x = self.dwconv(x) |
| | x = x.permute(0, 2, 3, 1) |
| | x = self.norm(x) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.grn(x) |
| | x = self.pwconv2(x) |
| | x = x.permute(0, 3, 1, 2) |
| | x = self.fow(x) |
| | x = self.drop_path(x) |
| | |
| | out = self.proj(input_x) + x |
| | return out |
| |
|
| |
|
| | |
| | class ConvNeXtV2BlockWACA_MultiAtrous(nn.Module): |
| | def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 2, 4]): |
| | super().__init__() |
| | |
| | |
| | self.dwconv_branches = nn.ModuleList([ |
| | nn.Conv2d( |
| | in_ch, in_ch // len(dilations), |
| | kernel_size=7, |
| | padding=d * 3, |
| | groups=in_ch // len(dilations), |
| | dilation=d |
| | ) for d in dilations |
| | ]) |
| | |
| | |
| | self.combine_conv = nn.Conv2d(in_ch, in_ch, 1) |
| | |
| | self.norm = LayerNorm(in_ch, eps=1e-6) |
| | self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) |
| | self.act = nn.GELU() |
| | self.grn = GRN(4 * in_ch) |
| | self.pwconv2 = nn.Linear(4 * in_ch, out_ch) |
| | self.fow = WACA_CBAM(out_ch, reduction=reduction) |
| | |
| | self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | |
| | def forward(self, x): |
| | input_x = x |
| | |
| | |
| | branch_outputs = [] |
| | for i, dwconv in enumerate(self.dwconv_branches): |
| | |
| | channels_per_branch = x.size(1) // len(self.dwconv_branches) |
| | start_idx = i * channels_per_branch |
| | end_idx = (i + 1) * channels_per_branch if i < len(self.dwconv_branches) - 1 else x.size(1) |
| | branch_input = x[:, start_idx:end_idx, :, :] |
| | branch_outputs.append(dwconv(branch_input)) |
| | |
| | |
| | x = torch.cat(branch_outputs, dim=1) |
| | x = self.combine_conv(x) |
| | |
| | x = x.permute(0, 2, 3, 1) |
| | x = self.norm(x) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.grn(x) |
| | x = self.pwconv2(x) |
| | x = x.permute(0, 3, 1, 2) |
| | x = self.fow(x) |
| | x = self.drop_path(x) |
| | |
| | out = self.proj(input_x) + x |
| | return out |
| |
|
| |
|
| | |
| | class ConvNeXtV2BlockWACA_ASPP(nn.Module): |
| | def __init__(self, in_ch, out_ch, reduction=16, drop_path=0., dilations=[1, 6, 12, 18]): |
| | super().__init__() |
| | |
| | |
| | self.aspp_branches = nn.ModuleList() |
| | |
| | for dilation in dilations: |
| | if dilation == 1: |
| | |
| | branch = nn.Conv2d(in_ch, in_ch // len(dilations), 1) |
| | else: |
| | |
| | branch = nn.Conv2d( |
| | in_ch, in_ch // len(dilations), |
| | kernel_size=3, |
| | padding=dilation, |
| | dilation=dilation, |
| | groups=in_ch // len(dilations) |
| | ) |
| | self.aspp_branches.append(branch) |
| | |
| | |
| | self.global_avg_pool = nn.Sequential( |
| | nn.AdaptiveAvgPool2d((1, 1)), |
| | nn.Conv2d(in_ch, in_ch // len(dilations), 1), |
| | ) |
| | |
| | |
| | total_channels = (len(dilations) + 1) * (in_ch // len(dilations)) |
| | self.combine_conv = nn.Conv2d(total_channels, in_ch, 1) |
| | |
| | self.norm = LayerNorm(in_ch, eps=1e-6) |
| | self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) |
| | self.act = nn.GELU() |
| | self.grn = GRN(4 * in_ch) |
| | self.pwconv2 = nn.Linear(4 * in_ch, out_ch) |
| | self.fow = WACA_CBAM(out_ch, reduction=reduction) |
| | |
| | self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| | |
| | def forward(self, x): |
| | input_x = x |
| | h, w = x.size()[2:] |
| | |
| | |
| | branch_outputs = [] |
| | for branch in self.aspp_branches: |
| | branch_outputs.append(branch(x)) |
| | |
| | |
| | global_feat = self.global_avg_pool(x) |
| | global_feat = F.interpolate(global_feat, size=(h, w), mode='bilinear', align_corners=False) |
| | branch_outputs.append(global_feat) |
| | |
| | |
| | x = torch.cat(branch_outputs, dim=1) |
| | x = self.combine_conv(x) |
| | |
| | x = x.permute(0, 2, 3, 1) |
| | x = self.norm(x) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.grn(x) |
| | x = self.pwconv2(x) |
| | x = x.permute(0, 3, 1, 2) |
| | x = self.fow(x) |
| | x = self.drop_path(x) |
| | |
| | out = self.proj(input_x) + x |
| | return out |
| |
|
| |
|
| |
|
| |
|
| | |
| | class ConvNeXtV2BlockWACA(nn.Module): |
| | def __init__(self, in_ch, out_ch, reduction=16, drop_path=0.,use_grn=True): |
| | super().__init__() |
| | self.dwconv = nn.Conv2d(in_ch, in_ch, kernel_size=7, padding=3, groups=in_ch) |
| | self.norm = LayerNorm(in_ch, eps=1e-6) |
| | self.pwconv1 = nn.Linear(in_ch, 4 * in_ch) |
| | self.act = nn.GELU() |
| | self.grn = GRN(4 * in_ch) if use_grn else nn.Identity() |
| | self.pwconv2 = nn.Linear(4 * in_ch, out_ch) |
| | self.fow = WACA_CBAM(out_ch,reduction=reduction) |
| | self.proj = nn.Identity() if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1) |
| | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
|
| | def forward(self, x): |
| | input_x = x |
| | x = self.dwconv(x) |
| | x = x.permute(0, 2, 3, 1) |
| | x = self.norm(x) |
| | x = self.pwconv1(x) |
| | x = self.act(x) |
| | x = self.grn(x) |
| | x = self.pwconv2(x) |
| | x = x.permute(0, 3, 1, 2) |
| | x = self.fow(x) |
| | x = self.drop_path(x) |
| | out = self.proj(input_x) + x |
| | return out |
| |
|
| |
|
| | class AttentionGate(nn.Module): |
| | def __init__(self, in_ch_x, in_ch_g, out_ch): |
| | super().__init__() |
| | self.act = nn.ReLU(inplace=True) |
| | self.w_x_g = nn.Conv2d(in_ch_x + in_ch_g, out_ch, kernel_size=1, stride=1, padding=0, bias=False) |
| | self.attn = nn.Conv2d(out_ch, out_ch, kernel_size=1, padding=0, bias=False) |
| | def forward(self, x, g): |
| | res = x |
| | xg = torch.cat([x, g], dim=1) |
| | xg = self.w_x_g(xg) |
| | xg = self.act(xg) |
| | attn = torch.sigmoid(self.attn(xg)) |
| | out = res * attn |
| | return out |
| |
|
| |
|
| | class WACA_Unet(nn.Module): |
| | def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16, |
| | depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs): |
| | super().__init__() |
| | self.depth = depth |
| | chs = [base_ch * 2**i for i in range(depth+1)] |
| | self.drop_path = drop_path |
| | |
| | n_enc_blocks = depth + 1 |
| | n_dec_blocks = depth |
| | total_blocks = n_enc_blocks + n_dec_blocks |
| |
|
| | drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist() |
| | enc_dp_rates = drop_path_rates[:n_enc_blocks] |
| | dec_dp_rates = drop_path_rates[n_enc_blocks:] |
| |
|
| | |
| | self.enc_blocks = nn.ModuleList([ |
| | block(in_ch, chs[0], reduction, drop_path=enc_dp_rates[0]) |
| | ] + [ |
| | block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1]) |
| | for i in range(depth) |
| | ]) |
| | self.pool = nn.ModuleList([ |
| | nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i]) |
| | for i in range(depth) |
| | ]) |
| | |
| | |
| | self.upconvs = nn.ModuleList([ |
| | nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2) |
| | for i in reversed(range(depth)) |
| | ]) |
| | self.dec_blocks = nn.ModuleList([ |
| | block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i]) |
| | for i in reversed(range(depth)) |
| | ]) |
| | |
| | self.attn_gates = nn.ModuleList([ |
| | AttentionGate(chs[i], chs[i], chs[i]) |
| | for i in reversed(range(depth)) |
| | ]) |
| |
|
| | self.final_head = nn.Sequential( |
| | nn.Conv2d(chs[0], out_ch, kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | enc_feats = [] |
| | for i, enc in enumerate(self.enc_blocks): |
| | x = enc(x) |
| | enc_feats.append(x) |
| | if i < self.depth: |
| | x = self.pool[i](x) |
| | |
| | for i in range(self.depth): |
| | x = self.upconvs[i](x) |
| | enc_feat = enc_feats[self.depth-1-i] |
| | |
| | attn_enc_feat = self.attn_gates[i](enc_feat, x) |
| | x = torch.cat([attn_enc_feat, x], dim=1) |
| | x = self.dec_blocks[i](x) |
| |
|
| | out = self.final_head(x) |
| | return { |
| | 'x_recon': out |
| | } |
| |
|
| | |
| | from torch.nn.utils.rnn import pack_padded_sequence |
| |
|
| | class GRUStem(nn.Module): |
| | """ |
| | Zero-padded variable channels. 각 채널을 공유 인코더(phi)로 임베딩 후, |
| | 채널 축을 시간축으로 간주해 BiGRU로 통합. |
| | """ |
| | def __init__(self, out_channels: int = 64, embed_channels: int = 16, small_input: bool = True): |
| | super().__init__() |
| | stride = 1 if small_input else 2 |
| | self.phi = nn.Sequential( |
| | nn.Conv2d(1, embed_channels, kernel_size=3, stride=stride, padding=1, bias=False), |
| | nn.BatchNorm2d(embed_channels), |
| | nn.ReLU(inplace=True), |
| | ) |
| | hidden = out_channels // 2 |
| | assert hidden > 0, "out_channels must be >=2 to support BiGRU" |
| | self.gru = nn.GRU(input_size=embed_channels, hidden_size=hidden, |
| | num_layers=1, bidirectional=True) |
| | self.bn = nn.BatchNorm2d(out_channels) |
| | self.act = nn.ReLU(inplace=True) |
| | self.out_channels = out_channels |
| | self.small_input = small_input |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | |
| | B, Cmax, H, W = x.shape |
| |
|
| | |
| | with torch.no_grad(): |
| | nonzero_ch = (x.abs().sum(dim=(2, 3)) > 0) |
| | lengths = nonzero_ch.sum(dim=1).clamp(min=1) |
| |
|
| | |
| | feat_per_c = [self.phi(x[:, c:c+1, :, :]) for c in range(Cmax)] |
| | Fstack = torch.stack(feat_per_c, dim=0) |
| | Cseq, Bsz, E, Hp, Wp = Fstack.shape |
| |
|
| | |
| | Fseq = Fstack.permute(0, 1, 3, 4, 2).contiguous().view(Cseq, Bsz * Hp * Wp, E) |
| | lens = lengths.repeat_interleave(Hp * Wp).cpu() |
| | packed = pack_padded_sequence(Fseq, lens, enforce_sorted=False) |
| | _, h_n = self.gru(packed) |
| | h_cat = torch.cat([h_n[0], h_n[1]], dim=-1) |
| |
|
| | out = h_cat.view(Bsz, Hp, Wp, -1).permute(0, 3, 1, 2).contiguous() |
| | out = self.act(self.bn(out)) |
| | return out |
| |
|
| | |
| |
|
| | class _PoolDownMixin: |
| | def __init__(self, small_input: bool): |
| | self._stride = 1 if small_input else 2 |
| | def _maybe_down(self, y: torch.Tensor) -> torch.Tensor: |
| | if self._stride == 2: |
| | return F.avg_pool2d(y, 2) |
| | return y |
| | |
| | class FourierStem2D(nn.Module, _PoolDownMixin): |
| | def __init__(self, out_dim=64, basis="chebyshev", small_input: bool = True): |
| | nn.Module.__init__(self); _PoolDownMixin.__init__(self, small_input) |
| | assert basis in ("fourier", "chebyshev") |
| | self.out_dim = out_dim; self.basis = basis |
| | self.proj = nn.Linear(out_dim, out_dim) |
| | self._basis_cache: Dict[Tuple[int, str, torch.device, torch.dtype], torch.Tensor] = {} |
| |
|
| | def _get_basis(self, C, device, dtype): |
| | key = (C, self.basis, device, dtype) |
| | if key in self._basis_cache: return self._basis_cache[key] |
| | idx = torch.linspace(-1, 1, C, device=device, dtype=dtype).unsqueeze(0) |
| | if self.basis == "fourier": |
| | B = torch.stack([torch.cos(idx * i * math.pi) for i in range(1, self.out_dim+1)], dim=-1) |
| | else: |
| | B = torch.stack([torch.cos(i * torch.acos(idx)) for i in range(1, self.out_dim+1)], dim=-1) |
| | self._basis_cache[key] = B; return B |
| |
|
| | def forward(self, x: torch.Tensor): |
| | B, C, H, W = x.shape; device, dtype = x.device, x.dtype |
| | x_flat = x.permute(0,2,3,1).reshape(-1, C) |
| | basis = self._get_basis(C, device, dtype)[0] |
| | emb = (x_flat @ basis) |
| | emb = self.proj(emb).view(B, H, W, self.out_dim).permute(0,3,1,2) |
| | return self._maybe_down(emb) |
| | |
| |
|
| | class WACA_Unet_stem(nn.Module): |
| | def __init__(self, in_ch=25, out_ch=1, base_ch=64, reduction=16, |
| | depth=4, drop_path=0.2, block=ConvNeXtV2BlockWACA, **kwargs): |
| | super().__init__() |
| | self.depth = depth |
| | chs = [base_ch * 2**i for i in range(depth+1)] |
| | self.drop_path = drop_path |
| | |
| | n_enc_blocks = depth + 1 |
| | n_dec_blocks = depth |
| | total_blocks = n_enc_blocks + n_dec_blocks |
| |
|
| | drop_path_rates = torch.linspace(0, drop_path, total_blocks).tolist() |
| | enc_dp_rates = drop_path_rates[:n_enc_blocks] |
| | dec_dp_rates = drop_path_rates[n_enc_blocks:] |
| | |
| | self.stem = FourierStem2D(chs[0]) |
| | |
| | |
| | self.enc_blocks = nn.ModuleList([ |
| | block(chs[0], chs[0], reduction, drop_path=enc_dp_rates[0]) |
| | ] + [ |
| | block(chs[i], chs[i+1], reduction, drop_path=enc_dp_rates[i+1]) |
| | for i in range(depth) |
| | ]) |
| | self.pool = nn.ModuleList([ |
| | nn.Conv2d(chs[i], chs[i], kernel_size=3, stride=2, padding=1, groups=chs[i]) |
| | for i in range(depth) |
| | ]) |
| | |
| | |
| | self.upconvs = nn.ModuleList([ |
| | nn.ConvTranspose2d(chs[i+1], chs[i], kernel_size=2, stride=2) |
| | for i in reversed(range(depth)) |
| | ]) |
| | self.dec_blocks = nn.ModuleList([ |
| | block(chs[i]*2, chs[i], reduction, drop_path=dec_dp_rates[i]) |
| | for i in reversed(range(depth)) |
| | ]) |
| | |
| | self.attn_gates = nn.ModuleList([ |
| | AttentionGate(chs[i], chs[i], chs[i]) |
| | for i in reversed(range(depth)) |
| | ]) |
| |
|
| | self.final_head = nn.Sequential( |
| | nn.Conv2d(chs[0], out_ch, kernel_size=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | enc_feats = [] |
| | x = self.stem(x) |
| | |
| | for i, enc in enumerate(self.enc_blocks): |
| | x = enc(x) |
| | enc_feats.append(x) |
| | if i < self.depth: |
| | x = self.pool[i](x) |
| | |
| | for i in range(self.depth): |
| | x = self.upconvs[i](x) |
| | enc_feat = enc_feats[self.depth-1-i] |
| | |
| | attn_enc_feat = self.attn_gates[i](enc_feat, x) |
| | x = torch.cat([attn_enc_feat, x], dim=1) |
| | x = self.dec_blocks[i](x) |
| |
|
| | out = self.final_head(x) |
| | return { |
| | 'x_recon': out |
| | } |
| |
|
| |
|
| |
|
| | if __name__ == '__main__': |
| | for block in [ConvNeXtV2BlockWACA]: |
| | print(f"Testing block: {block.__name__}") |
| | model = WACA_Unet_stem(in_ch=25, out_ch=1,block=block, depth=4) |
| | dummy = torch.randn(2,25, 384, 384) |
| | out = model(dummy)['x_recon'] |
| | print(f"Input shape: {dummy.shape}") |
| | print(f"Output shape: {out.shape}") |
| | total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | print(f"Total trainable parameters: {total_params:,}") |
| |
|