Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import fairseq | |
| import os | |
| import hydra | |
| def load_ssl_model(cp_path): | |
| ssl_model_type = cp_path.split("/")[-1] | |
| wavlm = "WavLM" in ssl_model_type | |
| if wavlm: | |
| checkpoint = torch.load(cp_path) | |
| cfg = WavLMConfig(checkpoint['cfg']) | |
| ssl_model = WavLM(cfg) | |
| ssl_model.load_state_dict(checkpoint['model']) | |
| if 'Large' in ssl_model_type: | |
| SSL_OUT_DIM = 1024 | |
| else: | |
| SSL_OUT_DIM = 768 | |
| else: | |
| if ssl_model_type == "wav2vec_small.pt": | |
| SSL_OUT_DIM = 768 | |
| elif ssl_model_type in ["w2v_large_lv_fsh_swbd_cv.pt", "xlsr_53_56k.pt"]: | |
| SSL_OUT_DIM = 1024 | |
| else: | |
| print("*** ERROR *** SSL model type " + ssl_model_type + " not supported.") | |
| exit() | |
| model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( | |
| [cp_path] | |
| ) | |
| ssl_model = model[0] | |
| ssl_model.remove_pretraining_modules() | |
| return SSL_model(ssl_model, SSL_OUT_DIM, wavlm) | |
| class SSL_model(nn.Module): | |
| def __init__(self,ssl_model,ssl_out_dim,wavlm) -> None: | |
| super(SSL_model,self).__init__() | |
| self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim | |
| self.WavLM = wavlm | |
| def forward(self,batch): | |
| wav = batch['wav'] | |
| wav = wav.squeeze(1) # [batches, audio_len] | |
| if self.WavLM: | |
| x = self.ssl_model.extract_features(wav)[0] | |
| else: | |
| res = self.ssl_model(wav, mask=False, features_only=True) | |
| x = res["x"] | |
| return {"ssl-feature":x} | |
| def get_output_dim(self): | |
| return self.ssl_out_dim | |
| class PhonemeEncoder(nn.Module): | |
| ''' | |
| PhonemeEncoder consists of an embedding layer, an LSTM layer, and a linear layer. | |
| Args: | |
| vocab_size: the size of the vocabulary | |
| hidden_dim: the size of the hidden state of the LSTM | |
| emb_dim: the size of the embedding layer | |
| out_dim: the size of the output of the linear layer | |
| n_lstm_layers: the number of LSTM layers | |
| ''' | |
| def __init__(self, vocab_size, hidden_dim, emb_dim, out_dim,n_lstm_layers,with_reference=True) -> None: | |
| super().__init__() | |
| self.with_reference = with_reference | |
| self.embedding = nn.Embedding(vocab_size, emb_dim) | |
| self.encoder = nn.LSTM(emb_dim, hidden_dim, | |
| num_layers=n_lstm_layers, dropout=0.1, bidirectional=True) | |
| self.linear = nn.Sequential( | |
| nn.Linear(hidden_dim + hidden_dim*self.with_reference, out_dim), | |
| nn.ReLU() | |
| ) | |
| self.out_dim = out_dim | |
| def forward(self,batch): | |
| seq = batch['phonemes'] | |
| lens = batch['phoneme_lens'] | |
| reference_seq = batch['reference'] | |
| reference_lens = batch['reference_lens'] | |
| emb = self.embedding(seq) | |
| emb = torch.nn.utils.rnn.pack_padded_sequence( | |
| emb, lens, batch_first=True, enforce_sorted=False) | |
| _, (ht, _) = self.encoder(emb) | |
| feature = ht[-1] + ht[0] | |
| if self.with_reference: | |
| if reference_seq==None or reference_lens ==None: | |
| raise ValueError("reference_batch and reference_lens should not be None when with_reference is True") | |
| reference_emb = self.embedding(reference_seq) | |
| reference_emb = torch.nn.utils.rnn.pack_padded_sequence( | |
| reference_emb, reference_lens, batch_first=True, enforce_sorted=False) | |
| _, (ht_ref, _) = self.encoder(emb) | |
| reference_feature = ht_ref[-1] + ht_ref[0] | |
| feature = self.linear(torch.cat([feature,reference_feature],1)) | |
| else: | |
| feature = self.linear(feature) | |
| return {"phoneme-feature": feature} | |
| def get_output_dim(self): | |
| return self.out_dim | |
| class DomainEmbedding(nn.Module): | |
| def __init__(self,n_domains,domain_dim) -> None: | |
| super().__init__() | |
| self.embedding = nn.Embedding(n_domains,domain_dim) | |
| self.output_dim = domain_dim | |
| def forward(self, batch): | |
| return {"domain-feature": self.embedding(batch['domains'])} | |
| def get_output_dim(self): | |
| return self.output_dim | |
| class LDConditioner(nn.Module): | |
| ''' | |
| Conditions ssl output by listener embedding | |
| ''' | |
| def __init__(self,input_dim, judge_dim, num_judges=None): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.judge_dim = judge_dim | |
| self.num_judges = num_judges | |
| assert num_judges !=None | |
| self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) | |
| # concat [self.output_layer, phoneme features] | |
| self.decoder_rnn = nn.LSTM( | |
| input_size = self.input_dim + self.judge_dim, | |
| hidden_size = 512, | |
| num_layers = 1, | |
| batch_first = True, | |
| bidirectional = True | |
| ) # linear? | |
| self.out_dim = self.decoder_rnn.hidden_size*2 | |
| def get_output_dim(self): | |
| return self.out_dim | |
| def forward(self, x, batch): | |
| judge_ids = batch['judge_id'] | |
| if 'phoneme-feature' in x.keys(): | |
| concatenated_feature = torch.cat((x['ssl-feature'], x['phoneme-feature'].unsqueeze(1).expand(-1,x['ssl-feature'].size(1) ,-1)),dim=2) | |
| else: | |
| concatenated_feature = x['ssl-feature'] | |
| if 'domain-feature' in x.keys(): | |
| concatenated_feature = torch.cat( | |
| ( | |
| concatenated_feature, | |
| x['domain-feature'] | |
| .unsqueeze(1) | |
| .expand(-1, concatenated_feature.size(1), -1), | |
| ), | |
| dim=2, | |
| ) | |
| if judge_ids != None: | |
| concatenated_feature = torch.cat( | |
| ( | |
| concatenated_feature, | |
| self.judge_embedding(judge_ids) | |
| .unsqueeze(1) | |
| .expand(-1, concatenated_feature.size(1), -1), | |
| ), | |
| dim=2, | |
| ) | |
| decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) | |
| return decoder_output | |
| class Projection(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): | |
| super(Projection, self).__init__() | |
| self.range_clipping = range_clipping | |
| output_dim = 1 | |
| if range_clipping: | |
| self.proj = nn.Tanh() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| activation, | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim, output_dim), | |
| ) | |
| self.output_dim = output_dim | |
| def forward(self, x, batch): | |
| output = self.net(x) | |
| # range clipping | |
| if self.range_clipping: | |
| return self.proj(output) * 2.0 + 3 | |
| else: | |
| return output | |
| def get_output_dim(self): | |
| return self.output_dim | |