File size: 1,237 Bytes
f55a095 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | import torch
from speechbrain.lobes.models.ECAPA_TDNN import ECAPA_TDNN
class Model(torch.nn.Module):
def __init__(self, n_mels=80, embedding_dim=192, channel=512):
super(Model, self).__init__()
channels = [channel for _ in range(4)]
channels.append(channel * 3)
self.model = ECAPA_TDNN(input_size=n_mels, lin_neurons=embedding_dim, channels=channels)
def forward(self, x):
x = x.squeeze(1)
x = self.model(x)
x = x.squeeze(1)
return x
if __name__ == '__main__':
# Fixing the naming issue for 'channel'
model = Model(n_mels=80, embedding_dim=192, channel=1024)
# Load the pretrained model checkpoint
checkpoint = torch.load("/ocean/projects/cis220031p/abdulhan/AVIS_baseline/ECAPA/pretrained_models/spkrec-ecapa-voxceleb/embedding_model.ckpt")
new_state_dict = {f"model.{k}": v for k, v in checkpoint.items()}
# Assuming the checkpoint contains the state dict directly
model.load_state_dict(new_state_dict)
# To evaluate or use the model
model.eval()
# Test with dummy input (B, 1, n_mels, T)
dummy_input = torch.randn(1, 1, 300, 80) # Example input
output = model(dummy_input)
print(output.shape) |