realtime-tts / model.py
drixo's picture
Update model.py
430a758 verified
import torch
import torch.nn as nn
from .positional_encoding import PositionalEncoding
class RealtimeTTS(nn.Module):
def __init__(self, config):
super().__init__()
self.embedding = nn.Embedding(
config.vocab_size,
config.d_model
)
self.positional_encoding = PositionalEncoding(
config.d_model,
config.max_seq_len
)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.dim_feedforward,
batch_first=True
)
self.transformer = nn.TransformerEncoder(
encoder_layer,
num_layers=config.num_layers
)
self.output_linear = nn.Linear(
config.d_model,
80 # mel bins
)
def forward(self, tokens, mel_input):
x = self.embedding(tokens)
x = self.positional_encoding(x)
x = self.transformer(x)
mel = self.output_linear(x)
return mel