|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel |
|
|
|
|
|
|
|
|
class ECGEncoder(nn.Module): |
|
|
def __init__(self, **kwargs): |
|
|
super(ECGEncoder, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.encoder = AutoModel.from_pretrained("fuyingw/MELP_Encoder", |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True, |
|
|
revision="main") |
|
|
|
|
|
def forward(self, input_embeds, audio_attention_mask): |
|
|
x = input_embeds.permute(0, 2, 1) |
|
|
output = self.encoder(x) |
|
|
|
|
|
mask = None |
|
|
return output["ecg_token_emb"], mask |
|
|
|
|
|
def gradient_checkpointing_enable(self): |
|
|
pass |
|
|
|