| | import torch |
| | import torch.nn as nn |
| | import fairseq |
| | import os |
| | import hydra |
| |
|
| | def load_ssl_model(cp_path): |
| | ssl_model_type = cp_path.split("/")[-1] |
| | wavlm = "WavLM" in ssl_model_type |
| | if wavlm: |
| | checkpoint = torch.load(cp_path) |
| | cfg = WavLMConfig(checkpoint['cfg']) |
| | ssl_model = WavLM(cfg) |
| | ssl_model.load_state_dict(checkpoint['model']) |
| | if 'Large' in ssl_model_type: |
| | SSL_OUT_DIM = 1024 |
| | else: |
| | SSL_OUT_DIM = 768 |
| | else: |
| | if ssl_model_type == "wav2vec_small.pt": |
| | SSL_OUT_DIM = 768 |
| | elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]: |
| | SSL_OUT_DIM = 1024 |
| | else: |
| | print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.") |
| | exit() |
| | model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
| | [cp_path] |
| | ) |
| | ssl_model = model[0] |
| | ssl_model.remove_pretraining_modules() |
| | return SSL_model(ssl_model, SSL_OUT_DIM, wavlm) |
| |
|
| | class SSL_model(nn.Module): |
| | def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None: |
| | super(SSL_model,self).__init__() |
| | self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim |
| | self.WavLM = wavlm |
| |
|
| | def forward(self,batch): |
| | wav = batch['wav'] |
| | wav = wav.squeeze(1) |
| | if self.WavLM: |
| | x = self.ssl_model.extract_features(wav)[0] |
| | else: |
| | res = self.ssl_model(wav, mask=False, features_only=True) |
| | x = res["x"] |
| | return {"ssl-feature":x} |
| | def get_output_dim(self): |
| | return self.ssl_out_dim |
| |
|
| |
|
| | class PhonemeEncoder(nn.Module): |
| | ''' |
| | PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer. |
| | Args: |
| | vocab_size: the size of the vocabulary |
| | hidden_dim: the size of the hidden state of the LSTM |
| | emb_dim: the size of the embedding layer |
| | out_dim: the size of the output of the linear layer |
| | n_lstm_layers: the number of LSTM layers |
| | ''' |
| | def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None: |
| | super().__init__() |
| | self.with_reference = with_reference |
| | self.embedding = nn.Embedding(vocab_size, emb_dim) |
| | self.encoder = nn.LSTM(emb_dim, hidden_dim, |
| | num_layers=n_lstm_layers, dropout=0.1, bidirectional=True) |
| | self.linear = nn.Sequential( |
| | nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim), |
| | nn.ReLU() |
| | ) |
| | self.out_dim = out_dim |
| |
|
| | def forward(self,batch): |
| | seq = batch['phonemes'] |
| | lens = batch['phoneme_lens'] |
| | reference_seq = batch['reference'] |
| | reference_lens = batch['reference_lens'] |
| | emb = self.embedding(seq) |
| | emb = torch.nn.utils.rnn.pack_padded_sequence( |
| | emb, lens, batch_first=True, enforce_sorted=False) |
| | _, (ht, _) = self.encoder(emb) |
| | feature = ht[-1] + ht[0] |
| | if self.with_reference: |
| | if reference_seq==None or reference_lens ==None: |
| | raise ValueError("reference_batch and reference_lens should not be None when with_reference is True") |
| | reference_emb = self.embedding(reference_seq) |
| | reference_emb = torch.nn.utils.rnn.pack_padded_sequence( |
| | reference_emb, reference_lens, batch_first=True, enforce_sorted=False) |
| | _, (ht_ref, _) = self.encoder(emb) |
| | reference_feature = ht_ref[-1] + ht_ref[0] |
| | feature = self.linear(torch.cat([feature,reference_feature],1)) |
| | else: |
| | feature = self.linear(feature) |
| | return {"phoneme-feature": feature} |
| | def get_output_dim(self): |
| | return self.out_dim |
| |
|
| | class DomainEmbedding(nn.Module): |
| | def __init__(self,n_domains,domain_dim) -> None: |
| | super().__init__() |
| | self.embedding = nn.Embedding(n_domains,domain_dim) |
| | self.output_dim = domain_dim |
| | def forward(self, batch): |
| | return {"domain-feature": self.embedding(batch['domains'])} |
| | def get_output_dim(self): |
| | return self.output_dim |
| |
|
| |
|
| | class LDConditioner(nn.Module): |
| | ''' |
| | Conditions ssl output by listener embedding |
| | ''' |
| | def __init__(self,input_dim, judge_dim, num_judges=None): |
| | super().__init__() |
| | self.input_dim = input_dim |
| | self.judge_dim = judge_dim |
| | self.num_judges = num_judges |
| | assert num_judges !=None |
| | self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) |
| | |
| | |
| | self.decoder_rnn = nn.LSTM( |
| | input_size = self.input_dim + self.judge_dim, |
| | hidden_size = 512, |
| | num_layers = 1, |
| | batch_first = True, |
| | bidirectional = True |
| | ) |
| | self.out_dim = self.decoder_rnn.hidden_size*2 |
| |
|
| | def get_output_dim(self): |
| | return self.out_dim |
| |
|
| |
|
| | def forward(self, x, batch): |
| | judge_ids = batch['judge_id'] |
| | if 'phoneme-feature' in x.keys(): |
| | concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2) |
| | else: |
| | concatenated_feature = x['ssl-feature'] |
| | if 'domain-feature' in x.keys(): |
| | concatenated_feature = torch.cat( |
| | ( |
| | concatenated_feature, |
| | x['domain-feature'] |
| | .unsqueeze(1) |
| | .expand(-1, concatenated_feature.size(1), -1), |
| | ), |
| | dim=2, |
| | ) |
| | if judge_ids != None: |
| | concatenated_feature = torch.cat( |
| | ( |
| | concatenated_feature, |
| | self.judge_embedding(judge_ids) |
| | .unsqueeze(1) |
| | .expand(-1, concatenated_feature.size(1), -1), |
| | ), |
| | dim=2, |
| | ) |
| | decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) |
| | return decoder_output |
| |
|
| | class Projection(nn.Module): |
| | def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): |
| | super(Projection, self).__init__() |
| | self.range_clipping = range_clipping |
| | output_dim = 1 |
| | if range_clipping: |
| | self.proj = nn.Tanh() |
| | |
| | self.net = nn.Sequential( |
| | nn.Linear(input_dim, hidden_dim), |
| | activation, |
| | nn.Dropout(0.3), |
| | nn.Linear(hidden_dim, output_dim), |
| | ) |
| | self.output_dim = output_dim |
| | |
| | def forward(self, x, batch): |
| | output = self.net(x) |
| |
|
| | |
| | if self.range_clipping: |
| | return self.proj(output) * 2.0 + 3 |
| | else: |
| | return output |
| | def get_output_dim(self): |
| | return self.output_dim |
| |
|