| import torch |
| import torch.nn as nn |
| from .positional_encoding import PositionalEncoding |
|
|
|
|
| class RealtimeTTS(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.embedding = nn.Embedding( |
| config.vocab_size, |
| config.d_model |
| ) |
|
|
| self.positional_encoding = PositionalEncoding( |
| config.d_model, |
| config.max_seq_len |
| ) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.nhead, |
| dim_feedforward=config.dim_feedforward, |
| batch_first=True |
| ) |
|
|
| self.transformer = nn.TransformerEncoder( |
| encoder_layer, |
| num_layers=config.num_layers |
| ) |
|
|
| self.output_linear = nn.Linear( |
| config.d_model, |
| 80 |
| ) |
|
|
| def forward(self, tokens, mel_input): |
| x = self.embedding(tokens) |
| x = self.positional_encoding(x) |
| x = self.transformer(x) |
| mel = self.output_linear(x) |
| return mel |