| import torch |
|
|
|
|
| def get_activation(actv_config): |
| actv_cls = getattr(torch.nn, actv_config.name, None) |
| assert actv_cls is not None, "No activation function" |
| if actv_config.params: |
| return ( |
| actv_cls(**actv_config.params) |
| if isinstance(actv_config.params, dict) |
| else actv_cls(**actv_config.params.model_dump()) |
| ) |
| else: |
| return actv_cls() |
|
|
| """ |
| ECGRecoverRandomMaskWithRS4 와 차이: lead 내에서 VCP 만 적용 + lead II 를 k/v 로 사용해서 다른 lead 로 rhythm 정보 전달 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class Convolution1D_layer(nn.Module): |
| def __init__( |
| self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout |
| ): |
| super(Convolution1D_layer, self).__init__() |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.padding = padding |
| self.conv = nn.Sequential( |
| nn.Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=2, |
| padding=padding, |
| ), |
| nn.BatchNorm1d(num_features=out_channels), |
| nn.LeakyReLU(leaky_relu), |
| nn.Dropout(dropout), |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| out_size = (x.shape[-1] + 2 * self.padding - self.kernel_size) // 2 + 1 |
| new_x = torch.zeros( |
| (len(x), self.out_channels, 12, out_size), |
| dtype=x.dtype, |
| device=x.device, |
| ) |
| for i in range(12): |
| new_x[:, :, i, :] = self.conv(x[:, :, i, :]) |
| return new_x |
|
|
|
|
| class Convolution2D_layer(nn.Module): |
| def __init__( |
| self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout |
| ): |
| super(Convolution2D_layer, self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=(1, 2), |
| padding=padding, |
| ), |
| nn.BatchNorm2d(num_features=out_channels), |
| nn.LeakyReLU(leaky_relu), |
| |
| ) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Deconvolution2D_layer(nn.Module): |
| def __init__( |
| self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout |
| ): |
| super(Deconvolution2D_layer, self).__init__() |
| self.deconv = nn.Sequential( |
| nn.ConvTranspose2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=(1, 2), |
| padding=padding, |
| ), |
| nn.BatchNorm2d(num_features=out_channels), |
| nn.LeakyReLU(leaky_relu), |
| |
| ) |
|
|
| def forward(self, x): |
| return self.deconv(x) |
|
|
|
|
| class VCBlock(nn.Module): |
| """ |
| enc: (B, C, 12, D) |
| 1) lead-wise self-attention (mask 를 이용한 VCP 방식): |
| - q: full lead (B, D, C) |
| - k, v: visible 구간 from mask |
| 2) lead II -> others cross-attention: |
| - q: full lead (B, D, C) |
| - k, v: lead ii 의 visible 구간 from mask |
| 3) residual: enc + 1) + 2) |
| """ |
|
|
| def __init__(self, channels: int, num_heads: int = 4): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=channels, num_heads=num_heads, batch_first=True |
| ) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=channels, num_heads=num_heads, batch_first=True |
| ) |
|
|
| def forward(self, enc: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| _, _, L, _ = enc.shape |
|
|
| enc_r = enc.permute(0, 2, 3, 1) |
| attn_self = torch.zeros_like(enc_r, dtype=enc_r.dtype, device=enc_r.device) |
| attn_lead2 = torch.zeros_like(enc_r, dtype=enc_r.dtype, device=enc_r.device) |
| |
| |
| k2 = v2 = enc_r[:, 1, :, :] |
| key_padding_mask2 = mask[:, 1, :].bool() |
|
|
| for lead in range(L): |
| |
| q = enc_r[:, lead, :, :] |
| k = v = enc_r[:, lead, :, :] |
| key_padding_mask = mask[:, lead, :].bool() |
| _attn_self, _ = self.self_attn(q, k, v, key_padding_mask=key_padding_mask) |
| attn_self[:, lead, :, :] = _attn_self |
|
|
| |
| _attn_lead2, _ = self.cross_attn( |
| q, k2, v2, key_padding_mask=key_padding_mask2 |
| ) |
| attn_lead2[:, lead, :, :] = _attn_lead2 |
| |
| vc = enc_r + attn_self + attn_lead2 |
| vc_r = vc.permute(0, 3, 1, 2) |
| |
| |
|
|
| |
| return vc_r |
|
|
|
|
| class ECGRecoverRandomMaskWithRS5(nn.Module): |
| def __init__(self, config, verbose=False): |
| super().__init__() |
| self.verbose = verbose |
| self.activation = get_activation(config.activation) |
| inplanes = int(config.inplanes) |
| kernel_size = tuple(config.kernel_size) |
| assert len(kernel_size) == 2, "len(kernel_size) must be 2" |
| assert kernel_size[0] % 2 == 1, "kernel_size[0] must be odd" |
| padding_1d = (kernel_size[1] - 1) // 2 |
| padding_2d = [(k - 1) // 2 for k in kernel_size] |
|
|
| num_heads = int(config.num_heads) |
| num_depths_cfg = getattr(config, "num_depths_attn_start", 5) |
| if isinstance(num_depths_cfg, (tuple, list)): |
| self.num_depths, self.attn_start = num_depths_cfg |
| else: |
| self.num_depths = int(num_depths_cfg) |
| self.attn_start = self.num_depths |
| leaky_relu = float(config.leaky_relu) |
| dropout = float(config.dropout) |
| |
|
|
| self.convs_1d = nn.ModuleList() |
| self.convs_2d = nn.ModuleList() |
| self.vc_blocks = nn.ModuleDict() |
|
|
| for d in range(self.num_depths): |
| in_channels = 1 if d == 0 else inplanes * (2 ** (d - 1)) |
| out_channels = inplanes * (2**d) |
| self.convs_1d.append( |
| Convolution1D_layer( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size[1], |
| padding=padding_1d, |
| leaky_relu=leaky_relu, |
| dropout=dropout, |
| ) |
| ) |
| self.convs_2d.append( |
| Convolution2D_layer( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| padding=padding_2d, |
| leaky_relu=leaky_relu, |
| dropout=dropout, |
| ) |
| ) |
|
|
| enc_channels = out_channels * 2 |
| if d >= self.attn_start: |
| self.vc_blocks[str(d)] = VCBlock( |
| channels=enc_channels, num_heads=num_heads |
| ) |
|
|
| trans_channels = inplanes * (2**self.num_depths) |
| self.trans_block = nn.Sequential( |
| nn.ConvTranspose2d( |
| in_channels=trans_channels, |
| out_channels=trans_channels, |
| kernel_size=kernel_size, |
| stride=(1, 1), |
| padding=padding_2d, |
| ), |
| nn.BatchNorm2d(trans_channels), |
| nn.LeakyReLU(leaky_relu), |
| ) |
|
|
| self.deconvs = nn.ModuleList() |
| for d in reversed(range(self.num_depths)): |
| in_channels = ( |
| trans_channels |
| if d == self.num_depths - 1 |
| else inplanes * 2 * (2 ** (d + 1)) |
| ) |
| out_channels = 1 if d == 0 else inplanes * (2**d) |
| |
| self.deconvs.append( |
| Deconvolution2D_layer( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| padding=padding_2d, |
| leaky_relu=leaky_relu, |
| dropout=dropout, |
| ) |
| ) |
| |
|
|
| def _downsample_mask(self, mask: torch.Tensor, target_D: int) -> torch.Tensor: |
| """ |
| mask: (B, 12, D) with 1=invisible, 0=visible |
| return: (B, 12, target_D) with 1/0 유지 |
| """ |
| mask_down = mask.float() |
| mask_down = F.max_pool1d( |
| mask_down, kernel_size=2, stride=2 |
| ) |
|
|
| if mask_down.shape[-1] != target_D: |
| mask_down = F.interpolate(mask_down, size=target_D, mode="nearest") |
|
|
| return (mask_down >= 0.5).to(mask.dtype) |
|
|
| def _log(self, name, x): |
| if self.verbose: |
| print(f"{name:<28}: {tuple(x.shape)}") |
|
|
| def make_default_group_center_mask_batch( |
| self, B: int, device=None, dtype=torch.int8 |
| ): |
| group_len = 1250 |
| vis_len = 625 |
| total_len = 5000 |
| center_offset = (group_len - vis_len) // 2 |
|
|
| |
| mask = torch.ones((12, total_len), device=device, dtype=dtype) |
|
|
| for g in range(4): |
| start = g * group_len + center_offset |
| end = start + vis_len |
|
|
| for lead in range(g * 3, g * 3 + 3): |
| mask[lead, start:end] = 0 |
|
|
| return mask.unsqueeze(0).expand(B, -1, -1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| input, mask = x |
| B, L, D = input.shape |
| if mask is None: |
| mask = self.make_default_group_center_mask_batch( |
| B, device=input.device, dtype=torch.float16 |
| ) |
|
|
| assert ( |
| L == 12 and D == 5000 |
| ), "this network's input must be 12 lead 5000 points digitized signal" |
|
|
| input = input.unsqueeze(1) |
| out_1d = input |
| out_2d = input |
| encs = [] |
| mask_down = mask |
| |
| |
| self._log("input", input) |
| for d in range(self.num_depths): |
| out_1d = self.convs_1d[d](out_1d) |
| out_2d = self.convs_2d[d](out_2d) |
| enc = torch.cat((out_1d, out_2d), dim=1) |
| self._log(f"enc[{d}]", enc) |
| mask_down = self._downsample_mask(mask_down, enc.shape[-1]) |
| self._log(f"mask_down[{d}]", mask_down) |
| key = str(d) |
| if key in self.vc_blocks: |
| enc = self.vc_blocks[key](enc, mask_down) |
| self._log(f"enc_refined[{d}]", enc) |
| encs.append(enc) |
|
|
| trans = self.trans_block(encs[-1]) |
| self._log("trans", trans) |
| out = self.deconvs[0](trans) |
| self._log("out initial", out) |
| |
| for d in range(1, self.num_depths): |
| skip = encs[-(d + 1)] |
| self._log(f"skip[{d}]", skip) |
| out = F.interpolate(out, skip.shape[-2:], mode="nearest") |
| self._log(f"out upsampled[{d}]", out) |
| out = torch.cat((out, skip), dim=1) |
| self._log(f"out concat[{d}]", out) |
| out = self.deconvs[d](out) |
| self._log(f"out deconv[{d}]", out) |
|
|
| out = F.interpolate(out, input.shape[-2:], mode="nearest") |
| self._log("out final upsampled", out) |
| out = out.squeeze(1) |
| self._log("out final", out) |
| |
| return out |
|
|
|
|
| if __name__ == "__main__": |
|
|
| def get_model_size(model): |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| model_size_MB = total_params * 4 / (1024**2) |
| print(f"Total Parameters : {total_params:,}") |
| print(f"Trainable Parameters : {trainable_params:,}") |
| print(f"Estimated Model Size : {model_size_MB:.2f} MB") |
| return total_params, model_size_MB |
|
|
| class Config: |
| pass |
|
|
| class Activation: |
| pass |
|
|
| config = Config() |
| config.inplanes = 8 |
| config.kernel_size = (7, 7) |
| config.num_depths_attn_start = (5, 2) |
| config.num_heads = 8 |
| config.leaky_relu = 0.02 |
| config.dropout = 0.2 |
| config.activation = Activation() |
| config.activation.name = "Identity" |
| config.activation.params = None |
|
|
| input = torch.rand(size=(1, 12, 5000)) |
| model = ECGRecoverRandomMaskWithRS5(config, True) |
| model.eval() |
| out = model([input, None]) |
| print(out.shape) |
| from torchinfo import summary |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|