dr-studio / model /code /ECGRecoverRandomMaskWithRS5.py
wogh2012's picture
Upload Dash Docker Space
12f8999 verified
import torch
def get_activation(actv_config):
actv_cls = getattr(torch.nn, actv_config.name, None)
assert actv_cls is not None, "No activation function"
if actv_config.params:
return (
actv_cls(**actv_config.params)
if isinstance(actv_config.params, dict)
else actv_cls(**actv_config.params.model_dump())
)
else:
return actv_cls()
"""
ECGRecoverRandomMaskWithRS4 와 차이: lead 내에서 VCP 만 적용 + lead II 를 k/v 로 사용해서 다른 lead 로 rhythm 정보 전달
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Convolution1D_layer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout
):
super(Convolution1D_layer, self).__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.conv = nn.Sequential(
nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=2,
padding=padding,
),
nn.BatchNorm1d(num_features=out_channels),
nn.LeakyReLU(leaky_relu),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor):
out_size = (x.shape[-1] + 2 * self.padding - self.kernel_size) // 2 + 1
new_x = torch.zeros(
(len(x), self.out_channels, 12, out_size),
dtype=x.dtype,
device=x.device,
)
for i in range(12):
new_x[:, :, i, :] = self.conv(x[:, :, i, :])
return new_x
class Convolution2D_layer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout
):
super(Convolution2D_layer, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 2),
padding=padding,
),
nn.BatchNorm2d(num_features=out_channels),
nn.LeakyReLU(leaky_relu),
# nn.Dropout(dropout)
)
def forward(self, x):
return self.conv(x)
class Deconvolution2D_layer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size, padding, leaky_relu, dropout
):
super(Deconvolution2D_layer, self).__init__()
self.deconv = nn.Sequential(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 2),
padding=padding,
),
nn.BatchNorm2d(num_features=out_channels),
nn.LeakyReLU(leaky_relu),
# nn.Dropout(dropout)
)
def forward(self, x):
return self.deconv(x)
class VCBlock(nn.Module):
"""
enc: (B, C, 12, D)
1) lead-wise self-attention (mask 를 이용한 VCP 방식):
- q: full lead (B, D, C)
- k, v: visible 구간 from mask
2) lead II -> others cross-attention:
- q: full lead (B, D, C)
- k, v: lead ii 의 visible 구간 from mask
3) residual: enc + 1) + 2)
"""
def __init__(self, channels: int, num_heads: int = 4):
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=channels, num_heads=num_heads, batch_first=True
)
self.cross_attn = nn.MultiheadAttention(
embed_dim=channels, num_heads=num_heads, batch_first=True
)
def forward(self, enc: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
_, _, L, _ = enc.shape # (B, C, 12, D)
enc_r = enc.permute(0, 2, 3, 1) # (B, 12, D, C)
attn_self = torch.zeros_like(enc_r, dtype=enc_r.dtype, device=enc_r.device)
attn_lead2 = torch.zeros_like(enc_r, dtype=enc_r.dtype, device=enc_r.device)
# print(f"in refineblock 1: {enc_r.shape}")
# Lead II K/V (모든 lead에 공통)
k2 = v2 = enc_r[:, 1, :, :] # (B, D, C)
key_padding_mask2 = mask[:, 1, :].bool() # (B, D)
for lead in range(L):
# lead 내에서 self-attention (mask 를 활용한 VCP 방식)
q = enc_r[:, lead, :, :] # (B, D, C)
k = v = enc_r[:, lead, :, :] # (B, D, C)
key_padding_mask = mask[:, lead, :].bool() # (B, D)
_attn_self, _ = self.self_attn(q, k, v, key_padding_mask=key_padding_mask)
attn_self[:, lead, :, :] = _attn_self
# lead II -> other lead cross-attention
_attn_lead2, _ = self.cross_attn(
q, k2, v2, key_padding_mask=key_padding_mask2
)
attn_lead2[:, lead, :, :] = _attn_lead2
# print(f"in refineblock 2: {attn_out.shape}")
vc = enc_r + attn_self + attn_lead2 # residual: (B, 12, D, C)
vc_r = vc.permute(0, 3, 1, 2) # (B, C, 12, D)
# print(f"in refineblock 3: {refined.shape}")
# visible_kv_mean = visible_kv_raw.mean(dim=1) # (B,12,vis_len)
# return refined, visible_kv_mean
return vc_r
class ECGRecoverRandomMaskWithRS5(nn.Module):
def __init__(self, config, verbose=False):
super().__init__()
self.verbose = verbose
self.activation = get_activation(config.activation)
inplanes = int(config.inplanes)
kernel_size = tuple(config.kernel_size)
assert len(kernel_size) == 2, "len(kernel_size) must be 2"
assert kernel_size[0] % 2 == 1, "kernel_size[0] must be odd"
padding_1d = (kernel_size[1] - 1) // 2
padding_2d = [(k - 1) // 2 for k in kernel_size]
num_heads = int(config.num_heads)
num_depths_cfg = getattr(config, "num_depths_attn_start", 5)
if isinstance(num_depths_cfg, (tuple, list)):
self.num_depths, self.attn_start = num_depths_cfg
else:
self.num_depths = int(num_depths_cfg)
self.attn_start = self.num_depths # attention 없음
leaky_relu = float(config.leaky_relu)
dropout = float(config.dropout)
# self.output_size: int = config.output_size
self.convs_1d = nn.ModuleList()
self.convs_2d = nn.ModuleList()
self.vc_blocks = nn.ModuleDict() # mask + cross-attn + residual
for d in range(self.num_depths):
in_channels = 1 if d == 0 else inplanes * (2 ** (d - 1))
out_channels = inplanes * (2**d)
self.convs_1d.append(
Convolution1D_layer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size[1],
padding=padding_1d,
leaky_relu=leaky_relu,
dropout=dropout,
)
)
self.convs_2d.append(
Convolution2D_layer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding_2d,
leaky_relu=leaky_relu,
dropout=dropout,
)
)
enc_channels = out_channels * 2 # concat(conv1d, conv2d)
if d >= self.attn_start:
self.vc_blocks[str(d)] = VCBlock(
channels=enc_channels, num_heads=num_heads
)
trans_channels = inplanes * (2**self.num_depths)
self.trans_block = nn.Sequential(
nn.ConvTranspose2d(
in_channels=trans_channels,
out_channels=trans_channels,
kernel_size=kernel_size,
stride=(1, 1),
padding=padding_2d,
),
nn.BatchNorm2d(trans_channels),
nn.LeakyReLU(leaky_relu),
)
self.deconvs = nn.ModuleList()
for d in reversed(range(self.num_depths)):
in_channels = (
trans_channels
if d == self.num_depths - 1
else inplanes * 2 * (2 ** (d + 1))
)
out_channels = 1 if d == 0 else inplanes * (2**d)
# print(f"Deconvolution2D_layer.__init__[{d}]: {in_channels} {out_channels}")
self.deconvs.append(
Deconvolution2D_layer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding_2d,
leaky_relu=leaky_relu,
dropout=dropout,
)
)
# print(f"creating deconv: {in_channels} / {out_channels}")
def _downsample_mask(self, mask: torch.Tensor, target_D: int) -> torch.Tensor:
"""
mask: (B, 12, D) with 1=invisible, 0=visible
return: (B, 12, target_D) with 1/0 유지
"""
mask_down = mask.float()
mask_down = F.max_pool1d(
mask_down, kernel_size=2, stride=2
) # (B,12,floor(D/2))
if mask_down.shape[-1] != target_D:
mask_down = F.interpolate(mask_down, size=target_D, mode="nearest")
return (mask_down >= 0.5).to(mask.dtype)
def _log(self, name, x):
if self.verbose:
print(f"{name:<28}: {tuple(x.shape)}")
def make_default_group_center_mask_batch(
self, B: int, device=None, dtype=torch.int8
):
group_len = 1250 # 2.5s
vis_len = 625 # 1.25s
total_len = 5000
center_offset = (group_len - vis_len) // 2 # 312
# mask만 tensor로 생성
mask = torch.ones((12, total_len), device=device, dtype=dtype)
for g in range(4):
start = g * group_len + center_offset
end = start + vis_len
for lead in range(g * 3, g * 3 + 3):
mask[lead, start:end] = 0 # visible
return mask.unsqueeze(0).expand(B, -1, -1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
input, mask = x
B, L, D = input.shape
if mask is None:
mask = self.make_default_group_center_mask_batch(
B, device=input.device, dtype=torch.float16
)
assert (
L == 12 and D == 5000
), "this network's input must be 12 lead 5000 points digitized signal"
input = input.unsqueeze(1) # make channel
out_1d = input
out_2d = input
encs = []
mask_down = mask
# encs_visible = []
# print(f"input: {input.shape}")
self._log("input", input)
for d in range(self.num_depths):
out_1d = self.convs_1d[d](out_1d)
out_2d = self.convs_2d[d](out_2d)
enc = torch.cat((out_1d, out_2d), dim=1) # (B, 2*C, 12, D)
self._log(f"enc[{d}]", enc)
mask_down = self._downsample_mask(mask_down, enc.shape[-1])
self._log(f"mask_down[{d}]", mask_down)
key = str(d)
if key in self.vc_blocks:
enc = self.vc_blocks[key](enc, mask_down)
self._log(f"enc_refined[{d}]", enc)
encs.append(enc)
trans = self.trans_block(encs[-1])
self._log("trans", trans)
out = self.deconvs[0](trans)
self._log("out initial", out)
# combine skip connection and visible context with encoding feature
for d in range(1, self.num_depths):
skip = encs[-(d + 1)] # 아래쪽 depth부터 사용
self._log(f"skip[{d}]", skip)
out = F.interpolate(out, skip.shape[-2:], mode="nearest")
self._log(f"out upsampled[{d}]", out)
out = torch.cat((out, skip), dim=1)
self._log(f"out concat[{d}]", out)
out = self.deconvs[d](out)
self._log(f"out deconv[{d}]", out)
out = F.interpolate(out, input.shape[-2:], mode="nearest")
self._log("out final upsampled", out)
out = out.squeeze(1)
self._log("out final", out)
# return out, encs_visible
return out
if __name__ == "__main__":
def get_model_size(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
model_size_MB = total_params * 4 / (1024**2) # float32 기준 (4 bytes)
print(f"Total Parameters : {total_params:,}")
print(f"Trainable Parameters : {trainable_params:,}")
print(f"Estimated Model Size : {model_size_MB:.2f} MB")
return total_params, model_size_MB
class Config:
pass
class Activation:
pass
config = Config()
config.inplanes = 8
config.kernel_size = (7, 7)
config.num_depths_attn_start = (5, 2)
config.num_heads = 8
config.leaky_relu = 0.02
config.dropout = 0.2
config.activation = Activation()
config.activation.name = "Identity"
config.activation.params = None
input = torch.rand(size=(1, 12, 5000))
model = ECGRecoverRandomMaskWithRS5(config, True)
model.eval()
out = model([input, None])
print(out.shape)
from torchinfo import summary
# for i in range(len(encs_visible)):
# print(encs_visible[i].shape)
# summary(model, input_size=(1, 12, 5000), depth=4)
# get_model_size(model)
# from torchviz import make_dot
# # 그래프 생성
# dot = make_dot(
# out, params=dict(model.named_parameters()), show_attrs=False, show_saved=False
# )
# # 파일로 저장 (PNG)
# dot.render("ecgrecover_vc_filtermask", format="png")
# from torchview import draw_graph
# graph = draw_graph(
# model,
# input_size=(1, 12, 5000),
# expand_nested=False, # ← 내부 세부 구조 펼치지 않음 → 매우 간단
# graph_dir="TD", # top-down
# )
# graph.visual_graph.render("model_overview", format="png")