import torch def get_activation(actv_config): actv_cls = getattr(torch.nn, actv_config.name, None) assert actv_cls is not None, "No activation function" if actv_config.params: return ( actv_cls(**actv_config.params) if isinstance(actv_config.params, dict) else actv_cls(**actv_config.params.model_dump()) ) else: return actv_cls() """ ECGRecoverRandomMaskWithRS4 와 차이: lead 내에서 VCP 만 적용 + lead II 를 k/v 로 사용해서 다른 lead 로 rhythm 정보 전달 """ import torch import torch.nn as nn import torch.nn.functional as F class Convolution1D_layer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout ): super(Convolution1D_layer, self).__init__() self.out_channels = out_channels self.kernel_size = kernel_size self.padding = padding self.conv = nn.Sequential( nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=padding, ), nn.BatchNorm1d(num_features=out_channels), nn.LeakyReLU(leaky_relu), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor): out_size = (x.shape[-1] + 2 * self.padding - self.kernel_size) // 2 + 1 new_x = torch.zeros( (len(x), self.out_channels, 12, out_size), dtype=x.dtype, device=x.device, ) for i in range(12): new_x[:, :, i, :] = self.conv(x[:, :, i, :]) return new_x class Convolution2D_layer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout ): super(Convolution2D_layer, self).__init__() self.conv = nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1, 2), padding=padding, ), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(leaky_relu), # nn.Dropout(dropout) ) def forward(self, x): return self.conv(x) class Deconvolution2D_layer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout ): super(Deconvolution2D_layer, self).__init__() self.deconv = nn.Sequential( nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=(1, 2), padding=padding, ), nn.BatchNorm2d(num_features=out_channels), nn.LeakyReLU(leaky_relu), # nn.Dropout(dropout) ) def forward(self, x): return self.deconv(x) class VCBlock(nn.Module): """ enc: (B, C, 12, D) 1) lead-wise self-attention (mask 를 이용한 VCP 방식): - q: full lead (B, D, C) - k, v: visible 구간 from mask 2) lead II -> others cross-attention: - q: full lead (B, D, C) - k, v: lead ii 의 visible 구간 from mask 3) residual: enc + 1) + 2) """ def __init__(self, channels: int, num_heads: int = 4): super().__init__() self.self_attn = nn.MultiheadAttention( embed_dim=channels, num_heads=num_heads, batch_first=True ) self.cross_attn = nn.MultiheadAttention( embed_dim=channels, num_heads=num_heads, batch_first=True ) def forward(self, enc: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: _, _, L, _ = enc.shape # (B, C, 12, D) enc_r = enc.permute(0, 2, 3, 1) # (B, 12, D, C) attn_self = torch.zeros_like(enc_r, dtype=enc_r.dtype, device=enc_r.device) attn_lead2 = torch.zeros_like(enc_r, dtype=enc_r.dtype, device=enc_r.device) # print(f"in refineblock 1: {enc_r.shape}") # Lead II K/V (모든 lead에 공통) k2 = v2 = enc_r[:, 1, :, :] # (B, D, C) key_padding_mask2 = mask[:, 1, :].bool() # (B, D) for lead in range(L): # lead 내에서 self-attention (mask 를 활용한 VCP 방식) q = enc_r[:, lead, :, :] # (B, D, C) k = v = enc_r[:, lead, :, :] # (B, D, C) key_padding_mask = mask[:, lead, :].bool() # (B, D) _attn_self, _ = self.self_attn(q, k, v, key_padding_mask=key_padding_mask) attn_self[:, lead, :, :] = _attn_self # lead II -> other lead cross-attention _attn_lead2, _ = self.cross_attn( q, k2, v2, key_padding_mask=key_padding_mask2 ) attn_lead2[:, lead, :, :] = _attn_lead2 # print(f"in refineblock 2: {attn_out.shape}") vc = enc_r + attn_self + attn_lead2 # residual: (B, 12, D, C) vc_r = vc.permute(0, 3, 1, 2) # (B, C, 12, D) # print(f"in refineblock 3: {refined.shape}") # visible_kv_mean = visible_kv_raw.mean(dim=1) # (B,12,vis_len) # return refined, visible_kv_mean return vc_r class ECGRecoverRandomMaskWithRS5(nn.Module): def __init__(self, config, verbose=False): super().__init__() self.verbose = verbose self.activation = get_activation(config.activation) inplanes = int(config.inplanes) kernel_size = tuple(config.kernel_size) assert len(kernel_size) == 2, "len(kernel_size) must be 2" assert kernel_size[0] % 2 == 1, "kernel_size[0] must be odd" padding_1d = (kernel_size[1] - 1) // 2 padding_2d = [(k - 1) // 2 for k in kernel_size] num_heads = int(config.num_heads) num_depths_cfg = getattr(config, "num_depths_attn_start", 5) if isinstance(num_depths_cfg, (tuple, list)): self.num_depths, self.attn_start = num_depths_cfg else: self.num_depths = int(num_depths_cfg) self.attn_start = self.num_depths # attention 없음 leaky_relu = float(config.leaky_relu) dropout = float(config.dropout) # self.output_size: int = config.output_size self.convs_1d = nn.ModuleList() self.convs_2d = nn.ModuleList() self.vc_blocks = nn.ModuleDict() # mask + cross-attn + residual for d in range(self.num_depths): in_channels = 1 if d == 0 else inplanes * (2 ** (d - 1)) out_channels = inplanes * (2**d) self.convs_1d.append( Convolution1D_layer( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size[1], padding=padding_1d, leaky_relu=leaky_relu, dropout=dropout, ) ) self.convs_2d.append( Convolution2D_layer( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding_2d, leaky_relu=leaky_relu, dropout=dropout, ) ) enc_channels = out_channels * 2 # concat(conv1d, conv2d) if d >= self.attn_start: self.vc_blocks[str(d)] = VCBlock( channels=enc_channels, num_heads=num_heads ) trans_channels = inplanes * (2**self.num_depths) self.trans_block = nn.Sequential( nn.ConvTranspose2d( in_channels=trans_channels, out_channels=trans_channels, kernel_size=kernel_size, stride=(1, 1), padding=padding_2d, ), nn.BatchNorm2d(trans_channels), nn.LeakyReLU(leaky_relu), ) self.deconvs = nn.ModuleList() for d in reversed(range(self.num_depths)): in_channels = ( trans_channels if d == self.num_depths - 1 else inplanes * 2 * (2 ** (d + 1)) ) out_channels = 1 if d == 0 else inplanes * (2**d) # print(f"Deconvolution2D_layer.__init__[{d}]: {in_channels} {out_channels}") self.deconvs.append( Deconvolution2D_layer( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding_2d, leaky_relu=leaky_relu, dropout=dropout, ) ) # print(f"creating deconv: {in_channels} / {out_channels}") def _downsample_mask(self, mask: torch.Tensor, target_D: int) -> torch.Tensor: """ mask: (B, 12, D) with 1=invisible, 0=visible return: (B, 12, target_D) with 1/0 유지 """ mask_down = mask.float() mask_down = F.max_pool1d( mask_down, kernel_size=2, stride=2 ) # (B,12,floor(D/2)) if mask_down.shape[-1] != target_D: mask_down = F.interpolate(mask_down, size=target_D, mode="nearest") return (mask_down >= 0.5).to(mask.dtype) def _log(self, name, x): if self.verbose: print(f"{name:<28}: {tuple(x.shape)}") def make_default_group_center_mask_batch( self, B: int, device=None, dtype=torch.int8 ): group_len = 1250 # 2.5s vis_len = 625 # 1.25s total_len = 5000 center_offset = (group_len - vis_len) // 2 # 312 # mask만 tensor로 생성 mask = torch.ones((12, total_len), device=device, dtype=dtype) for g in range(4): start = g * group_len + center_offset end = start + vis_len for lead in range(g * 3, g * 3 + 3): mask[lead, start:end] = 0 # visible return mask.unsqueeze(0).expand(B, -1, -1) def forward(self, x: torch.Tensor) -> torch.Tensor: input, mask = x B, L, D = input.shape if mask is None: mask = self.make_default_group_center_mask_batch( B, device=input.device, dtype=torch.float16 ) assert ( L == 12 and D == 5000 ), "this network's input must be 12 lead 5000 points digitized signal" input = input.unsqueeze(1) # make channel out_1d = input out_2d = input encs = [] mask_down = mask # encs_visible = [] # print(f"input: {input.shape}") self._log("input", input) for d in range(self.num_depths): out_1d = self.convs_1d[d](out_1d) out_2d = self.convs_2d[d](out_2d) enc = torch.cat((out_1d, out_2d), dim=1) # (B, 2*C, 12, D) self._log(f"enc[{d}]", enc) mask_down = self._downsample_mask(mask_down, enc.shape[-1]) self._log(f"mask_down[{d}]", mask_down) key = str(d) if key in self.vc_blocks: enc = self.vc_blocks[key](enc, mask_down) self._log(f"enc_refined[{d}]", enc) encs.append(enc) trans = self.trans_block(encs[-1]) self._log("trans", trans) out = self.deconvs[0](trans) self._log("out initial", out) # combine skip connection and visible context with encoding feature for d in range(1, self.num_depths): skip = encs[-(d + 1)] # 아래쪽 depth부터 사용 self._log(f"skip[{d}]", skip) out = F.interpolate(out, skip.shape[-2:], mode="nearest") self._log(f"out upsampled[{d}]", out) out = torch.cat((out, skip), dim=1) self._log(f"out concat[{d}]", out) out = self.deconvs[d](out) self._log(f"out deconv[{d}]", out) out = F.interpolate(out, input.shape[-2:], mode="nearest") self._log("out final upsampled", out) out = out.squeeze(1) self._log("out final", out) # return out, encs_visible return out if __name__ == "__main__": def get_model_size(model): total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) model_size_MB = total_params * 4 / (1024**2) # float32 기준 (4 bytes) print(f"Total Parameters : {total_params:,}") print(f"Trainable Parameters : {trainable_params:,}") print(f"Estimated Model Size : {model_size_MB:.2f} MB") return total_params, model_size_MB class Config: pass class Activation: pass config = Config() config.inplanes = 8 config.kernel_size = (7, 7) config.num_depths_attn_start = (5, 2) config.num_heads = 8 config.leaky_relu = 0.02 config.dropout = 0.2 config.activation = Activation() config.activation.name = "Identity" config.activation.params = None input = torch.rand(size=(1, 12, 5000)) model = ECGRecoverRandomMaskWithRS5(config, True) model.eval() out = model([input, None]) print(out.shape) from torchinfo import summary # for i in range(len(encs_visible)): # print(encs_visible[i].shape) # summary(model, input_size=(1, 12, 5000), depth=4) # get_model_size(model) # from torchviz import make_dot # # 그래프 생성 # dot = make_dot( # out, params=dict(model.named_parameters()), show_attrs=False, show_saved=False # ) # # 파일로 저장 (PNG) # dot.render("ecgrecover_vc_filtermask", format="png") # from torchview import draw_graph # graph = draw_graph( # model, # input_size=(1, 12, 5000), # expand_nested=False, # ← 내부 세부 구조 펼치지 않음 → 매우 간단 # graph_dir="TD", # top-down # ) # graph.visual_graph.render("model_overview", format="png")