import torch import torch.nn as nn from torch.nn import Transformer from UpdatedTransformer import neko_TransformerEncoderLayer # To speed-up training process torch.autograd.set_detect_anomaly(False) torch.autograd.profiler.profile(False) torch.autograd.profiler.emit_nvtx(False) import math import json with open("config.json") as json_data_file: data = json.load(json_data_file) N_enc = data['N_encoder'] N_heads_enc = data['N_heads_enc'] N_dec = data['N_decoder'] N_heads_dec = data['N_heads_dec'] gen_emb = data['Gen_Embed'] fwd_exp = data['Forward_Expansion'] dropout = data['Dropout'] device = 'cpu' class BiLSTM(nn.Module): def __init__(self, phonemed, speaker, gender, seq_len): super(BiLSTM, self).__init__() self.phonemed = phonemed self.speakered = speaker self.gendered = gender self.device = device self.seq_len = seq_len factor = (self.phonemed * 1) + (self.speakered * 1) + (self.gendered * 1) + 1 self.embed_mfcc = nn.Sequential( nn.Linear(13, gen_emb), nn.ReLU(), nn.Linear(gen_emb, gen_emb)) self.emb_phoneme = nn.Sequential( nn.Embedding(40, gen_emb), nn.ReLU(), nn.Linear(gen_emb, gen_emb)) self.emb_speaker = nn.Sequential( nn.Embedding(38, gen_emb), nn.ReLU(), nn.Linear(gen_emb, gen_emb)) self.emb_gender = nn.Sequential( nn.Embedding(2, gen_emb), nn.ReLU(), nn.Linear(gen_emb, gen_emb)) self.pre_position = nn.Sequential( nn.ReLU(), nn.Linear(gen_emb * factor, gen_emb)) self.aai_rnn = nn.LSTM(gen_emb, fwd_exp, 3, bidirectional = True, batch_first = True) self.artic_linear = nn.Sequential( nn.ReLU(), nn.Linear(fwd_exp * 2, 32), nn.ReLU(), nn.Linear(32, 12)) def forward(self, mfcc, pho, spk, gnd): mfcc_embedded = self.embed_mfcc(mfcc) cat_embedded = mfcc_embedded if self.phonemed: pho_embedded = self.emb_phoneme(pho) cat_embedded = torch.concat((mfcc_embedded, pho_embedded), -1) cat_embedded = self.pre_position(cat_embedded) final, (hidden, cell) = self.aai_rnn(cat_embedded) artic = self.artic_linear(final) if torch.sum(artic[0]).isnan(): print('Encountered Nan. Exiting program ...') exit(0) return artic