Siddarth commited on
Commit
709ed73
·
1 Parent(s): 872069c

Upload Transformer.py

Browse files
Files changed (1) hide show
  1. Transformer.py +97 -0
Transformer.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import Transformer
4
+ from UpdatedTransformer import neko_TransformerEncoderLayer
5
+
6
+ # To speed-up training process
7
+ torch.autograd.set_detect_anomaly(False)
8
+ torch.autograd.profiler.profile(False)
9
+ torch.autograd.profiler.emit_nvtx(False)
10
+
11
+ import math
12
+ import json
13
+
14
+ with open("config.json") as json_data_file:
15
+ data = json.load(json_data_file)
16
+
17
+ N_enc = data['N_encoder']
18
+ N_heads_enc = data['N_heads_enc']
19
+
20
+ N_dec = data['N_decoder']
21
+ N_heads_dec = data['N_heads_dec']
22
+
23
+ gen_emb = data['Gen_Embed']
24
+ fwd_exp = data['Forward_Expansion']
25
+
26
+ dropout = data['Dropout']
27
+ device = 'cpu'
28
+
29
+
30
+ class BiLSTM(nn.Module):
31
+ def __init__(self, phonemed, speaker, gender, seq_len):
32
+
33
+ super(BiLSTM, self).__init__()
34
+
35
+ self.phonemed = phonemed
36
+ self.speakered = speaker
37
+ self.gendered = gender
38
+ self.device = device
39
+
40
+ self.seq_len = seq_len
41
+
42
+ factor = (self.phonemed * 1) + (self.speakered * 1) + (self.gendered * 1) + 1
43
+
44
+ self.embed_mfcc = nn.Sequential(
45
+ nn.Linear(13, gen_emb),
46
+ nn.ReLU(),
47
+ nn.Linear(gen_emb, gen_emb))
48
+
49
+ self.emb_phoneme = nn.Sequential(
50
+ nn.Embedding(40, gen_emb),
51
+ nn.ReLU(),
52
+ nn.Linear(gen_emb, gen_emb))
53
+
54
+ self.emb_speaker = nn.Sequential(
55
+ nn.Embedding(38, gen_emb),
56
+ nn.ReLU(),
57
+ nn.Linear(gen_emb, gen_emb))
58
+
59
+ self.emb_gender = nn.Sequential(
60
+ nn.Embedding(2, gen_emb),
61
+ nn.ReLU(),
62
+ nn.Linear(gen_emb, gen_emb))
63
+
64
+ self.pre_position = nn.Sequential(
65
+ nn.ReLU(),
66
+ nn.Linear(gen_emb * factor, gen_emb))
67
+
68
+ self.aai_rnn = nn.LSTM(gen_emb, fwd_exp, 3, bidirectional = True, batch_first = True)
69
+
70
+ self.artic_linear = nn.Sequential(
71
+ nn.ReLU(),
72
+ nn.Linear(fwd_exp * 2, 32),
73
+ nn.ReLU(),
74
+ nn.Linear(32, 12))
75
+
76
+
77
+ def forward(self, mfcc, pho, spk, gnd):
78
+
79
+ mfcc_embedded = self.embed_mfcc(mfcc)
80
+ cat_embedded = mfcc_embedded
81
+
82
+ if self.phonemed:
83
+ pho_embedded = self.emb_phoneme(pho)
84
+ cat_embedded = torch.concat((mfcc_embedded, pho_embedded), -1)
85
+
86
+
87
+ cat_embedded = self.pre_position(cat_embedded)
88
+
89
+ final, (hidden, cell) = self.aai_rnn(cat_embedded)
90
+
91
+ artic = self.artic_linear(final)
92
+
93
+ if torch.sum(artic[0]).isnan():
94
+ print('Encountered Nan. Exiting program ...')
95
+ exit(0)
96
+
97
+ return artic