| | from collections import OrderedDict |
| | from typing import List, Union, Dict |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import Tensor |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | import fairseq |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | class UpstreamExpert(nn.Module): |
| | def __init__( |
| | self, |
| | ckpt: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", |
| | upstream_feature_selection: str = "hidden_states", |
| | **kwargs): |
| | """ |
| | Args: |
| | ckpt: |
| | The checkpoint path for loading your pretrained weights. |
| | Should be fixed as model.pt for SUPERB Challenge. |
| | upstream_feature_selection: |
| | The value could be |
| | 'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks). |
| | You can use it to control which task-specified pre- / post-processing to do. |
| | """ |
| | super().__init__() |
| | self.name = "[Example UpstreamExpert]" |
| | self.upstream_feature_selection = upstream_feature_selection |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | assert version.parse(fairseq.__version__) > version.parse( |
| | "0.10.2" |
| | ), "Please install the fairseq master branch." |
| |
|
| | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
| | [ckpt] |
| | ) |
| | self.model = model[0] |
| | self.task = task |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def get_downsample_rates(self, key: str) -> int: |
| | """ |
| | Since we do not do any downsampling in this example upstream |
| | All keys' corresponding representations have downsample rate of 1 |
| | Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz) |
| | """ |
| | return 320 |
| |
|
| | def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]: |
| | """ |
| | When the returning Dict contains the List with more than one Tensor, |
| | those Tensors should be in the same shape to train a weighted-sum on them. |
| | """ |
| | wavs_silence = [] |
| |
|
| |
|
| | |
| |
|
| | |
| | wavs_silence = wavs |
| |
|
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//5).to(wav.device) |
| | wavs_silence.append(torch.cat((temp_wav, wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//10).to(wav.device) |
| | wavs_silence.append(torch.cat((temp_wav, wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//20).to(wav.device) |
| | wavs_silence.append(torch.cat((temp_wav, wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//5).to(wav.device) |
| | wavs_silence.append(torch.cat((wav, temp_wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//10).to(wav.device) |
| | wavs_silence.append(torch.cat((wav, temp_wav))) |
| |
|
| | |
| | for wav in wavs: |
| | temp_wav = torch.zeros(len(wav)//20).to(wav.device) |
| | wavs_silence.append(torch.cat((wav, temp_wav))) |
| |
|
| |
|
| | wavs = wavs_silence |
| |
|
| | device = wavs[0].device |
| | wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) |
| | wav_padding_mask = ~torch.lt( |
| | torch.arange(max(wav_lengths)).unsqueeze(0).to(device), |
| | wav_lengths.unsqueeze(1), |
| | ) |
| | padded_wav = pad_sequence(wavs, batch_first=True) |
| |
|
| | features, feat_padding_mask = self.model.extract_features( |
| | padded_wav, |
| | padding_mask=wav_padding_mask, |
| | mask=None, |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | return { |
| | "hidden_states": features, |
| | } |
| |
|