| from packaging import version | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.nn.utils.rnn import pad_sequence | |
| import fairseq | |
| from s3prl.upstream.interfaces import UpstreamBase | |
| SAMPLE_RATE = 16000 | |
| EXAMPLE_SEC = 5 | |
| class UpstreamExpert(UpstreamBase): | |
| def __init__(self, ckpt, **kwargs): | |
| super().__init__(**kwargs) | |
| assert version.parse(fairseq.__version__) > version.parse( | |
| "0.10.2" | |
| ), "Please install the fairseq master branch." | |
| model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
| [ckpt] | |
| ) | |
| self.model = model[0] | |
| self.task = task | |
| if len(self.hooks) == 0: | |
| module_name = "self.model.encoder.layers" | |
| for module_id in range(len(eval(module_name))): | |
| self.add_hook( | |
| f"{module_name}[{module_id}]", | |
| lambda input, output: input[0].transpose(0, 1), | |
| ) | |
| self.add_hook("self.model.encoder", lambda input, output: output[0]) | |
| def forward(self, wavs): | |
| if self.task.cfg.normalize: | |
| wavs = [F.layer_norm(wav, wav.shape) for wav in wavs] | |
| device = wavs[0].device | |
| wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) | |
| wav_padding_mask = ~torch.lt( | |
| torch.arange(max(wav_lengths)).unsqueeze(0).to(device), | |
| wav_lengths.unsqueeze(1), | |
| ) | |
| padded_wav = pad_sequence(wavs, batch_first=True) | |
| features, feat_padding_mask = self.model.extract_features( | |
| padded_wav, | |
| padding_mask=wav_padding_mask, | |
| mask=None, | |
| ) | |
| return { | |
| "default": features, | |
| } |