File size: 888 Bytes
1140d83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
from transformers import AutoModel
class ECGEncoder(nn.Module):
def __init__(self, **kwargs):
super(ECGEncoder, self).__init__()
# Load the pre-trained model from Hugging Face
# Note: Ensure that the model is available and accessible
# Use bfloat16 by default.
self.encoder = AutoModel.from_pretrained("fuyingw/MELP_Encoder",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
revision="main")
def forward(self, input_embeds, audio_attention_mask):
x = input_embeds.permute(0, 2, 1) # Change shape from (batch_size, seq_len, features) to (batch_size, features, seq_len)
output = self.encoder(x)
# bz, 128, 768
mask = None
return output["ecg_token_emb"], mask
def gradient_checkpointing_enable(self):
pass
|