WatchMeSpeak / Transformer.py
Siddarth's picture
Upload Transformer.py
709ed73
import torch
import torch.nn as nn
from torch.nn import Transformer
from UpdatedTransformer import neko_TransformerEncoderLayer
# To speed-up training process
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
import math
import json
with open("config.json") as json_data_file:
data = json.load(json_data_file)
N_enc = data['N_encoder']
N_heads_enc = data['N_heads_enc']
N_dec = data['N_decoder']
N_heads_dec = data['N_heads_dec']
gen_emb = data['Gen_Embed']
fwd_exp = data['Forward_Expansion']
dropout = data['Dropout']
device = 'cpu'
class BiLSTM(nn.Module):
def __init__(self, phonemed, speaker, gender, seq_len):
super(BiLSTM, self).__init__()
self.phonemed = phonemed
self.speakered = speaker
self.gendered = gender
self.device = device
self.seq_len = seq_len
factor = (self.phonemed * 1) + (self.speakered * 1) + (self.gendered * 1) + 1
self.embed_mfcc = nn.Sequential(
nn.Linear(13, gen_emb),
nn.ReLU(),
nn.Linear(gen_emb, gen_emb))
self.emb_phoneme = nn.Sequential(
nn.Embedding(40, gen_emb),
nn.ReLU(),
nn.Linear(gen_emb, gen_emb))
self.emb_speaker = nn.Sequential(
nn.Embedding(38, gen_emb),
nn.ReLU(),
nn.Linear(gen_emb, gen_emb))
self.emb_gender = nn.Sequential(
nn.Embedding(2, gen_emb),
nn.ReLU(),
nn.Linear(gen_emb, gen_emb))
self.pre_position = nn.Sequential(
nn.ReLU(),
nn.Linear(gen_emb * factor, gen_emb))
self.aai_rnn = nn.LSTM(gen_emb, fwd_exp, 3, bidirectional = True, batch_first = True)
self.artic_linear = nn.Sequential(
nn.ReLU(),
nn.Linear(fwd_exp * 2, 32),
nn.ReLU(),
nn.Linear(32, 12))
def forward(self, mfcc, pho, spk, gnd):
mfcc_embedded = self.embed_mfcc(mfcc)
cat_embedded = mfcc_embedded
if self.phonemed:
pho_embedded = self.emb_phoneme(pho)
cat_embedded = torch.concat((mfcc_embedded, pho_embedded), -1)
cat_embedded = self.pre_position(cat_embedded)
final, (hidden, cell) = self.aai_rnn(cat_embedded)
artic = self.artic_linear(final)
if torch.sum(artic[0]).isnan():
print('Encountered Nan. Exiting program ...')
exit(0)
return artic