import torch import torch.nn as nn from .positional_encoding import PositionalEncoding class RealtimeTTS(nn.Module): def __init__(self, config): super().__init__() self.embedding = nn.Embedding( config.vocab_size, config.d_model ) self.positional_encoding = PositionalEncoding( config.d_model, config.max_seq_len ) encoder_layer = nn.TransformerEncoderLayer( d_model=config.d_model, nhead=config.nhead, dim_feedforward=config.dim_feedforward, batch_first=True ) self.transformer = nn.TransformerEncoder( encoder_layer, num_layers=config.num_layers ) self.output_linear = nn.Linear( config.d_model, 80 # mel bins ) def forward(self, tokens, mel_input): x = self.embedding(tokens) x = self.positional_encoding(x) x = self.transformer(x) mel = self.output_linear(x) return mel