drixo commited on
Commit
430a758
·
verified ·
1 Parent(s): 42e537a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +29 -36
model.py CHANGED
@@ -1,49 +1,42 @@
1
  import torch
2
  import torch.nn as nn
3
- from .positional_encoding import FramePositionalEncoding
4
- from .config import TTSConfig
5
 
6
  class RealtimeTTS(nn.Module):
7
- def __init__(self, config=TTSConfig()):
8
  super().__init__()
9
 
10
- self.embedding = nn.Embedding(config.vocab_size, config.d_model)
11
-
12
- self.encoder = nn.TransformerEncoder(
13
- nn.TransformerEncoderLayer(
14
- d_model=config.d_model,
15
- nhead=config.n_heads,
16
- batch_first=True
17
- ),
18
- num_layers=config.num_encoder_layers
19
  )
20
 
21
- self.frame_pe = FramePositionalEncoding(config.d_model)
22
-
23
- self.decoder = nn.TransformerDecoder(
24
- nn.TransformerDecoderLayer(
25
- d_model=config.d_model,
26
- nhead=config.n_heads,
27
- batch_first=True
28
- ),
29
- num_layers=config.num_decoder_layers
30
  )
31
 
32
- self.mel_projection = nn.Linear(config.d_model, config.mel_bins)
33
-
34
- def forward(self, text_tokens, mel_inputs):
35
- # Text embedding
36
- x = self.embedding(text_tokens)
37
-
38
- # Text encoding
39
- memory = self.encoder(x)
40
-
41
- # Frame positional encoding
42
- mel_inputs = self.frame_pe(mel_inputs)
43
 
44
- # Decode mel frames
45
- out = self.decoder(mel_inputs, memory)
 
 
46
 
47
- mel_output = self.mel_projection(out)
 
 
 
48
 
49
- return mel_output
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from .positional_encoding import PositionalEncoding
4
+
5
 
6
  class RealtimeTTS(nn.Module):
7
+ def __init__(self, config):
8
  super().__init__()
9
 
10
+ self.embedding = nn.Embedding(
11
+ config.vocab_size,
12
+ config.d_model
 
 
 
 
 
 
13
  )
14
 
15
+ self.positional_encoding = PositionalEncoding(
16
+ config.d_model,
17
+ config.max_seq_len
 
 
 
 
 
 
18
  )
19
 
20
+ encoder_layer = nn.TransformerEncoderLayer(
21
+ d_model=config.d_model,
22
+ nhead=config.nhead,
23
+ dim_feedforward=config.dim_feedforward,
24
+ batch_first=True
25
+ )
 
 
 
 
 
26
 
27
+ self.transformer = nn.TransformerEncoder(
28
+ encoder_layer,
29
+ num_layers=config.num_layers
30
+ )
31
 
32
+ self.output_linear = nn.Linear(
33
+ config.d_model,
34
+ 80 # mel bins
35
+ )
36
 
37
+ def forward(self, tokens, mel_input):
38
+ x = self.embedding(tokens)
39
+ x = self.positional_encoding(x)
40
+ x = self.transformer(x)
41
+ mel = self.output_linear(x)
42
+ return mel