Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn import Transformer | |
| from UpdatedTransformer import neko_TransformerEncoderLayer | |
| # To speed-up training process | |
| torch.autograd.set_detect_anomaly(False) | |
| torch.autograd.profiler.profile(False) | |
| torch.autograd.profiler.emit_nvtx(False) | |
| import math | |
| import json | |
| with open("config.json") as json_data_file: | |
| data = json.load(json_data_file) | |
| N_enc = data['N_encoder'] | |
| N_heads_enc = data['N_heads_enc'] | |
| N_dec = data['N_decoder'] | |
| N_heads_dec = data['N_heads_dec'] | |
| gen_emb = data['Gen_Embed'] | |
| fwd_exp = data['Forward_Expansion'] | |
| dropout = data['Dropout'] | |
| device = 'cpu' | |
| class BiLSTM(nn.Module): | |
| def __init__(self, phonemed, speaker, gender, seq_len): | |
| super(BiLSTM, self).__init__() | |
| self.phonemed = phonemed | |
| self.speakered = speaker | |
| self.gendered = gender | |
| self.device = device | |
| self.seq_len = seq_len | |
| factor = (self.phonemed * 1) + (self.speakered * 1) + (self.gendered * 1) + 1 | |
| self.embed_mfcc = nn.Sequential( | |
| nn.Linear(13, gen_emb), | |
| nn.ReLU(), | |
| nn.Linear(gen_emb, gen_emb)) | |
| self.emb_phoneme = nn.Sequential( | |
| nn.Embedding(40, gen_emb), | |
| nn.ReLU(), | |
| nn.Linear(gen_emb, gen_emb)) | |
| self.emb_speaker = nn.Sequential( | |
| nn.Embedding(38, gen_emb), | |
| nn.ReLU(), | |
| nn.Linear(gen_emb, gen_emb)) | |
| self.emb_gender = nn.Sequential( | |
| nn.Embedding(2, gen_emb), | |
| nn.ReLU(), | |
| nn.Linear(gen_emb, gen_emb)) | |
| self.pre_position = nn.Sequential( | |
| nn.ReLU(), | |
| nn.Linear(gen_emb * factor, gen_emb)) | |
| self.aai_rnn = nn.LSTM(gen_emb, fwd_exp, 3, bidirectional = True, batch_first = True) | |
| self.artic_linear = nn.Sequential( | |
| nn.ReLU(), | |
| nn.Linear(fwd_exp * 2, 32), | |
| nn.ReLU(), | |
| nn.Linear(32, 12)) | |
| def forward(self, mfcc, pho, spk, gnd): | |
| mfcc_embedded = self.embed_mfcc(mfcc) | |
| cat_embedded = mfcc_embedded | |
| if self.phonemed: | |
| pho_embedded = self.emb_phoneme(pho) | |
| cat_embedded = torch.concat((mfcc_embedded, pho_embedded), -1) | |
| cat_embedded = self.pre_position(cat_embedded) | |
| final, (hidden, cell) = self.aai_rnn(cat_embedded) | |
| artic = self.artic_linear(final) | |
| if torch.sum(artic[0]).isnan(): | |
| print('Encountered Nan. Exiting program ...') | |
| exit(0) | |
| return artic | |