import torch import torch.nn as nn from transformers import AutoModel class ECGEncoder(nn.Module): def __init__(self, **kwargs): super(ECGEncoder, self).__init__() # Load the pre-trained model from Hugging Face # Note: Ensure that the model is available and accessible # Use bfloat16 by default. self.encoder = AutoModel.from_pretrained("fuyingw/MELP_Encoder", torch_dtype=torch.bfloat16, trust_remote_code=True, revision="main") def forward(self, input_embeds, audio_attention_mask): x = input_embeds.permute(0, 2, 1) # Change shape from (batch_size, seq_len, features) to (batch_size, features, seq_len) output = self.encoder(x) # bz, 128, 768 mask = None return output["ecg_token_emb"], mask def gradient_checkpointing_enable(self): pass