Spaces:
Sleeping
Sleeping
File size: 5,871 Bytes
be29b5b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
"""
Injects information about the relative or absolute position of the tokens
in the sequence. The model needs this because it has no recurrence.
"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register buffer allows us to save this with state_dict but not train it
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
# x shape: [batch_size, seq_len, d_model]
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class MiniTTS(nn.Module):
def __init__(self, num_chars, num_mels, d_model=256, nhead=4, num_layers=4):
super(MiniTTS, self).__init__()
# 1. Text Encoder Layers
self.embedding = nn.Embedding(num_chars, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
# 2. Spectrogram Decoder Layers
# We process the mel spectrogram frames (Standard Transformers use teacher forcing during training)
self.mel_embedding = nn.Linear(num_mels, d_model) # Project mel dimension to model dimension
self.pos_decoder = PositionalEncoding(d_model)
decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
# 3. Final Projection
# Project back from model dimension to Mel Spectrogram dimension (usually 80 channels)
self.output_layer = nn.Linear(d_model, num_mels)
# 4. Post-Net (Optional but recommended for TTS quality)
# Simple convolutional network to refine the output
self.post_net = nn.Sequential(
nn.Conv1d(num_mels, 512, kernel_size=5, padding=2),
nn.BatchNorm1d(512),
nn.Tanh(),
nn.Dropout(0.5),
nn.Conv1d(512, num_mels, kernel_size=5, padding=2)
)
def forward(self, text_tokens, mel_target=None):
"""
text_tokens: [batch, text_len] (Integers representing phonemes)
mel_target: [batch, mel_len, num_mels] (The target spectrogram for training)
"""
# --- ENCODING ---
# [batch, text_len] -> [batch, text_len, d_model]
src = self.embedding(text_tokens)
src = self.pos_encoder(src)
# Memory is the output of the encoder that the decoder attends to
memory = self.transformer_encoder(src)
# --- DECODING ---
if mel_target is not None:
# TRAINING MODE (Teacher Forcing)
# We feed the real spectrogram (shifted) into the decoder
tgt = self.mel_embedding(mel_target)
tgt = self.pos_decoder(tgt)
# Create a casual mask (prevent decoder from peeking at future frames)
batch_size, tgt_len, _ = tgt.shape
tgt_mask = self.generate_square_subsequent_mask(tgt_len).to(tgt.device)
output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
output_mel = self.output_layer(output)
# Post-net refinement
# Conv1d expects [batch, channels, time], so we transpose
output_mel_post = output_mel.transpose(1, 2)
output_mel_post = self.post_net(output_mel_post)
output_mel_post = output_mel_post.transpose(1, 2)
# Combine raw output + residual
final_output = output_mel + output_mel_post
return final_output
else:
# INFERENCE MODE (Greedy Decoding)
# We will handle this loop inside inference.py later
# For now, we just return the encoder memory so we can debug shapes
return memory
def generate_square_subsequent_mask(self, sz):
"""Generates an upper-triangular matrix of -inf, with zeros on diag."""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
# --- SANITY CHECK ---
# Run this file directly to check if dimensions work!
if __name__ == "__main__":
print("Testing Model Dimensions...")
# Dummy Config
num_chars = 50 # Size of vocabulary (phonemes)
num_mels = 80 # Standard Mel Spectrogram channels
batch_size = 2
text_len = 10
mel_len = 100
# Instantiate Model
model = MiniTTS(num_chars, num_mels)
# Create Dummy Data
dummy_text = torch.randint(0, num_chars, (batch_size, text_len))
dummy_mel = torch.randn(batch_size, mel_len, num_mels)
# Forward Pass
try:
output = model(dummy_text, dummy_mel)
print(f"Input Text Shape: {dummy_text.shape}")
print(f"Input Mel Shape: {dummy_mel.shape}")
print(f"Output Shape: {output.shape}")
print("\nSUCCESS: The architecture is valid!")
except Exception as e:
print(f"\nERROR: {e}") |