| """Conv-based prosody predictor (~681K params).""" |
|
|
| import math |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=4096): |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
| def forward(self, x): |
| |
| return x + self.pe[:, :x.size(1)] |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, channels, kernel_size=5, dropout=0.2): |
| super().__init__() |
| self.conv = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2) |
| self.norm = nn.LayerNorm(channels) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| |
| out = self.conv(x) |
| out = torch.relu(out) |
| out = out.transpose(1, 2) |
| out = self.norm(out) |
| out = out.transpose(1, 2) |
| out = self.drop(out) |
| return out |
|
|
|
|
| class CharEncoder(nn.Module): |
| def __init__(self, vocab_size=50, d_model=128, n_layers=4, kernel_size=5, dropout=0.2): |
| super().__init__() |
| self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0) |
| self.pos_enc = SinusoidalPositionalEncoding(d_model) |
| self.blocks = nn.ModuleList([ |
| ConvBlock(d_model, kernel_size, dropout) for _ in range(n_layers) |
| ]) |
|
|
| def forward(self, char_ids): |
| |
| x = self.embed(char_ids) |
| x = self.pos_enc(x) |
| x = x.transpose(1, 2) |
| for block in self.blocks: |
| x = block(x) |
| return x |
|
|
|
|
| class DurationPredictor(nn.Module): |
| """Predicts log-duration from detached encoder output (tiny-tts pattern).""" |
| def __init__(self, d_model=128, kernel_size=3, dropout=0.2): |
| super().__init__() |
| self.conv1 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.drop1 = nn.Dropout(dropout) |
| self.conv2 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.drop2 = nn.Dropout(dropout) |
| self.proj = nn.Conv1d(d_model, 1, 1) |
|
|
| def forward(self, x, mask=None): |
| |
| x = x.detach() |
| h = self.conv1(x) |
| h = torch.relu(h) |
| h = h.transpose(1, 2) |
| h = self.norm1(h) |
| h = h.transpose(1, 2) |
| h = self.drop1(h) |
| h = self.conv2(h) |
| h = torch.relu(h) |
| h = h.transpose(1, 2) |
| h = self.norm2(h) |
| h = h.transpose(1, 2) |
| h = self.drop2(h) |
| h = self.proj(h) |
| if mask is not None: |
| h = h * mask.unsqueeze(1) |
| return h.squeeze(1) |
|
|
|
|
| def length_regulate(encoder_out, durations): |
| """Repeat encoder frames according to durations. |
| |
| Args: |
| encoder_out: [B, D, L] |
| durations: [B, L] (integer) |
| |
| Returns: |
| regulated: [B, D, T_max] zero-padded |
| frame_lengths: [B] |
| """ |
| B, D, L = encoder_out.shape |
| durations = durations.long() |
| frame_lengths = durations.sum(dim=1) |
| T_max = frame_lengths.max().item() |
|
|
| regulated = torch.zeros(B, D, T_max, device=encoder_out.device, dtype=encoder_out.dtype) |
| for b in range(B): |
| idx = 0 |
| for l in range(L): |
| dur = durations[b, l].item() |
| if dur > 0 and idx < T_max: |
| end = min(idx + dur, T_max) |
| regulated[b, :, idx:end] = encoder_out[b, :, l:l+1].expand(-1, end - idx) |
| idx = end |
| return regulated, frame_lengths |
|
|
|
|
| class FrameDecoder(nn.Module): |
| def __init__(self, d_model=128, n_layers=3, kernel_size=5, dropout=0.2): |
| super().__init__() |
| self.blocks = nn.ModuleList([ |
| ConvBlock(d_model, kernel_size, dropout) for _ in range(n_layers) |
| ]) |
| self.proj = nn.Conv1d(d_model, 2, 1) |
|
|
| def forward(self, x): |
| |
| for block in self.blocks: |
| x = block(x) |
| return self.proj(x) |
|
|
|
|
| class ProsodyPredictor(nn.Module): |
| def __init__(self, vocab_size=50, d_model=128, dropout=0.2): |
| super().__init__() |
| self.encoder = CharEncoder(vocab_size, d_model, n_layers=4, kernel_size=5, dropout=dropout) |
| self.dur_predictor = DurationPredictor(d_model, kernel_size=3, dropout=dropout) |
| self.decoder = FrameDecoder(d_model, n_layers=3, kernel_size=5, dropout=dropout) |
|
|
| def forward(self, char_ids, durations=None, char_lengths=None): |
| """ |
| Args: |
| char_ids: [B, L] |
| durations: [B, L] ground-truth (training) or None (inference) |
| char_lengths: [B] for masking |
| |
| Returns: |
| pred_f0: [B, T] |
| pred_rms: [B, T] |
| pred_log_dur: [B, L] |
| frame_lengths: [B] |
| """ |
| |
| enc_out = self.encoder(char_ids) |
|
|
| |
| char_mask = None |
| if char_lengths is not None: |
| char_mask = torch.arange(char_ids.size(1), device=char_ids.device).unsqueeze(0) < char_lengths.unsqueeze(1) |
| char_mask = char_mask.float() |
| pred_log_dur = self.dur_predictor(enc_out, char_mask) |
|
|
| |
| if durations is not None: |
| |
| regulated, frame_lengths = length_regulate(enc_out, durations) |
| else: |
| |
| pred_dur = torch.round(torch.exp(pred_log_dur)).long().clamp(min=1) |
| if char_mask is not None: |
| pred_dur = pred_dur * char_mask.long() |
| regulated, frame_lengths = length_regulate(enc_out, pred_dur) |
|
|
| |
| out = self.decoder(regulated) |
| pred_f0 = out[:, 0] |
| pred_rms = out[:, 1] |
|
|
| return pred_f0, pred_rms, pred_log_dur, frame_lengths |
|
|