File size: 5,871 Bytes
be29b5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    """

    Injects information about the relative or absolute position of the tokens

    in the sequence. The model needs this because it has no recurrence.

    """
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # Register buffer allows us to save this with state_dict but not train it
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x shape: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class MiniTTS(nn.Module):
    def __init__(self, num_chars, num_mels, d_model=256, nhead=4, num_layers=4):
        super(MiniTTS, self).__init__()
        
        # 1. Text Encoder Layers
        self.embedding = nn.Embedding(num_chars, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # 2. Spectrogram Decoder Layers
        # We process the mel spectrogram frames (Standard Transformers use teacher forcing during training)
        self.mel_embedding = nn.Linear(num_mels, d_model) # Project mel dimension to model dimension
        self.pos_decoder = PositionalEncoding(d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        # 3. Final Projection
        # Project back from model dimension to Mel Spectrogram dimension (usually 80 channels)
        self.output_layer = nn.Linear(d_model, num_mels)
        
        # 4. Post-Net (Optional but recommended for TTS quality)
        # Simple convolutional network to refine the output
        self.post_net = nn.Sequential(
            nn.Conv1d(num_mels, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, num_mels, kernel_size=5, padding=2)
        )

    def forward(self, text_tokens, mel_target=None):
        """

        text_tokens: [batch, text_len] (Integers representing phonemes)

        mel_target: [batch, mel_len, num_mels] (The target spectrogram for training)

        """
        # --- ENCODING ---
        # [batch, text_len] -> [batch, text_len, d_model]
        src = self.embedding(text_tokens)
        src = self.pos_encoder(src)
        
        # Memory is the output of the encoder that the decoder attends to
        memory = self.transformer_encoder(src)
        
        # --- DECODING ---
        if mel_target is not None:
            # TRAINING MODE (Teacher Forcing)
            # We feed the real spectrogram (shifted) into the decoder
            tgt = self.mel_embedding(mel_target)
            tgt = self.pos_decoder(tgt)
            
            # Create a casual mask (prevent decoder from peeking at future frames)
            batch_size, tgt_len, _ = tgt.shape
            tgt_mask = self.generate_square_subsequent_mask(tgt_len).to(tgt.device)
            
            output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
            output_mel = self.output_layer(output)
            
            # Post-net refinement
            # Conv1d expects [batch, channels, time], so we transpose
            output_mel_post = output_mel.transpose(1, 2)
            output_mel_post = self.post_net(output_mel_post)
            output_mel_post = output_mel_post.transpose(1, 2)
            
            # Combine raw output + residual
            final_output = output_mel + output_mel_post
            
            return final_output
        else:
            # INFERENCE MODE (Greedy Decoding)
            # We will handle this loop inside inference.py later
            # For now, we just return the encoder memory so we can debug shapes
            return memory

    def generate_square_subsequent_mask(self, sz):
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

# --- SANITY CHECK ---
# Run this file directly to check if dimensions work!
if __name__ == "__main__":
    print("Testing Model Dimensions...")
    
    # Dummy Config
    num_chars = 50   # Size of vocabulary (phonemes)
    num_mels = 80    # Standard Mel Spectrogram channels
    batch_size = 2
    text_len = 10
    mel_len = 100
    
    # Instantiate Model
    model = MiniTTS(num_chars, num_mels)
    
    # Create Dummy Data
    dummy_text = torch.randint(0, num_chars, (batch_size, text_len))
    dummy_mel = torch.randn(batch_size, mel_len, num_mels)
    
    # Forward Pass
    try:
        output = model(dummy_text, dummy_mel)
        print(f"Input Text Shape: {dummy_text.shape}")
        print(f"Input Mel Shape:  {dummy_mel.shape}")
        print(f"Output Shape:     {output.shape}")
        print("\nSUCCESS: The architecture is valid!")
    except Exception as e:
        print(f"\nERROR: {e}")