| | import torch |
| | import fairseq |
| | from packaging import version |
| | import torch.nn.functional as F |
| | from fairseq import tasks |
| | from fairseq.checkpoint_utils import load_checkpoint_to_cpu |
| | from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
| | from omegaconf import OmegaConf |
| | from s3prl.upstream.interfaces import UpstreamBase |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | def load_model(filepath): |
| | state = torch.load(filepath, map_location=lambda storage, loc: storage) |
| | |
| | state["cfg"] = OmegaConf.create(state["cfg"]) |
| |
|
| | if "args" in state and state["args"] is not None: |
| | cfg = convert_namespace_to_omegaconf(state["args"]) |
| | elif "cfg" in state and state["cfg"] is not None: |
| | cfg = state["cfg"] |
| | else: |
| | raise RuntimeError( |
| | f"Neither args nor cfg exist in state keys = {state.keys()}" |
| | ) |
| |
|
| | task = tasks.setup_task(cfg.task) |
| | if "task_state" in state: |
| | task.load_state_dict(state["task_state"]) |
| |
|
| | model = task.build_model(cfg.model) |
| |
|
| | return model, cfg, task |
| |
|
| |
|
| | |
| | |
| | |
| | class UpstreamExpert(UpstreamBase): |
| | def __init__(self, ckpt, **kwargs): |
| | super().__init__(**kwargs) |
| | assert version.parse(fairseq.__version__) > version.parse( |
| | "0.10.2" |
| | ), "Please install the fairseq master branch." |
| |
|
| | model, cfg, task = load_model(ckpt) |
| | self.model = model |
| | self.task = task |
| |
|
| | if len(self.hooks) == 0: |
| | module_name = "self.model.encoder.layers" |
| | for module_id in range(len(eval(module_name))): |
| | self.add_hook( |
| | f"{module_name}[{module_id}]", |
| | lambda input, output: input[0].transpose(0, 1), |
| | ) |
| | self.add_hook("self.model.encoder", lambda input, output: output[0]) |
| |
|
| | def forward(self, wavs): |
| | if self.task.cfg.normalize: |
| | wavs = [F.layer_norm(wav, wav.shape) for wav in wavs] |
| |
|
| | device = wavs[0].device |
| | wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device) |
| | wav_padding_mask = ~torch.lt( |
| | torch.arange(max(wav_lengths)).unsqueeze(0).to(device), |
| | wav_lengths.unsqueeze(1), |
| | ) |
| | padded_wav = pad_sequence(wavs, batch_first=True) |
| |
|
| | features, feat_padding_mask = self.model.extract_features( |
| | padded_wav, |
| | padding_mask=wav_padding_mask, |
| | mask=None, |
| | ) |
| | return { |
| | "default": features, |
| | } |
| |
|
| |
|