File size: 2,545 Bytes
709ed73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from torch.nn import Transformer
from UpdatedTransformer import neko_TransformerEncoderLayer

# To speed-up training process
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)

import math
import json

with open("config.json") as json_data_file:
    data = json.load(json_data_file)

N_enc = data['N_encoder']
N_heads_enc = data['N_heads_enc']

N_dec = data['N_decoder']
N_heads_dec = data['N_heads_dec']

gen_emb = data['Gen_Embed']
fwd_exp = data['Forward_Expansion']

dropout = data['Dropout']
device = 'cpu'

       
class BiLSTM(nn.Module):
    def __init__(self, phonemed, speaker, gender, seq_len):

        super(BiLSTM, self).__init__()

        self.phonemed = phonemed
        self.speakered = speaker
        self.gendered = gender
        self.device = device

        self.seq_len = seq_len

        factor = (self.phonemed * 1) + (self.speakered * 1) + (self.gendered * 1) + 1

        self.embed_mfcc = nn.Sequential(
            nn.Linear(13, gen_emb),
            nn.ReLU(),
            nn.Linear(gen_emb, gen_emb))
        
        self.emb_phoneme = nn.Sequential(
            nn.Embedding(40, gen_emb),
            nn.ReLU(),
            nn.Linear(gen_emb, gen_emb))

        self.emb_speaker = nn.Sequential(
            nn.Embedding(38, gen_emb),
            nn.ReLU(),
            nn.Linear(gen_emb, gen_emb))

        self.emb_gender = nn.Sequential(
            nn.Embedding(2, gen_emb),
            nn.ReLU(),
            nn.Linear(gen_emb, gen_emb))

        self.pre_position = nn.Sequential(
            nn.ReLU(),
            nn.Linear(gen_emb * factor, gen_emb))

        self.aai_rnn = nn.LSTM(gen_emb, fwd_exp, 3, bidirectional = True, batch_first = True)

        self.artic_linear = nn.Sequential(
            nn.ReLU(),
            nn.Linear(fwd_exp * 2, 32),
            nn.ReLU(),
            nn.Linear(32, 12))


    def forward(self, mfcc, pho, spk, gnd):

        mfcc_embedded = self.embed_mfcc(mfcc)
        cat_embedded = mfcc_embedded

        if self.phonemed:
            pho_embedded = self.emb_phoneme(pho)
            cat_embedded = torch.concat((mfcc_embedded, pho_embedded), -1)
        

        cat_embedded = self.pre_position(cat_embedded)

        final, (hidden, cell) = self.aai_rnn(cat_embedded)

        artic = self.artic_linear(final)

        if torch.sum(artic[0]).isnan():
            print('Encountered Nan. Exiting program ...')
            exit(0)

        return artic