| """ |
| Self-contained Q-HEART model for HuggingFace Hub. |
| All dependencies are inlined — no external repo required. |
| |
| Q-HEART: ECG Question Answering via Knowledge-Informed Multimodal LLMs (ECAI 2025) |
| """ |
|
|
| import logging |
| import math |
| import os |
| from collections import OrderedDict |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
| from transformers import ( |
| AutoModelForCausalLM, |
| PreTrainedModel, |
| ) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.models.bert.modeling_bert import ( |
| BertConfig, |
| BertPredictionHeadTransform, |
| ) |
| try: |
| from transformers.pytorch_utils import ( |
| apply_chunking_to_forward, |
| find_pruneable_heads_and_indices, |
| prune_linear_layer, |
| ) |
| except ImportError: |
| from transformers.modeling_utils import ( |
| apply_chunking_to_forward, |
| find_pruneable_heads_and_indices, |
| prune_linear_layer, |
| ) |
| from transformers.activations import ACT2FN |
| from peft import LoraConfig, get_peft_model, TaskType |
|
|
| from .configuration_qheart import QHEARTConfig |
|
|
| logger = logging.getLogger(__name__) |
| os.environ.setdefault("CURL_CA_BUNDLE", "") |
|
|
|
|
| |
| |
| |
|
|
| class _Dropout(nn.Module): |
| def __init__(self, p, module_name=None): |
| super().__init__() |
| self.p = p |
| self.module_name = module_name |
| self.apply_during_inference = False |
|
|
| def forward(self, x, inplace: bool = False): |
| if self.p > 0 and (self.training or self.apply_during_inference): |
| return F.dropout(x, p=self.p, training=True, inplace=inplace) |
| return x |
|
|
|
|
| class _GradMultiply(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, scale): |
| ctx.scale = scale |
| return x.new(x) |
|
|
| @staticmethod |
| def backward(ctx, grad): |
| return grad * ctx.scale, None |
|
|
|
|
| class _Fp32GroupNorm(nn.GroupNorm): |
| def forward(self, input): |
| output = F.group_norm( |
| input.float(), self.num_groups, |
| self.weight.float() if self.weight is not None else None, |
| self.bias.float() if self.bias is not None else None, |
| self.eps, |
| ) |
| return output.type_as(input) |
|
|
|
|
| class _Fp32LayerNorm(nn.LayerNorm): |
| def forward(self, input): |
| output = F.layer_norm( |
| input.float(), self.normalized_shape, |
| self.weight.float() if self.weight is not None else None, |
| self.bias.float() if self.bias is not None else None, |
| self.eps, |
| ) |
| return output.type_as(input) |
|
|
|
|
| def _make_layer_norm(normalized_shape, eps=1e-5, elementwise_affine=True): |
| try: |
| from apex.normalization import FusedLayerNorm |
| return FusedLayerNorm(normalized_shape, eps, elementwise_affine) |
| except ImportError: |
| return nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
|
|
|
|
| class _TransposeLast(nn.Module): |
| def __init__(self, deconstruct_idx=None): |
| super().__init__() |
| self.deconstruct_idx = deconstruct_idx |
|
|
| def forward(self, x): |
| if self.deconstruct_idx is not None: |
| x = x[self.deconstruct_idx] |
| return x.transpose(-2, -1) |
|
|
|
|
| class _SamePad(nn.Module): |
| def __init__(self, kernel_size, causal=False): |
| super().__init__() |
| self.remove = (kernel_size - 1) if causal else (1 if kernel_size % 2 == 0 else 0) |
|
|
| def forward(self, x): |
| if self.remove > 0: |
| x = x[:, :, : -self.remove] |
| return x |
|
|
|
|
| def _quant_noise(module, p, block_size): |
| if p <= 0: |
| return module |
| assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) |
| is_conv = module.weight.ndim == 4 |
|
|
| def _hook(mod, input): |
| if mod.training: |
| weight = mod.weight |
| if not is_conv: |
| in_features, out_features = weight.size(1), weight.size(0) |
| mask = torch.zeros(in_features // block_size * out_features, device=weight.device) |
| mask.bernoulli_(p) |
| mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) |
| else: |
| if mod.kernel_size == (1, 1): |
| mask = torch.zeros(int(mod.in_channels // block_size * mod.out_channels), device=weight.device) |
| mask.bernoulli_(p) |
| mask = mask.repeat_interleave(block_size, -1).view(-1, mod.in_channels) |
| else: |
| mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) |
| mask.bernoulli_(p) |
| mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
| mask = mask.to(torch.bool) |
| mod.weight.data = (1 / (1 - p)) * weight.masked_fill(mask, 0) |
|
|
| module.register_forward_pre_hook(_hook) |
| return module |
|
|
|
|
| class _MultiHeadAttention(nn.Module): |
| def __init__(self, embed_dim, n_heads, kdim=None, vdim=None, dropout=0.0, |
| bias=True, self_attention=False, q_noise=0.0, qn_block_size=8): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.kdim = kdim if kdim is not None else embed_dim |
| self.vdim = vdim if vdim is not None else embed_dim |
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
| self.n_heads = n_heads |
| self.dropout = _Dropout(dropout, module_name=self.__class__.__name__) |
| self.d_heads = embed_dim // n_heads |
| assert self.d_heads * n_heads == embed_dim |
| self.self_attention = self_attention |
| self.k_proj = _quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size) |
| self.v_proj = _quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size) |
| self.q_proj = _quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) |
| self.out_proj = _quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size) |
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| if self.qkv_same_dim: |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
| else: |
| nn.init.xavier_uniform_(self.k_proj.weight) |
| nn.init.xavier_uniform_(self.v_proj.weight) |
| nn.init.xavier_uniform_(self.q_proj.weight) |
| nn.init.xavier_uniform_(self.out_proj.weight) |
| if self.out_proj.bias is not None: |
| nn.init.constant_(self.out_proj.bias, 0.0) |
|
|
| def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None): |
| assert key is not None and value is not None |
| return F.multi_head_attention_forward( |
| query, key, value, self.embed_dim, self.n_heads, |
| torch.empty([0]), |
| torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), |
| None, None, False, self.dropout.p, self.out_proj.weight, self.out_proj.bias, |
| self.training or self.dropout.apply_during_inference, |
| key_padding_mask, need_weights, attn_mask, |
| use_separate_proj_weight=True, |
| q_proj_weight=self.q_proj.weight, |
| k_proj_weight=self.k_proj.weight, |
| v_proj_weight=self.v_proj.weight, |
| ) |
|
|
|
|
| class _ConvFeatureExtraction(nn.Module): |
| def __init__(self, conv_layers, in_d=1, dropout=0.0, mode="default", conv_bias=False): |
| super().__init__() |
| assert mode in {"default", "layer_norm"} |
|
|
| def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False): |
| def make_conv(): |
| c = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| nn.init.kaiming_normal_(c.weight) |
| return c |
|
|
| assert not (is_layer_norm and is_group_norm) |
| if is_layer_norm: |
| return nn.Sequential( |
| make_conv(), nn.Dropout(p=dropout), |
| nn.Sequential(_TransposeLast(), _Fp32LayerNorm(dim, dim, affine=True), _TransposeLast()), |
| nn.GELU(), |
| ) |
| elif is_group_norm: |
| return nn.Sequential(make_conv(), nn.Dropout(p=dropout), _Fp32GroupNorm(dim, dim, affine=True), nn.GELU()) |
| else: |
| return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
|
|
| self.conv_layers = nn.ModuleList() |
| for i, cl in enumerate(conv_layers): |
| (dim, k, stride) = cl |
| self.conv_layers.append(block( |
| in_d, dim, k, stride, |
| is_layer_norm=mode == "layer_norm", |
| is_group_norm=mode == "default" and i == 0, |
| conv_bias=conv_bias, |
| )) |
| in_d = dim |
|
|
| def forward(self, x): |
| if len(x.shape) < 3: |
| x = x.unsqueeze(1) |
| for conv in self.conv_layers: |
| x = conv(x) |
| return x |
|
|
|
|
| class _ConvPositionalEncoding(nn.Module): |
| def __init__(self, args): |
| super().__init__() |
| self.embedding_dim = args.encoder_embed_dim |
| self.pos_conv = nn.Conv1d( |
| self.embedding_dim, self.embedding_dim, |
| kernel_size=args.conv_pos, padding=args.conv_pos // 2, |
| groups=args.conv_pos_groups, |
| ) |
| std = math.sqrt((4 * 1.0) / (args.conv_pos * self.embedding_dim)) |
| nn.init.normal_(self.pos_conv.weight, mean=0, std=std) |
| nn.init.constant_(self.pos_conv.bias, 0) |
| self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) |
| self.pos_conv = nn.Sequential(self.pos_conv, _SamePad(args.conv_pos), nn.GELU()) |
|
|
| def forward(self, x, channel_first=False): |
| if not channel_first: |
| x = x.transpose(1, 2) |
| return self.pos_conv(x).transpose(1, 2) |
|
|
|
|
| class _TransformerEncoderLayer(nn.Module): |
| def __init__(self, embed_dim=768, n_heads=12, ffn_dim=3072, dropout=0.1, |
| attention_dropout=0.1, activation_dropout=0.1, layer_norm_first=False): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.dropout = dropout |
| self.activation_dropout = activation_dropout |
|
|
| def gelu(x): |
| return F.gelu(x.float()).type_as(x) |
| self.activation_fn = gelu |
| self.self_attn = _MultiHeadAttention(embed_dim, n_heads, dropout=attention_dropout, self_attention=True) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(activation_dropout) |
| self.dropout3 = nn.Dropout(dropout) |
| self.layer_norm_first = layer_norm_first |
| self.self_attn_layer_norm = _make_layer_norm(embed_dim) |
| self.fc1 = nn.Linear(embed_dim, ffn_dim) |
| self.fc2 = nn.Linear(ffn_dim, embed_dim) |
| self.final_layer_norm = _make_layer_norm(embed_dim) |
|
|
| def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None): |
| residual = x |
| if self.layer_norm_first: |
| x = self.self_attn_layer_norm(x) |
| x, attn = self.self_attn(query=x, key=x, value=x, |
| key_padding_mask=self_attn_padding_mask, |
| attn_mask=self_attn_mask, need_weights=False) |
| x = self.dropout1(x) |
| x = residual + x |
| residual = x |
| x = self.final_layer_norm(x) |
| x = self.activation_fn(self.fc1(x)) |
| x = self.dropout2(x) |
| x = self.fc2(x) |
| layer_result = x |
| x = self.dropout3(x) |
| x = residual + x |
| else: |
| x, attn = self.self_attn(query=x, key=x, value=x, |
| key_padding_mask=self_attn_padding_mask, |
| attn_mask=self_attn_mask, need_weights=False) |
| x = self.dropout1(x) |
| x = residual + x |
| x = self.self_attn_layer_norm(x) |
| residual = x |
| x = self.activation_fn(self.fc1(x)) |
| x = self.dropout2(x) |
| x = self.fc2(x) |
| layer_result = x |
| x = self.dropout3(x) |
| x = residual + x |
| x = self.final_layer_norm(x) |
| return x, (attn, layer_result) |
|
|
|
|
| def _init_bert_params(module): |
| def normal_(data): |
| data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) |
| if isinstance(module, nn.Linear): |
| normal_(module.weight.data) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| if isinstance(module, nn.Embedding): |
| normal_(module.weight.data) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| if isinstance(module, _MultiHeadAttention): |
| normal_(module.q_proj.weight.data) |
| normal_(module.k_proj.weight.data) |
| normal_(module.v_proj.weight.data) |
|
|
|
|
| |
| |
| |
|
|
| class _SupporterLayerNorm(nn.LayerNorm): |
| def forward(self, x: torch.Tensor): |
| orig_type = x.dtype |
| return super().forward(x.type(torch.float32)).type(orig_type) |
|
|
|
|
| class _QuickGELU(nn.Module): |
| def forward(self, x): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class _ResidualAttentionBlock(nn.Module): |
| def __init__(self, d_model, n_head, attn_mask=None): |
| super().__init__() |
| self.attn = nn.MultiheadAttention(d_model, n_head) |
| self.ln_1 = _SupporterLayerNorm(d_model) |
| self.mlp = nn.Sequential(OrderedDict([ |
| ("c_fc", nn.Linear(d_model, d_model * 4)), |
| ("gelu", _QuickGELU()), |
| ("c_proj", nn.Linear(d_model * 4, d_model)), |
| ])) |
| self.ln_2 = _SupporterLayerNorm(d_model) |
| self.attn_mask = attn_mask |
|
|
| def attention(self, x, x_mask): |
| if x_mask is not None: |
| x_mask = x_mask.to(dtype=torch.bool, device=x.device) |
| attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, key_padding_mask=x_mask)[0] |
|
|
| def forward(self, x, x_mask=None): |
| x = x + self.attention(self.ln_1(x), x_mask) |
| x = x + self.mlp(self.ln_2(x)) |
| return x |
|
|
|
|
| class _SupporterTransformer(nn.Module): |
| def __init__(self, width, layers, heads, attn_mask=None): |
| super().__init__() |
| self.resblocks = nn.Sequential( |
| *[_ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers - 1)] |
| ) |
|
|
| def forward(self, x, x_mask=None): |
| for block in self.resblocks: |
| x = block(x, x_mask) |
| return x |
|
|
|
|
| class _PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len): |
| super().__init__() |
| position = torch.arange(max_len).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
| pe = torch.zeros(1, max_len, d_model) |
| pe[0, :, 0::2] = torch.sin(position * div_term) |
| pe[0, :, 1::2] = torch.cos(position * div_term) |
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x): |
| return x + self.pe[:, : x.size(1)] |
|
|
|
|
| class _Pooler(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.dense = nn.Linear(hidden_size, hidden_size) |
| self.activation = nn.Tanh() |
|
|
| def forward(self, hidden_states): |
| return self.activation(self.dense(hidden_states[:, 0])) |
|
|
|
|
| |
| |
| |
| |
|
|
| class _TransformerEncoder(nn.Module): |
| def __init__(self, args): |
| super().__init__() |
| self.dropout = args.dropout |
| self.embed_dim = args.encoder_embed_dim |
| self.layers = nn.ModuleList([ |
| _TransformerEncoderLayer( |
| embed_dim=self.embed_dim, |
| ffn_dim=args.encoder_ffn_embed_dim, |
| n_heads=args.encoder_attention_heads, |
| dropout=self.dropout, |
| attention_dropout=args.attention_dropout, |
| activation_dropout=args.activation_dropout, |
| layer_norm_first=args.layer_norm_first, |
| ) |
| for _ in range(args.encoder_layers) |
| ]) |
| self.layer_norm_first = args.layer_norm_first |
| self.layer_norm = _make_layer_norm(self.embed_dim) |
| self.layerdrop = args.encoder_layerdrop |
| self.apply(_init_bert_params) |
|
|
| def forward(self, x, padding_mask=None, attn_mask=None): |
| x = self._extract_features(x, padding_mask, attn_mask) |
| if self.layer_norm_first: |
| x = self.layer_norm(x) |
| return x |
|
|
| def _extract_features(self, x, padding_mask=None, attn_mask=None): |
| if padding_mask is not None: |
| x[padding_mask] = 0 |
| if not self.layer_norm_first: |
| x = self.layer_norm(x) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = x.transpose(0, 1) |
| for layer in self.layers: |
| dropout_probability = np.random.random() |
| if not self.training or dropout_probability > self.layerdrop: |
| x, z = layer(x, self_attn_padding_mask=padding_mask, |
| self_attn_mask=attn_mask, need_weights=False) |
| return x.transpose(0, 1) |
|
|
|
|
| class _ECGTransformerModel(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.mask_prob = cfg.mask_prob |
| self.mask_selection = cfg.mask_selection |
| self.mask_other = cfg.mask_other |
| self.mask_length = cfg.mask_length |
| self.no_mask_overlap = cfg.no_mask_overlap |
| self.mask_min_space = cfg.mask_min_space |
| self.mask_channel_prob = cfg.mask_channel_prob |
| self.mask_channel_selection = cfg.mask_channel_selection |
| self.mask_channel_other = cfg.mask_channel_other |
| self.mask_channel_length = cfg.mask_channel_length |
| self.no_mask_channel_overlap = cfg.no_mask_channel_overlap |
| self.mask_channel_min_space = cfg.mask_channel_min_space |
|
|
| if cfg.apply_mask: |
| self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) |
|
|
| self.dropout_input = nn.Dropout(cfg.dropout_input) |
| self.dropout_features = nn.Dropout(cfg.dropout_features) |
| self.num_updates = 0 |
|
|
| feature_enc_layers = eval(cfg.conv_feature_layers) |
| self.embed = feature_enc_layers[-1][0] |
|
|
| self.feature_extractor = _ConvFeatureExtraction( |
| conv_layers=feature_enc_layers, in_d=cfg.in_d, |
| dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias, |
| ) |
| self.post_extract_proj = ( |
| nn.Linear(self.embed, cfg.encoder_embed_dim) |
| if self.embed != cfg.encoder_embed_dim else None |
| ) |
| self.feature_grad_mult = cfg.feature_grad_mult |
| self.conv_pos = _ConvPositionalEncoding(cfg) |
| self.layer_norm = _make_layer_norm(self.embed) |
| self.encoder = _TransformerEncoder(cfg) |
|
|
| def _get_feat_extract_output_lengths(self, input_lengths): |
| def _conv_out_length(input_length, kernel_size, stride): |
| return torch.floor((input_length - kernel_size) / stride + 1) |
| for cl in eval(self.cfg.conv_feature_layers): |
| input_lengths = _conv_out_length(input_lengths, cl[1], cl[2]) |
| return input_lengths.to(torch.long) |
|
|
| def get_embeddings(self, source, padding_mask): |
| """Returns (x, padding_mask, x_conv) — 3 values.""" |
| if self.feature_grad_mult > 0: |
| features = self.feature_extractor(source) |
| if self.feature_grad_mult != 1.0: |
| features = _GradMultiply.apply(features, self.feature_grad_mult) |
| else: |
| with torch.no_grad(): |
| features = self.feature_extractor(source) |
|
|
| features = features.transpose(1, 2) |
| features = self.layer_norm(features) |
|
|
| if padding_mask is not None and padding_mask.any(): |
| input_lengths = (1 - padding_mask.long()).sum(-1) |
| if input_lengths.dim() > 1: |
| input_lengths = input_lengths[:, 0] |
| output_lengths = self._get_feat_extract_output_lengths(input_lengths) |
| padding_mask = torch.zeros(features.shape[:2], dtype=features.dtype, device=features.device) |
| padding_mask[ |
| (torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1) |
| ] = 1 |
| padding_mask[torch.where(output_lengths == 0)] = 0 |
| padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() |
| else: |
| padding_mask = None |
|
|
| if self.post_extract_proj is not None: |
| features = self.post_extract_proj(features) |
|
|
| features = self.dropout_input(features) |
| x_conv = self.conv_pos(features, channel_first=False) |
| x = features + x_conv |
| return x, padding_mask, x_conv |
|
|
| def get_output(self, x, padding_mask=None): |
| return self.encoder(x, padding_mask=padding_mask) |
|
|
|
|
| |
| |
| |
|
|
| class _BertSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.num_attention_heads = config.num_attention_heads |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
| self.all_head_size = self.num_attention_heads * self.attention_head_size |
| self.query = nn.Linear(config.hidden_size, self.all_head_size) |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
| self.is_decoder = config.is_decoder |
|
|
| def transpose_for_scores(self, x): |
| return x.view(*x.size()[:-1], self.num_attention_heads, self.attention_head_size).permute(0, 2, 1, 3) |
|
|
| def forward(self, hidden_states, attention_mask=None, head_mask=None, |
| encoder_hidden_states=None, encoder_attention_mask=None, |
| past_key_value=None, output_attentions=False): |
| mixed_query_layer = self.query(hidden_states) |
| is_cross_attention = encoder_hidden_states is not None |
|
|
| if is_cross_attention: |
| key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) |
| attention_mask = encoder_attention_mask |
| else: |
| key_layer = self.transpose_for_scores(self.key(hidden_states)) |
| value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
| query_layer = self.transpose_for_scores(mixed_query_layer) |
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
| if attention_mask is not None: |
| attention_scores = attention_scores + attention_mask |
| attention_probs = nn.Softmax(dim=-1)(attention_scores) |
| attention_probs = self.dropout(attention_probs) |
| if head_mask is not None: |
| attention_probs = attention_probs * head_mask |
|
|
| context_layer = torch.matmul(attention_probs, value_layer) |
| context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| context_layer = context_layer.view(*context_layer.size()[:-2], self.all_head_size) |
| outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
| return outputs |
|
|
|
|
| class _BertSelfOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states, input_tensor): |
| return self.LayerNorm(self.dropout(self.dense(hidden_states)) + input_tensor) |
|
|
|
|
| class _BertAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.self = _BertSelfAttention(config) |
| self.output = _BertSelfOutput(config) |
| self.pruned_heads = set() |
|
|
| def forward(self, hidden_states, attention_mask=None, head_mask=None, |
| encoder_hidden_states=None, encoder_attention_mask=None, |
| past_key_value=None, output_attentions=False): |
| self_outputs = self.self(hidden_states, attention_mask, head_mask, |
| encoder_hidden_states, encoder_attention_mask, |
| past_key_value, output_attentions) |
| return (self.output(self_outputs[0], hidden_states),) + self_outputs[1:] |
|
|
|
|
| class _BertIntermediate(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act |
|
|
| def forward(self, hidden_states): |
| return self.intermediate_act_fn(self.dense(hidden_states)) |
|
|
|
|
| class _BertOutput(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
| def forward(self, hidden_states, input_tensor): |
| return self.LayerNorm(self.dropout(self.dense(hidden_states)) + input_tensor) |
|
|
|
|
| class _BertCrossLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward |
| self.seq_len_dim = 1 |
| self.attention = _BertAttention(config) |
| self.is_decoder = config.is_decoder |
| self.add_cross_attention = config.add_cross_attention |
| self.crossattention = _BertAttention(config) |
| self.intermediate = _BertIntermediate(config) |
| self.output = _BertOutput(config) |
|
|
| def forward(self, hidden_states, encoder_hidden_states, attention_mask=None, |
| encoder_attention_mask=None, output_attentions=False): |
| self_attention_outputs = self.attention(hidden_states, attention_mask, |
| head_mask=None, output_attentions=output_attentions) |
| attention_output = self_attention_outputs[0] |
| outputs = self_attention_outputs[1:] |
| cross_attention_outputs = self.crossattention( |
| attention_output, attention_mask, None, |
| encoder_hidden_states, encoder_attention_mask, None, output_attentions, |
| ) |
| attention_output = cross_attention_outputs[0] |
| outputs = outputs + cross_attention_outputs[1:] |
| layer_output = apply_chunking_to_forward( |
| self._feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output |
| ) |
| return (layer_output,) + outputs |
|
|
| def _feed_forward_chunk(self, attention_output): |
| return self.output(self.intermediate(attention_output), attention_output) |
|
|
|
|
| |
| |
| |
|
|
| def _init_weights(module): |
| if isinstance(module, (nn.Linear, nn.Embedding)): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
| if isinstance(module, nn.Linear) and module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class _MLMHead(nn.Module): |
| def __init__(self, config, weight=None): |
| super().__init__() |
| self.transform = BertPredictionHeadTransform(config) |
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
| if weight is not None: |
| self.decoder.weight = weight |
|
|
| def forward(self, x): |
| return self.decoder(self.transform(x)) + self.bias |
|
|
|
|
| class _MIMHead(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.hidden_dim = cfg.hidden_dim |
| self.decoder_hidden_dim = cfg.mim_decoder_hidden_dim |
| self.decoder_embed = nn.Linear(self.hidden_dim, self.decoder_hidden_dim, bias=True) |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_hidden_dim)) |
| torch.nn.init.normal_(self.mask_token, std=0.02) |
| self.decoder_pos_embed = _PositionalEncoding(self.decoder_hidden_dim, max_len=512) |
| self.decoder = _SupporterTransformer(self.decoder_hidden_dim, cfg.mim_decoder_num_layers + 1, cfg.mim_decoder_num_heads) |
| self.decoder_norm = _SupporterLayerNorm(self.decoder_hidden_dim) |
|
|
| def _conv_out_length(il, k, s): |
| return np.floor((il - k) / s + 1) |
|
|
| inferred = 5000 |
| for cl in eval(cfg.conv_feature_layers): |
| inferred = _conv_out_length(inferred, cl[1], cl[2]) |
| self.inferred_decoded_size = int(np.floor(5000 / inferred)) |
| self.decoder_pred = nn.Linear(self.decoder_hidden_dim, self.inferred_decoded_size * 12, bias=True) |
|
|
| def forward(self, x, ids_restore): |
| x = self.decoder_embed(x) |
| mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| x = torch.cat([x[:, :1, :], x_], dim=1) |
| x = self.decoder_pos_embed(x) |
| x = self.decoder(x.permute(1, 0, 2)).permute(1, 0, 2) |
| x = self.decoder_norm(x) |
| x = self.decoder_pred(x)[:, 1:, :] |
| return x.view(x.size(0), x.size(1), -1, self.inferred_decoded_size) |
|
|
|
|
| class _ITMHead(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.fc = nn.Linear(hidden_size, 2) |
|
|
| def forward(self, x): |
| return self.fc(x) |
|
|
|
|
| class M3AEModel(nn.Module): |
| """ECG encoder from Q-HEART (identical role to DBETA in D-BETA).""" |
|
|
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.vocab_size = cfg.vocab_size |
| self.mim_prob = cfg.mim_prob |
| self.mim_layer = cfg.mim_layer |
|
|
| self.ecg_encoder = _ECGTransformerModel(cfg) |
| self.class_embedding = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) |
|
|
| from transformers import T5EncoderModel |
| self.language_encoder = T5EncoderModel.from_pretrained("google/flan-t5-base") |
| self.language_encoder.pooler = None |
|
|
| self.multi_modal_language_proj = nn.Linear(cfg.encoder_embed_dim, cfg.hidden_dim) |
| self.multi_modal_language_proj.apply(_init_weights) |
| self.multi_modal_ecg_proj = nn.Linear(cfg.encoder_embed_dim, cfg.hidden_dim) |
| self.multi_modal_ecg_proj.apply(_init_weights) |
|
|
| self.modality_type_embeddings = nn.Embedding(2, cfg.hidden_dim) |
| self.modality_type_embeddings.apply(_init_weights) |
|
|
| bert_config = BertConfig( |
| vocab_size=cfg.vocab_size, |
| hidden_size=cfg.hidden_dim, |
| num_hidden_layers=cfg.num_layers, |
| num_attention_heads=cfg.num_heads, |
| intermediate_size=cfg.hidden_dim * 4, |
| max_position_embeddings=cfg.max_text_size, |
| hidden_dropout_prob=cfg.drop_rate, |
| attention_probs_dropout_prob=cfg.drop_rate, |
| ) |
| self.multi_modal_ecg_layers = nn.ModuleList([_BertCrossLayer(bert_config) for _ in range(cfg.num_top_layer)]) |
| self.multi_modal_ecg_layers.apply(_init_weights) |
| self.multi_modal_language_layers = nn.ModuleList([_BertCrossLayer(bert_config) for _ in range(cfg.num_top_layer)]) |
| self.multi_modal_language_layers.apply(_init_weights) |
|
|
| self.multi_modal_ecg_pooler = _Pooler(cfg.hidden_dim) |
| self.multi_modal_ecg_pooler.apply(_init_weights) |
| self.multi_modal_language_pooler = _Pooler(cfg.hidden_dim) |
| self.multi_modal_language_pooler.apply(_init_weights) |
| self.unimodal_ecg_pooler = _Pooler(cfg.hidden_dim) |
| self.unimodal_ecg_pooler.apply(_init_weights) |
| self.unimodal_language_pooler = _Pooler(cfg.hidden_dim) |
| self.unimodal_language_pooler.apply(_init_weights) |
|
|
| self.mlm_head = _MLMHead(bert_config) |
| self.mlm_head.apply(_init_weights) |
| self.mim_head = _MIMHead(cfg) |
| self.mim_head.apply(_init_weights) |
| self.itm_head = _ITMHead(cfg.hidden_dim * 2) |
| self.itm_head.apply(_init_weights) |
|
|
| def remove_pretraining_modules(self): |
| self.mlm_head = None |
| self.mim_head = None |
| self.itm_head = None |
| self.language_encoder = None |
| self.multi_modal_language_layers = None |
| self.multi_modal_ecg_layers = None |
|
|
|
|
| |
| |
| |
|
|
| class _MlpTransformer(nn.Module): |
| def __init__(self, in_dim, h_dim, out_d=None, act=F.relu, dropout=0.): |
| super().__init__() |
| out_d = out_d if out_d is not None else in_dim |
| self.fc1 = nn.Linear(in_dim, h_dim) |
| self.act = act |
| self.fc2 = nn.Linear(h_dim, out_d) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| x = self.act(self.fc1(x)) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| return self.dropout(x) |
|
|
|
|
| class _MapperMultiHeadAttention(nn.Module): |
| def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim_self // num_heads |
| self.scale = head_dim ** -0.5 |
| self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) |
| self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) |
| self.project = nn.Linear(dim_self, dim_self) |
| self.dropout = nn.Dropout(dropout) |
|
|
| def forward(self, x, y=None, mask=None): |
| y = y if y is not None else x |
| b, n, c = x.shape |
| _, m, d = y.shape |
| queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) |
| keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) |
| keys, values = keys_values[:, :, 0], keys_values[:, :, 1] |
| attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale |
| if mask is not None: |
| if mask.dim() == 2: |
| mask = mask.unsqueeze(1) |
| attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) |
| attention = attention.softmax(dim=2) |
| out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c) |
| return self.project(out), attention |
|
|
|
|
| class _MapperTransformerLayer(nn.Module): |
| def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., |
| act=F.relu, norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.norm1 = norm_layer(dim_self) |
| self.attn = _MapperMultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) |
| self.norm2 = norm_layer(dim_self) |
| self.mlp = _MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) |
|
|
| def forward(self, x, y=None, mask=None): |
| x = x + self.attn(self.norm1(x), y, mask)[0] |
| return x + self.mlp(self.norm2(x)) |
|
|
|
|
| class _MapperTransformer(nn.Module): |
| def __init__(self, dim_self, num_heads, num_layers, dim_ref=None, mlp_ratio=2., |
| act=F.relu, norm_layer=nn.LayerNorm, enc_dec=False): |
| super().__init__() |
| dim_ref = dim_ref if dim_ref is not None else dim_self |
| self.enc_dec = enc_dec |
| layers = [] |
| for i in range(num_layers * 2 if enc_dec else num_layers): |
| if enc_dec and i % 2 == 0: |
| layers.append(_MapperTransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) |
| elif enc_dec: |
| layers.append(_MapperTransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) |
| else: |
| layers.append(_MapperTransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) |
| self.layers = nn.ModuleList(layers) |
|
|
| def forward(self, x, y=None, mask=None): |
| for i, layer in enumerate(self.layers): |
| if i % 2 == 0 and self.enc_dec: |
| x = layer(x, y) |
| elif self.enc_dec: |
| x = layer(x, x, mask) |
| else: |
| x = layer(x, y, mask) |
| return x |
|
|
|
|
| class TransformerMapper(nn.Module): |
| def __init__(self, dim_clip, dim_embedding, prefix_length, clip_length, num_layers=4, num_heads=4): |
| super().__init__() |
| self.clip_length = clip_length |
| self.transformer = _MapperTransformer(dim_embedding, num_heads, num_layers) |
| self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) |
| self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True) |
|
|
| def forward(self, x): |
| x = self.linear(x).view(x.shape[0], self.clip_length, -1) |
| prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) |
| prefix = torch.cat((x, prefix), dim=1) |
| return self.transformer(prefix)[:, self.clip_length:] |
|
|
|
|
| class AttentionMapper(nn.Module): |
| def __init__(self, dim=786, output_dim=2048, num_heads=8, dim_head=64): |
| super().__init__() |
| self.num_heads = num_heads |
| self.scale = dim_head ** -0.5 |
| self.dim_head = dim_head |
| self.inner_dim = num_heads * dim_head |
| self.ecg_projection_layer = nn.Linear(dim, output_dim) |
| self.norm_ecg = nn.LayerNorm(output_dim) |
| self.norm_query = nn.LayerNorm(output_dim) |
| self.to_q = nn.Linear(output_dim, self.inner_dim, bias=False) |
| self.to_kv = nn.Linear(output_dim, self.inner_dim * 2, bias=False) |
| self.to_out = nn.Linear(self.inner_dim, output_dim, bias=False) |
|
|
| def forward(self, ecg_features, query_features, prefix_len=None): |
| ecg_features = self.norm_ecg(self.ecg_projection_layer(ecg_features)) |
| normed_query = self.norm_query(query_features[:, :prefix_len, :] if prefix_len is not None else query_features) |
| q = self.to_q(ecg_features) |
| k, v = self.to_kv(normed_query).chunk(2, dim=-1) |
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v)) |
| scores = torch.einsum("bhqd, bhkd -> bhqk", q, k) * self.scale |
| attn = scores.softmax(dim=-1) |
| out = rearrange(torch.einsum("bhqk, bhvd -> bhqd", attn, v), "b h n d -> b n (h d)") |
| return torch.cat((self.to_out(out), query_features), dim=1) |
|
|
|
|
| class MoEMapper(nn.Module): |
| def __init__(self, input_dim, output_dim, num_experts=12): |
| super().__init__() |
| self.num_experts = num_experts |
| self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, output_dim)) for _ in range(num_experts)]) |
| self.text_gate = nn.Linear(output_dim, num_experts) |
| self.output_dim = output_dim |
|
|
| def forward(self, x, t): |
| B = x.size(0) |
| x_flat = x.squeeze(1) |
| gate_logits = self.text_gate(t.mean(dim=1)) |
| top1_indices = gate_logits.argmax(dim=-1) |
| moe_out = torch.zeros(B, self.output_dim, device=x.device, dtype=x.dtype) |
| for expert_id in range(self.num_experts): |
| mask = (top1_indices == expert_id) |
| if mask.any(): |
| moe_out[mask] = self.experts[expert_id](x_flat[mask]) |
| return moe_out.reshape(B, 1, self.output_dim) |
|
|
|
|
| class _MLPBlock(nn.Module): |
| def __init__(self, dim, hidden_dim): |
| super().__init__() |
| self.mlp = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim)) |
|
|
| def forward(self, x): |
| return self.mlp(x) |
|
|
|
|
| class MLPMixer(nn.Module): |
| def __init__(self, num_tokens=12, input_dim=768, llm_embedding_size=1024, num_layers=4): |
| super().__init__() |
| self.token_mixer = nn.ModuleList([_MLPBlock(num_tokens, input_dim) for _ in range(num_layers)]) |
| self.channel_mixer = nn.ModuleList([_MLPBlock(input_dim, input_dim) for _ in range(num_layers)]) |
| self.last = nn.Linear(input_dim, llm_embedding_size) |
| self.norm = nn.LayerNorm(input_dim) |
|
|
| def forward(self, x): |
| for token_mlp, channel_mlp in zip(self.token_mixer, self.channel_mixer): |
| x = x + token_mlp(x.transpose(1, 2)).transpose(1, 2) |
| x = x + channel_mlp(x) |
| return self.last(self.norm(x)) |
|
|
|
|
| |
| |
| |
|
|
| class CustomECGQAModel(nn.Module): |
| def __init__(self, ecg_encoder, mapping_type="Transformer", setting="lora", |
| prefix_length=12, clip_length=12, llm_model_type="meta-llama/Llama-3.2-1B-Instruct"): |
| super().__init__() |
| self.mapping_type = mapping_type |
| self.setting = setting |
| self.llm_type = llm_model_type |
|
|
| self.llm = AutoModelForCausalLM.from_pretrained(self.llm_type) |
| self.llm_embedding_size = self.llm.config.hidden_size |
|
|
| if setting == "lora": |
| peft_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False, |
| r=8, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], |
| ) |
| self.llm = get_peft_model(self.llm, peft_config) |
| elif setting == "frozen": |
| for param in self.llm.parameters(): |
| param.requires_grad = False |
| self.llm.eval() |
|
|
| self.ecg_encoder = ecg_encoder |
| self.postconv = nn.Conv1d(in_channels=312, out_channels=12, kernel_size=1) |
|
|
| self.ecg_token_nums = prefix_length if mapping_type in ["Transformer", "MLPMixer"] else 1 |
| self.ecg_feature_dim = 768 |
|
|
| if mapping_type == "MLP": |
| self.ecg_projection_layer = nn.Linear(self.ecg_feature_dim, self.llm_embedding_size) |
| elif mapping_type == "MLPMixer": |
| self.ecg_projection_layer = MLPMixer(input_dim=self.ecg_feature_dim, llm_embedding_size=self.llm_embedding_size) |
| elif mapping_type == "Attention": |
| self.ecg_projection_layer = AttentionMapper(dim=self.ecg_feature_dim, output_dim=self.llm_embedding_size) |
| elif mapping_type == "Transformer": |
| self.ecg_projection_layer = TransformerMapper( |
| dim_clip=self.ecg_feature_dim, |
| dim_embedding=self.llm_embedding_size, |
| prefix_length=self.ecg_token_nums, |
| clip_length=clip_length, |
| num_heads=4, |
| num_layers=2, |
| ) |
| self.postlinear = nn.Linear(self.ecg_feature_dim, self.llm_embedding_size) |
| elif mapping_type == "MOE": |
| self.ecg_projection_layer = MoEMapper(self.ecg_feature_dim, self.llm_embedding_size) |
|
|
| def _get_ecg_features(self, ecg): |
| uni_modal_ecg_feats, ecg_padding_mask, conv_embedd = ( |
| self.ecg_encoder.ecg_encoder.get_embeddings(ecg, padding_mask=None) |
| ) |
| cls_emb = self.ecg_encoder.class_embedding.repeat(len(uni_modal_ecg_feats), 1, 1) |
| uni_modal_ecg_feats = torch.cat([cls_emb, uni_modal_ecg_feats], dim=1) |
| uni_modal_ecg_feats = self.ecg_encoder.ecg_encoder.get_output(uni_modal_ecg_feats, ecg_padding_mask) |
| out = self.ecg_encoder.multi_modal_ecg_proj(uni_modal_ecg_feats) |
| ecg_features = self.ecg_encoder.unimodal_ecg_pooler(out) |
| return ecg_features, conv_embedd |
|
|
| def _build_inputs_embeds(self, input_ids, ecg): |
| ecg_features, conv_embedd = self._get_ecg_features(ecg) |
| ecg_features = ecg_features.reshape(input_ids.shape[0], 1, -1) |
|
|
| if self.mapping_type == "MOE": |
| embeddings = self.llm.get_input_embeddings()(input_ids) |
| ecg_features_projected = self.ecg_projection_layer(ecg_features, embeddings) |
| elif self.mapping_type == "Attention": |
| embeddings = self.llm.get_input_embeddings()(input_ids) |
| return self.ecg_projection_layer(ecg_features, embeddings, None), None |
| else: |
| ecg_features_projected = self.ecg_projection_layer(ecg_features) |
| if self.mapping_type == "Transformer": |
| ecg_features_projected = ecg_features_projected + self.postlinear(self.postconv(conv_embedd)) |
| embeddings = self.llm.get_input_embeddings()(input_ids) |
|
|
| return torch.cat((ecg_features_projected, embeddings), dim=1), ecg_features_projected.shape[1] |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, ecg=None, **kwargs): |
| inputs_embeds, n_ecg_tokens = self._build_inputs_embeds(input_ids, ecg) |
| device = input_ids.device |
|
|
| if attention_mask is not None: |
| attention_mask = torch.cat( |
| (torch.ones((input_ids.size(0), self.ecg_token_nums), dtype=attention_mask.dtype, device=device), |
| attention_mask), dim=1, |
| ) |
| if labels is not None: |
| labels = torch.cat( |
| (torch.full((labels.size(0), self.ecg_token_nums), -100, dtype=labels.dtype, device=device), |
| labels), dim=1, |
| ) |
| return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels) |
|
|
| def generate(self, input_ids=None, attention_mask=None, ecg=None, max_length=50, **kwargs): |
| inputs_embeds, _ = self._build_inputs_embeds(input_ids, ecg.reshape(-1, 12, 5000)) |
| device = input_ids.device |
|
|
| if attention_mask is not None: |
| attention_mask = torch.cat( |
| (torch.ones((input_ids.size(0), self.ecg_token_nums), dtype=attention_mask.dtype, device=device), |
| attention_mask), dim=1, |
| ) |
| return self.llm.generate( |
| inputs_embeds=inputs_embeds, |
| max_length=inputs_embeds.shape[1] + max_length, |
| attention_mask=attention_mask, |
| **kwargs, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class QHEARTForECGQA(PreTrainedModel): |
| """ |
| Q-HEART: ECG Question Answering model wrapped as a HuggingFace PreTrainedModel. |
| |
| Combines a 12-lead ECG encoder (M3AEModel) with a causal LLM (Llama/Gemma/etc.) |
| via a learned mapping layer (ET-Mapper). |
| |
| Example:: |
| |
| from transformers import AutoModel, AutoTokenizer |
| import torch |
| |
| model = AutoModel.from_pretrained("Manhph2211/Q-HEART", trust_remote_code=True) |
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct") |
| model.eval() |
| |
| ecg = torch.randn(1, 12, 5000) # [batch, leads, length] at 500 Hz |
| question = "What is the heart rhythm shown in this ECG?" |
| inputs = tokenizer(question, return_tensors="pt") |
| |
| with torch.no_grad(): |
| output_ids = model.generate( |
| ecg=ecg, |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_new_tokens=50, |
| ) |
| print(tokenizer.decode(output_ids[0], skip_special_tokens=True)) |
| """ |
|
|
| config_class = QHEARTConfig |
|
|
| def __init__(self, config: QHEARTConfig): |
| super().__init__(config) |
| ecg_encoder = M3AEModel(config) |
| ecg_encoder.remove_pretraining_modules() |
| self.model = CustomECGQAModel( |
| ecg_encoder=ecg_encoder, |
| mapping_type=config.mapping_type, |
| setting="lora", |
| prefix_length=config.prefix_length, |
| clip_length=config.clip_length, |
| llm_model_type=config.llm_model_type, |
| ) |
| self.post_init() |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """ |
| Load from a HuggingFace Hub repo or local directory. |
| Supports ``pytorch_model.bin`` (standard) and ``sample.bin`` (original format). |
| Both are flat state dicts — no nested key handling needed. |
| """ |
| import os |
| from transformers.utils import cached_file |
|
|
| config = kwargs.pop("config", None) |
| cache_dir = kwargs.get("cache_dir", None) |
| token = kwargs.get("token", kwargs.get("use_auth_token", None)) |
| revision = kwargs.get("revision", None) |
| local_files_only = kwargs.get("local_files_only", False) |
| device_map = kwargs.pop("device_map", None) |
|
|
| if config is None: |
| config = QHEARTConfig.from_pretrained(pretrained_model_name_or_path, **{ |
| k: v for k, v in kwargs.items() |
| if k in ("cache_dir", "token", "use_auth_token", "revision", "local_files_only", "trust_remote_code") |
| }) |
|
|
| model = cls(config) |
|
|
| resolved_path = None |
| for fname in ("pytorch_model.bin", "sample.bin"): |
| try: |
| if os.path.isdir(pretrained_model_name_or_path): |
| candidate = os.path.join(pretrained_model_name_or_path, fname) |
| if os.path.isfile(candidate): |
| resolved_path = candidate |
| break |
| else: |
| resolved_path = cached_file( |
| pretrained_model_name_or_path, fname, |
| cache_dir=cache_dir, token=token, |
| revision=revision, local_files_only=local_files_only, |
| ) |
| if resolved_path: |
| break |
| except Exception: |
| continue |
|
|
| if resolved_path is None: |
| logger.warning("No checkpoint found (pytorch_model.bin or sample.bin). Returning model with random weights.") |
| return model |
|
|
| state_dict = torch.load(resolved_path, map_location="cpu") |
| |
| if isinstance(state_dict, dict) and "model" in state_dict and not any( |
| k.startswith("model.") or k.startswith("ecg_encoder.") or k.startswith("llm.") |
| for k in state_dict.keys() |
| ): |
| state_dict = state_dict["model"] |
|
|
| missing, unexpected = model.model.load_state_dict(state_dict, strict=False) |
| if missing: |
| logger.warning(f"Missing keys: {missing}") |
| if unexpected: |
| logger.warning(f"Unexpected keys: {unexpected}") |
|
|
| logger.info(f"Loaded Q-HEART weights from {resolved_path}") |
| if device_map is not None: |
| model = model.to(device_map) |
| return model |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, ecg=None, **kwargs): |
| return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, ecg=ecg) |
|
|
| def generate(self, input_ids=None, attention_mask=None, ecg=None, max_new_tokens=50, **kwargs): |
| return self.model.generate( |
| input_ids=input_ids, attention_mask=attention_mask, |
| ecg=ecg, max_length=max_new_tokens, **kwargs, |
| ) |
|
|