Q-HEART / modeling_qheart.py
Manhph2211's picture
Update modeling_qheart.py
ad755f6 verified
"""
Self-contained Q-HEART model for HuggingFace Hub.
All dependencies are inlined — no external repo required.
Q-HEART: ECG Question Answering via Knowledge-Informed Multimodal LLMs (ECAI 2025)
"""
import logging
import math
import os
from collections import OrderedDict
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import (
AutoModelForCausalLM,
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.bert.modeling_bert import (
BertConfig,
BertPredictionHeadTransform,
)
try:
from transformers.pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
except ImportError:
from transformers.modeling_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from transformers.activations import ACT2FN
from peft import LoraConfig, get_peft_model, TaskType
from .configuration_qheart import QHEARTConfig
logger = logging.getLogger(__name__)
os.environ.setdefault("CURL_CA_BUNDLE", "")
# ---------------------------------------------------------------------------
# Low-level helpers (inlined from models/ecg_encoder/modules/)
# ---------------------------------------------------------------------------
class _Dropout(nn.Module):
def __init__(self, p, module_name=None):
super().__init__()
self.p = p
self.module_name = module_name
self.apply_during_inference = False
def forward(self, x, inplace: bool = False):
if self.p > 0 and (self.training or self.apply_during_inference):
return F.dropout(x, p=self.p, training=True, inplace=inplace)
return x
class _GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
return x.new(x)
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
class _Fp32GroupNorm(nn.GroupNorm):
def forward(self, input):
output = F.group_norm(
input.float(), self.num_groups,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
class _Fp32LayerNorm(nn.LayerNorm):
def forward(self, input):
output = F.layer_norm(
input.float(), self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
)
return output.type_as(input)
def _make_layer_norm(normalized_shape, eps=1e-5, elementwise_affine=True):
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
class _TransposeLast(nn.Module):
def __init__(self, deconstruct_idx=None):
super().__init__()
self.deconstruct_idx = deconstruct_idx
def forward(self, x):
if self.deconstruct_idx is not None:
x = x[self.deconstruct_idx]
return x.transpose(-2, -1)
class _SamePad(nn.Module):
def __init__(self, kernel_size, causal=False):
super().__init__()
self.remove = (kernel_size - 1) if causal else (1 if kernel_size % 2 == 0 else 0)
def forward(self, x):
if self.remove > 0:
x = x[:, :, : -self.remove]
return x
def _quant_noise(module, p, block_size):
if p <= 0:
return module
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
is_conv = module.weight.ndim == 4
def _hook(mod, input):
if mod.training:
weight = mod.weight
if not is_conv:
in_features, out_features = weight.size(1), weight.size(0)
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
else:
if mod.kernel_size == (1, 1):
mask = torch.zeros(int(mod.in_channels // block_size * mod.out_channels), device=weight.device)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, mod.in_channels)
else:
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
mask.bernoulli_(p)
mask = mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
mask = mask.to(torch.bool)
mod.weight.data = (1 / (1 - p)) * weight.masked_fill(mask, 0)
module.register_forward_pre_hook(_hook)
return module
class _MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, n_heads, kdim=None, vdim=None, dropout=0.0,
bias=True, self_attention=False, q_noise=0.0, qn_block_size=8):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.n_heads = n_heads
self.dropout = _Dropout(dropout, module_name=self.__class__.__name__)
self.d_heads = embed_dim // n_heads
assert self.d_heads * n_heads == embed_dim
self.self_attention = self_attention
self.k_proj = _quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
self.v_proj = _quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
self.q_proj = _quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
self.out_proj = _quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
self._reset_parameters()
def _reset_parameters(self):
if self.qkv_same_dim:
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query, key, value, self.embed_dim, self.n_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
None, None, False, self.dropout.p, self.out_proj.weight, self.out_proj.bias,
self.training or self.dropout.apply_during_inference,
key_padding_mask, need_weights, attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
class _ConvFeatureExtraction(nn.Module):
def __init__(self, conv_layers, in_d=1, dropout=0.0, mode="default", conv_bias=False):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
def make_conv():
c = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(c.weight)
return c
assert not (is_layer_norm and is_group_norm)
if is_layer_norm:
return nn.Sequential(
make_conv(), nn.Dropout(p=dropout),
nn.Sequential(_TransposeLast(), _Fp32LayerNorm(dim, dim, affine=True), _TransposeLast()),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), _Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
(dim, k, stride) = cl
self.conv_layers.append(block(
in_d, dim, k, stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
))
in_d = dim
def forward(self, x):
if len(x.shape) < 3:
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
class _ConvPositionalEncoding(nn.Module):
def __init__(self, args):
super().__init__()
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim, self.embedding_dim,
kernel_size=args.conv_pos, padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
std = math.sqrt((4 * 1.0) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, _SamePad(args.conv_pos), nn.GELU())
def forward(self, x, channel_first=False):
if not channel_first:
x = x.transpose(1, 2)
return self.pos_conv(x).transpose(1, 2)
class _TransformerEncoderLayer(nn.Module):
def __init__(self, embed_dim=768, n_heads=12, ffn_dim=3072, dropout=0.1,
attention_dropout=0.1, activation_dropout=0.1, layer_norm_first=False):
super().__init__()
self.embed_dim = embed_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
def gelu(x):
return F.gelu(x.float()).type_as(x)
self.activation_fn = gelu
self.self_attn = _MultiHeadAttention(embed_dim, n_heads, dropout=attention_dropout, self_attention=True)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
self.self_attn_layer_norm = _make_layer_norm(embed_dim)
self.fc1 = nn.Linear(embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, embed_dim)
self.final_layer_norm = _make_layer_norm(embed_dim)
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(query=x, key=x, value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask, need_weights=False)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
else:
x, attn = self.self_attn(query=x, key=x, value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask, need_weights=False)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
layer_result = x
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, (attn, layer_result)
def _init_bert_params(module):
def normal_(data):
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, _MultiHeadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
# ---------------------------------------------------------------------------
# Helpers from dbeta.py (Pooler, LayerNorm, PositionalEncoding, Transformer)
# ---------------------------------------------------------------------------
class _SupporterLayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor):
orig_type = x.dtype
return super().forward(x.type(torch.float32)).type(orig_type)
class _QuickGELU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(1.702 * x)
class _ResidualAttentionBlock(nn.Module):
def __init__(self, d_model, n_head, attn_mask=None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = _SupporterLayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", _QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]))
self.ln_2 = _SupporterLayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x, x_mask):
if x_mask is not None:
x_mask = x_mask.to(dtype=torch.bool, device=x.device)
attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask, key_padding_mask=x_mask)[0]
def forward(self, x, x_mask=None):
x = x + self.attention(self.ln_1(x), x_mask)
x = x + self.mlp(self.ln_2(x))
return x
class _SupporterTransformer(nn.Module):
def __init__(self, width, layers, heads, attn_mask=None):
super().__init__()
self.resblocks = nn.Sequential(
*[_ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers - 1)]
)
def forward(self, x, x_mask=None):
for block in self.resblocks:
x = block(x, x_mask)
return x
class _PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
return x + self.pe[:, : x.size(1)]
class _Pooler(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
return self.activation(self.dense(hidden_states[:, 0]))
# ---------------------------------------------------------------------------
# TransformerEncoder + ECGTransformerModel
# NOTE: get_embeddings returns (x, padding_mask, x_conv) — 3 values
# ---------------------------------------------------------------------------
class _TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embed_dim = args.encoder_embed_dim
self.layers = nn.ModuleList([
_TransformerEncoderLayer(
embed_dim=self.embed_dim,
ffn_dim=args.encoder_ffn_embed_dim,
n_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
layer_norm_first=args.layer_norm_first,
)
for _ in range(args.encoder_layers)
])
self.layer_norm_first = args.layer_norm_first
self.layer_norm = _make_layer_norm(self.embed_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(_init_bert_params)
def forward(self, x, padding_mask=None, attn_mask=None):
x = self._extract_features(x, padding_mask, attn_mask)
if self.layer_norm_first:
x = self.layer_norm(x)
return x
def _extract_features(self, x, padding_mask=None, attn_mask=None):
if padding_mask is not None:
x[padding_mask] = 0
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(0, 1)
for layer in self.layers:
dropout_probability = np.random.random()
if not self.training or dropout_probability > self.layerdrop:
x, z = layer(x, self_attn_padding_mask=padding_mask,
self_attn_mask=attn_mask, need_weights=False)
return x.transpose(0, 1)
class _ECGTransformerModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
if cfg.apply_mask:
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.num_updates = 0
feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = _ConvFeatureExtraction(
conv_layers=feature_enc_layers, in_d=cfg.in_d,
dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias,
)
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim else None
)
self.feature_grad_mult = cfg.feature_grad_mult
self.conv_pos = _ConvPositionalEncoding(cfg)
self.layer_norm = _make_layer_norm(self.embed)
self.encoder = _TransformerEncoder(cfg)
def _get_feat_extract_output_lengths(self, input_lengths):
def _conv_out_length(input_length, kernel_size, stride):
return torch.floor((input_length - kernel_size) / stride + 1)
for cl in eval(self.cfg.conv_feature_layers):
input_lengths = _conv_out_length(input_lengths, cl[1], cl[2])
return input_lengths.to(torch.long)
def get_embeddings(self, source, padding_mask):
"""Returns (x, padding_mask, x_conv) — 3 values."""
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = _GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
features = features.transpose(1, 2)
features = self.layer_norm(features)
if padding_mask is not None and padding_mask.any():
input_lengths = (1 - padding_mask.long()).sum(-1)
if input_lengths.dim() > 1:
input_lengths = input_lengths[:, 0]
output_lengths = self._get_feat_extract_output_lengths(input_lengths)
padding_mask = torch.zeros(features.shape[:2], dtype=features.dtype, device=features.device)
padding_mask[
(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)
] = 1
padding_mask[torch.where(output_lengths == 0)] = 0
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
else:
padding_mask = None
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
x_conv = self.conv_pos(features, channel_first=False)
x = features + x_conv
return x, padding_mask, x_conv
def get_output(self, x, padding_mask=None):
return self.encoder(x, padding_mask=padding_mask)
# ---------------------------------------------------------------------------
# Cross-attention layer (from cross_layer.py — same as D-BETA)
# ---------------------------------------------------------------------------
class _BertSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.is_decoder = config.is_decoder
def transpose_for_scores(self, x):
return x.view(*x.size()[:-1], self.num_attention_heads, self.attention_head_size).permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_attention_mask=None,
past_key_value=None, output_attentions=False):
mixed_query_layer = self.query(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
context_layer = context_layer.view(*context_layer.size()[:-2], self.all_head_size)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class _BertSelfOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
return self.LayerNorm(self.dropout(self.dense(hidden_states)) + input_tensor)
class _BertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.self = _BertSelfAttention(config)
self.output = _BertSelfOutput(config)
self.pruned_heads = set()
def forward(self, hidden_states, attention_mask=None, head_mask=None,
encoder_hidden_states=None, encoder_attention_mask=None,
past_key_value=None, output_attentions=False):
self_outputs = self.self(hidden_states, attention_mask, head_mask,
encoder_hidden_states, encoder_attention_mask,
past_key_value, output_attentions)
return (self.output(self_outputs[0], hidden_states),) + self_outputs[1:]
class _BertIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
return self.intermediate_act_fn(self.dense(hidden_states))
class _BertOutput(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
return self.LayerNorm(self.dropout(self.dense(hidden_states)) + input_tensor)
class _BertCrossLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = _BertAttention(config)
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
self.crossattention = _BertAttention(config)
self.intermediate = _BertIntermediate(config)
self.output = _BertOutput(config)
def forward(self, hidden_states, encoder_hidden_states, attention_mask=None,
encoder_attention_mask=None, output_attentions=False):
self_attention_outputs = self.attention(hidden_states, attention_mask,
head_mask=None, output_attentions=output_attentions)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
cross_attention_outputs = self.crossattention(
attention_output, attention_mask, None,
encoder_hidden_states, encoder_attention_mask, None, output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:]
layer_output = apply_chunking_to_forward(
self._feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
)
return (layer_output,) + outputs
def _feed_forward_chunk(self, attention_output):
return self.output(self.intermediate(attention_output), attention_output)
# ---------------------------------------------------------------------------
# M3AEModel — ECG encoder (from models/ecg_encoder/dbeta.py)
# ---------------------------------------------------------------------------
def _init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class _MLMHead(nn.Module):
def __init__(self, config, weight=None):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
if weight is not None:
self.decoder.weight = weight
def forward(self, x):
return self.decoder(self.transform(x)) + self.bias
class _MIMHead(nn.Module):
def __init__(self, cfg):
super().__init__()
self.hidden_dim = cfg.hidden_dim
self.decoder_hidden_dim = cfg.mim_decoder_hidden_dim
self.decoder_embed = nn.Linear(self.hidden_dim, self.decoder_hidden_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, self.decoder_hidden_dim))
torch.nn.init.normal_(self.mask_token, std=0.02)
self.decoder_pos_embed = _PositionalEncoding(self.decoder_hidden_dim, max_len=512)
self.decoder = _SupporterTransformer(self.decoder_hidden_dim, cfg.mim_decoder_num_layers + 1, cfg.mim_decoder_num_heads)
self.decoder_norm = _SupporterLayerNorm(self.decoder_hidden_dim)
def _conv_out_length(il, k, s):
return np.floor((il - k) / s + 1)
inferred = 5000
for cl in eval(cfg.conv_feature_layers):
inferred = _conv_out_length(inferred, cl[1], cl[2])
self.inferred_decoded_size = int(np.floor(5000 / inferred))
self.decoder_pred = nn.Linear(self.decoder_hidden_dim, self.inferred_decoded_size * 12, bias=True)
def forward(self, x, ids_restore):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))
x = torch.cat([x[:, :1, :], x_], dim=1)
x = self.decoder_pos_embed(x)
x = self.decoder(x.permute(1, 0, 2)).permute(1, 0, 2)
x = self.decoder_norm(x)
x = self.decoder_pred(x)[:, 1:, :]
return x.view(x.size(0), x.size(1), -1, self.inferred_decoded_size)
class _ITMHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.fc = nn.Linear(hidden_size, 2)
def forward(self, x):
return self.fc(x)
class M3AEModel(nn.Module):
"""ECG encoder from Q-HEART (identical role to DBETA in D-BETA)."""
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.vocab_size = cfg.vocab_size
self.mim_prob = cfg.mim_prob
self.mim_layer = cfg.mim_layer
self.ecg_encoder = _ECGTransformerModel(cfg)
self.class_embedding = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
from transformers import T5EncoderModel
self.language_encoder = T5EncoderModel.from_pretrained("google/flan-t5-base")
self.language_encoder.pooler = None
self.multi_modal_language_proj = nn.Linear(cfg.encoder_embed_dim, cfg.hidden_dim)
self.multi_modal_language_proj.apply(_init_weights)
self.multi_modal_ecg_proj = nn.Linear(cfg.encoder_embed_dim, cfg.hidden_dim)
self.multi_modal_ecg_proj.apply(_init_weights)
self.modality_type_embeddings = nn.Embedding(2, cfg.hidden_dim)
self.modality_type_embeddings.apply(_init_weights)
bert_config = BertConfig(
vocab_size=cfg.vocab_size,
hidden_size=cfg.hidden_dim,
num_hidden_layers=cfg.num_layers,
num_attention_heads=cfg.num_heads,
intermediate_size=cfg.hidden_dim * 4,
max_position_embeddings=cfg.max_text_size,
hidden_dropout_prob=cfg.drop_rate,
attention_probs_dropout_prob=cfg.drop_rate,
)
self.multi_modal_ecg_layers = nn.ModuleList([_BertCrossLayer(bert_config) for _ in range(cfg.num_top_layer)])
self.multi_modal_ecg_layers.apply(_init_weights)
self.multi_modal_language_layers = nn.ModuleList([_BertCrossLayer(bert_config) for _ in range(cfg.num_top_layer)])
self.multi_modal_language_layers.apply(_init_weights)
self.multi_modal_ecg_pooler = _Pooler(cfg.hidden_dim)
self.multi_modal_ecg_pooler.apply(_init_weights)
self.multi_modal_language_pooler = _Pooler(cfg.hidden_dim)
self.multi_modal_language_pooler.apply(_init_weights)
self.unimodal_ecg_pooler = _Pooler(cfg.hidden_dim)
self.unimodal_ecg_pooler.apply(_init_weights)
self.unimodal_language_pooler = _Pooler(cfg.hidden_dim)
self.unimodal_language_pooler.apply(_init_weights)
self.mlm_head = _MLMHead(bert_config)
self.mlm_head.apply(_init_weights)
self.mim_head = _MIMHead(cfg)
self.mim_head.apply(_init_weights)
self.itm_head = _ITMHead(cfg.hidden_dim * 2)
self.itm_head.apply(_init_weights)
def remove_pretraining_modules(self):
self.mlm_head = None
self.mim_head = None
self.itm_head = None
self.language_encoder = None
self.multi_modal_language_layers = None
self.multi_modal_ecg_layers = None
# ---------------------------------------------------------------------------
# Mapper classes (from models/prefix_mappers.py)
# ---------------------------------------------------------------------------
class _MlpTransformer(nn.Module):
def __init__(self, in_dim, h_dim, out_d=None, act=F.relu, dropout=0.):
super().__init__()
out_d = out_d if out_d is not None else in_dim
self.fc1 = nn.Linear(in_dim, h_dim)
self.act = act
self.fc2 = nn.Linear(h_dim, out_d)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.act(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return self.dropout(x)
class _MapperMultiHeadAttention(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim_self // num_heads
self.scale = head_dim ** -0.5
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
self.project = nn.Linear(dim_self, dim_self)
self.dropout = nn.Dropout(dropout)
def forward(self, x, y=None, mask=None):
y = y if y is not None else x
b, n, c = x.shape
_, m, d = y.shape
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(1)
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
attention = attention.softmax(dim=2)
out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c)
return self.project(out), attention
class _MapperTransformerLayer(nn.Module):
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0.,
act=F.relu, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim_self)
self.attn = _MapperMultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
self.norm2 = norm_layer(dim_self)
self.mlp = _MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
def forward(self, x, y=None, mask=None):
x = x + self.attn(self.norm1(x), y, mask)[0]
return x + self.mlp(self.norm2(x))
class _MapperTransformer(nn.Module):
def __init__(self, dim_self, num_heads, num_layers, dim_ref=None, mlp_ratio=2.,
act=F.relu, norm_layer=nn.LayerNorm, enc_dec=False):
super().__init__()
dim_ref = dim_ref if dim_ref is not None else dim_self
self.enc_dec = enc_dec
layers = []
for i in range(num_layers * 2 if enc_dec else num_layers):
if enc_dec and i % 2 == 0:
layers.append(_MapperTransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
elif enc_dec:
layers.append(_MapperTransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
else:
layers.append(_MapperTransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
self.layers = nn.ModuleList(layers)
def forward(self, x, y=None, mask=None):
for i, layer in enumerate(self.layers):
if i % 2 == 0 and self.enc_dec:
x = layer(x, y)
elif self.enc_dec:
x = layer(x, x, mask)
else:
x = layer(x, y, mask)
return x
class TransformerMapper(nn.Module):
def __init__(self, dim_clip, dim_embedding, prefix_length, clip_length, num_layers=4, num_heads=4):
super().__init__()
self.clip_length = clip_length
self.transformer = _MapperTransformer(dim_embedding, num_heads, num_layers)
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
def forward(self, x):
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
prefix = torch.cat((x, prefix), dim=1)
return self.transformer(prefix)[:, self.clip_length:]
class AttentionMapper(nn.Module):
def __init__(self, dim=786, output_dim=2048, num_heads=8, dim_head=64):
super().__init__()
self.num_heads = num_heads
self.scale = dim_head ** -0.5
self.dim_head = dim_head
self.inner_dim = num_heads * dim_head
self.ecg_projection_layer = nn.Linear(dim, output_dim)
self.norm_ecg = nn.LayerNorm(output_dim)
self.norm_query = nn.LayerNorm(output_dim)
self.to_q = nn.Linear(output_dim, self.inner_dim, bias=False)
self.to_kv = nn.Linear(output_dim, self.inner_dim * 2, bias=False)
self.to_out = nn.Linear(self.inner_dim, output_dim, bias=False)
def forward(self, ecg_features, query_features, prefix_len=None):
ecg_features = self.norm_ecg(self.ecg_projection_layer(ecg_features))
normed_query = self.norm_query(query_features[:, :prefix_len, :] if prefix_len is not None else query_features)
q = self.to_q(ecg_features)
k, v = self.to_kv(normed_query).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.num_heads), (q, k, v))
scores = torch.einsum("bhqd, bhkd -> bhqk", q, k) * self.scale
attn = scores.softmax(dim=-1)
out = rearrange(torch.einsum("bhqk, bhvd -> bhqd", attn, v), "b h n d -> b n (h d)")
return torch.cat((self.to_out(out), query_features), dim=1)
class MoEMapper(nn.Module):
def __init__(self, input_dim, output_dim, num_experts=12):
super().__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, output_dim)) for _ in range(num_experts)])
self.text_gate = nn.Linear(output_dim, num_experts)
self.output_dim = output_dim
def forward(self, x, t):
B = x.size(0)
x_flat = x.squeeze(1)
gate_logits = self.text_gate(t.mean(dim=1))
top1_indices = gate_logits.argmax(dim=-1)
moe_out = torch.zeros(B, self.output_dim, device=x.device, dtype=x.dtype)
for expert_id in range(self.num_experts):
mask = (top1_indices == expert_id)
if mask.any():
moe_out[mask] = self.experts[expert_id](x_flat[mask])
return moe_out.reshape(B, 1, self.output_dim)
class _MLPBlock(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim))
def forward(self, x):
return self.mlp(x)
class MLPMixer(nn.Module):
def __init__(self, num_tokens=12, input_dim=768, llm_embedding_size=1024, num_layers=4):
super().__init__()
self.token_mixer = nn.ModuleList([_MLPBlock(num_tokens, input_dim) for _ in range(num_layers)])
self.channel_mixer = nn.ModuleList([_MLPBlock(input_dim, input_dim) for _ in range(num_layers)])
self.last = nn.Linear(input_dim, llm_embedding_size)
self.norm = nn.LayerNorm(input_dim)
def forward(self, x):
for token_mlp, channel_mlp in zip(self.token_mixer, self.channel_mixer):
x = x + token_mlp(x.transpose(1, 2)).transpose(1, 2)
x = x + channel_mlp(x)
return self.last(self.norm(x))
# ---------------------------------------------------------------------------
# CustomECGQAModel (from models/models.py)
# ---------------------------------------------------------------------------
class CustomECGQAModel(nn.Module):
def __init__(self, ecg_encoder, mapping_type="Transformer", setting="lora",
prefix_length=12, clip_length=12, llm_model_type="meta-llama/Llama-3.2-1B-Instruct"):
super().__init__()
self.mapping_type = mapping_type
self.setting = setting
self.llm_type = llm_model_type
self.llm = AutoModelForCausalLM.from_pretrained(self.llm_type)
self.llm_embedding_size = self.llm.config.hidden_size
if setting == "lora":
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
)
self.llm = get_peft_model(self.llm, peft_config)
elif setting == "frozen":
for param in self.llm.parameters():
param.requires_grad = False
self.llm.eval()
self.ecg_encoder = ecg_encoder
self.postconv = nn.Conv1d(in_channels=312, out_channels=12, kernel_size=1)
self.ecg_token_nums = prefix_length if mapping_type in ["Transformer", "MLPMixer"] else 1
self.ecg_feature_dim = 768
if mapping_type == "MLP":
self.ecg_projection_layer = nn.Linear(self.ecg_feature_dim, self.llm_embedding_size)
elif mapping_type == "MLPMixer":
self.ecg_projection_layer = MLPMixer(input_dim=self.ecg_feature_dim, llm_embedding_size=self.llm_embedding_size)
elif mapping_type == "Attention":
self.ecg_projection_layer = AttentionMapper(dim=self.ecg_feature_dim, output_dim=self.llm_embedding_size)
elif mapping_type == "Transformer":
self.ecg_projection_layer = TransformerMapper(
dim_clip=self.ecg_feature_dim,
dim_embedding=self.llm_embedding_size,
prefix_length=self.ecg_token_nums,
clip_length=clip_length,
num_heads=4,
num_layers=2,
)
self.postlinear = nn.Linear(self.ecg_feature_dim, self.llm_embedding_size)
elif mapping_type == "MOE":
self.ecg_projection_layer = MoEMapper(self.ecg_feature_dim, self.llm_embedding_size)
def _get_ecg_features(self, ecg):
uni_modal_ecg_feats, ecg_padding_mask, conv_embedd = (
self.ecg_encoder.ecg_encoder.get_embeddings(ecg, padding_mask=None)
)
cls_emb = self.ecg_encoder.class_embedding.repeat(len(uni_modal_ecg_feats), 1, 1)
uni_modal_ecg_feats = torch.cat([cls_emb, uni_modal_ecg_feats], dim=1)
uni_modal_ecg_feats = self.ecg_encoder.ecg_encoder.get_output(uni_modal_ecg_feats, ecg_padding_mask)
out = self.ecg_encoder.multi_modal_ecg_proj(uni_modal_ecg_feats)
ecg_features = self.ecg_encoder.unimodal_ecg_pooler(out)
return ecg_features, conv_embedd
def _build_inputs_embeds(self, input_ids, ecg):
ecg_features, conv_embedd = self._get_ecg_features(ecg)
ecg_features = ecg_features.reshape(input_ids.shape[0], 1, -1)
if self.mapping_type == "MOE":
embeddings = self.llm.get_input_embeddings()(input_ids)
ecg_features_projected = self.ecg_projection_layer(ecg_features, embeddings)
elif self.mapping_type == "Attention":
embeddings = self.llm.get_input_embeddings()(input_ids)
return self.ecg_projection_layer(ecg_features, embeddings, None), None
else:
ecg_features_projected = self.ecg_projection_layer(ecg_features)
if self.mapping_type == "Transformer":
ecg_features_projected = ecg_features_projected + self.postlinear(self.postconv(conv_embedd))
embeddings = self.llm.get_input_embeddings()(input_ids)
return torch.cat((ecg_features_projected, embeddings), dim=1), ecg_features_projected.shape[1]
def forward(self, input_ids=None, attention_mask=None, labels=None, ecg=None, **kwargs):
inputs_embeds, n_ecg_tokens = self._build_inputs_embeds(input_ids, ecg)
device = input_ids.device
if attention_mask is not None:
attention_mask = torch.cat(
(torch.ones((input_ids.size(0), self.ecg_token_nums), dtype=attention_mask.dtype, device=device),
attention_mask), dim=1,
)
if labels is not None:
labels = torch.cat(
(torch.full((labels.size(0), self.ecg_token_nums), -100, dtype=labels.dtype, device=device),
labels), dim=1,
)
return self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels)
def generate(self, input_ids=None, attention_mask=None, ecg=None, max_length=50, **kwargs):
inputs_embeds, _ = self._build_inputs_embeds(input_ids, ecg.reshape(-1, 12, 5000))
device = input_ids.device
if attention_mask is not None:
attention_mask = torch.cat(
(torch.ones((input_ids.size(0), self.ecg_token_nums), dtype=attention_mask.dtype, device=device),
attention_mask), dim=1,
)
return self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=inputs_embeds.shape[1] + max_length,
attention_mask=attention_mask,
**kwargs,
)
# ---------------------------------------------------------------------------
# HuggingFace PreTrainedModel wrapper
# ---------------------------------------------------------------------------
class QHEARTForECGQA(PreTrainedModel):
"""
Q-HEART: ECG Question Answering model wrapped as a HuggingFace PreTrainedModel.
Combines a 12-lead ECG encoder (M3AEModel) with a causal LLM (Llama/Gemma/etc.)
via a learned mapping layer (ET-Mapper).
Example::
from transformers import AutoModel, AutoTokenizer
import torch
model = AutoModel.from_pretrained("Manhph2211/Q-HEART", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
model.eval()
ecg = torch.randn(1, 12, 5000) # [batch, leads, length] at 500 Hz
question = "What is the heart rhythm shown in this ECG?"
inputs = tokenizer(question, return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(
ecg=ecg,
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=50,
)
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))
"""
config_class = QHEARTConfig
def __init__(self, config: QHEARTConfig):
super().__init__(config)
ecg_encoder = M3AEModel(config)
ecg_encoder.remove_pretraining_modules()
self.model = CustomECGQAModel(
ecg_encoder=ecg_encoder,
mapping_type=config.mapping_type,
setting="lora",
prefix_length=config.prefix_length,
clip_length=config.clip_length,
llm_model_type=config.llm_model_type,
)
self.post_init()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load from a HuggingFace Hub repo or local directory.
Supports ``pytorch_model.bin`` (standard) and ``sample.bin`` (original format).
Both are flat state dicts — no nested key handling needed.
"""
import os
from transformers.utils import cached_file
config = kwargs.pop("config", None)
cache_dir = kwargs.get("cache_dir", None)
token = kwargs.get("token", kwargs.get("use_auth_token", None))
revision = kwargs.get("revision", None)
local_files_only = kwargs.get("local_files_only", False)
device_map = kwargs.pop("device_map", None)
if config is None:
config = QHEARTConfig.from_pretrained(pretrained_model_name_or_path, **{
k: v for k, v in kwargs.items()
if k in ("cache_dir", "token", "use_auth_token", "revision", "local_files_only", "trust_remote_code")
})
model = cls(config)
resolved_path = None
for fname in ("pytorch_model.bin", "sample.bin"):
try:
if os.path.isdir(pretrained_model_name_or_path):
candidate = os.path.join(pretrained_model_name_or_path, fname)
if os.path.isfile(candidate):
resolved_path = candidate
break
else:
resolved_path = cached_file(
pretrained_model_name_or_path, fname,
cache_dir=cache_dir, token=token,
revision=revision, local_files_only=local_files_only,
)
if resolved_path:
break
except Exception:
continue
if resolved_path is None:
logger.warning("No checkpoint found (pytorch_model.bin or sample.bin). Returning model with random weights.")
return model
state_dict = torch.load(resolved_path, map_location="cpu")
# sample.bin is already a flat state dict (no nested "model" key)
if isinstance(state_dict, dict) and "model" in state_dict and not any(
k.startswith("model.") or k.startswith("ecg_encoder.") or k.startswith("llm.")
for k in state_dict.keys()
):
state_dict = state_dict["model"]
missing, unexpected = model.model.load_state_dict(state_dict, strict=False)
if missing:
logger.warning(f"Missing keys: {missing}")
if unexpected:
logger.warning(f"Unexpected keys: {unexpected}")
logger.info(f"Loaded Q-HEART weights from {resolved_path}")
if device_map is not None:
model = model.to(device_map)
return model
def forward(self, input_ids=None, attention_mask=None, labels=None, ecg=None, **kwargs):
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, ecg=ecg)
def generate(self, input_ids=None, attention_mask=None, ecg=None, max_new_tokens=50, **kwargs):
return self.model.generate(
input_ids=input_ids, attention_mask=attention_mask,
ecg=ecg, max_length=max_new_tokens, **kwargs,
)