ECGAgent / ecg_encoder.py
fuyingw's picture
Upload Phi4MMForCausalLM
1140d83 verified
import torch
import torch.nn as nn
from transformers import AutoModel
class ECGEncoder(nn.Module):
def __init__(self, **kwargs):
super(ECGEncoder, self).__init__()
# Load the pre-trained model from Hugging Face
# Note: Ensure that the model is available and accessible
# Use bfloat16 by default.
self.encoder = AutoModel.from_pretrained("fuyingw/MELP_Encoder",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
revision="main")
def forward(self, input_embeds, audio_attention_mask):
x = input_embeds.permute(0, 2, 1) # Change shape from (batch_size, seq_len, features) to (batch_size, features, seq_len)
output = self.encoder(x)
# bz, 128, 768
mask = None
return output["ecg_token_emb"], mask
def gradient_checkpointing_enable(self):
pass