Q-HEART / configuration_qheart.py
Contributor21's picture
.
5ad21a6
from transformers import PretrainedConfig
class QHEARTConfig(PretrainedConfig):
model_type = "qheart"
def __init__(
self,
# --- ECG encoder (M3AEModel) ---
encoder_layers=12,
encoder_embed_dim=768,
encoder_ffn_embed_dim=3072,
encoder_attention_heads=12,
layer_norm_first=False,
dropout=0.1,
attention_dropout=0.1,
activation_dropout=0.0,
encoder_layerdrop=0.0,
dropout_input=0.1,
dropout_features=0.1,
apply_mask=False,
mask_length=10,
mask_prob=0.0,
mask_selection="static",
mask_other=0.0,
no_mask_overlap=False,
mask_min_space=1,
mask_channel_length=10,
mask_channel_prob=0.0,
mask_channel_selection="static",
mask_channel_other=0.0,
no_mask_channel_overlap=False,
mask_channel_min_space=1,
extractor_mode="default",
conv_feature_layers="[(256, 2, 2)] * 4",
in_d=12,
conv_bias=False,
feature_grad_mult=1.0,
conv_pos=128,
conv_pos_groups=16,
load_pretrained_weights=False,
pretrained_model_path="",
vocab_size=32128,
hidden_dim=768,
num_layers=6,
num_heads=12,
drop_rate=0.1,
num_top_layer=6,
mim_layer=3,
mim_prob=0.75,
mim_decoder_hidden_dim=384,
mim_decoder_num_layers=4,
mim_decoder_num_heads=6,
max_text_size=256,
# --- CustomECGQAModel ---
llm_model_type="meta-llama/Llama-3.2-1B-Instruct",
mapping_type="Transformer",
prefix_length=12,
clip_length=12,
**kwargs,
):
# ECG encoder params
self.encoder_layers = encoder_layers
self.encoder_embed_dim = encoder_embed_dim
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
self.encoder_attention_heads = encoder_attention_heads
self.layer_norm_first = layer_norm_first
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.encoder_layerdrop = encoder_layerdrop
self.dropout_input = dropout_input
self.dropout_features = dropout_features
self.apply_mask = apply_mask
self.mask_length = mask_length
self.mask_prob = mask_prob
self.mask_selection = mask_selection
self.mask_other = mask_other
self.no_mask_overlap = no_mask_overlap
self.mask_min_space = mask_min_space
self.mask_channel_length = mask_channel_length
self.mask_channel_prob = mask_channel_prob
self.mask_channel_selection = mask_channel_selection
self.mask_channel_other = mask_channel_other
self.no_mask_channel_overlap = no_mask_channel_overlap
self.mask_channel_min_space = mask_channel_min_space
self.extractor_mode = extractor_mode
self.conv_feature_layers = conv_feature_layers
self.in_d = in_d
self.conv_bias = conv_bias
self.feature_grad_mult = feature_grad_mult
self.conv_pos = conv_pos
self.conv_pos_groups = conv_pos_groups
self.load_pretrained_weights = load_pretrained_weights
self.pretrained_model_path = pretrained_model_path
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.drop_rate = drop_rate
self.num_top_layer = num_top_layer
self.mim_layer = mim_layer
self.mim_prob = mim_prob
self.mim_decoder_hidden_dim = mim_decoder_hidden_dim
self.mim_decoder_num_layers = mim_decoder_num_layers
self.mim_decoder_num_heads = mim_decoder_num_heads
self.max_text_size = max_text_size
# LLM / mapper params
self.llm_model_type = llm_model_type
self.mapping_type = mapping_type
self.prefix_length = prefix_length
self.clip_length = clip_length
super().__init__(**kwargs)