"""Conv-based prosody predictor (~681K params).""" import math import torch import torch.nn as nn class SinusoidalPositionalEncoding(nn.Module): def __init__(self, d_model, max_len=4096): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model] def forward(self, x): # x: [B, L, D] return x + self.pe[:, :x.size(1)] class ConvBlock(nn.Module): def __init__(self, channels, kernel_size=5, dropout=0.2): super().__init__() self.conv = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2) self.norm = nn.LayerNorm(channels) self.drop = nn.Dropout(dropout) def forward(self, x): # x: [B, C, T] out = self.conv(x) out = torch.relu(out) out = out.transpose(1, 2) # [B, T, C] out = self.norm(out) out = out.transpose(1, 2) # [B, C, T] out = self.drop(out) return out class CharEncoder(nn.Module): def __init__(self, vocab_size=50, d_model=128, n_layers=4, kernel_size=5, dropout=0.2): super().__init__() self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0) self.pos_enc = SinusoidalPositionalEncoding(d_model) self.blocks = nn.ModuleList([ ConvBlock(d_model, kernel_size, dropout) for _ in range(n_layers) ]) def forward(self, char_ids): # char_ids: [B, L] x = self.embed(char_ids) # [B, L, D] x = self.pos_enc(x) # [B, L, D] x = x.transpose(1, 2) # [B, D, L] for block in self.blocks: x = block(x) return x # [B, D, L] class DurationPredictor(nn.Module): """Predicts log-duration from detached encoder output (tiny-tts pattern).""" def __init__(self, d_model=128, kernel_size=3, dropout=0.2): super().__init__() self.conv1 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2) self.norm1 = nn.LayerNorm(d_model) self.drop1 = nn.Dropout(dropout) self.conv2 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2) self.norm2 = nn.LayerNorm(d_model) self.drop2 = nn.Dropout(dropout) self.proj = nn.Conv1d(d_model, 1, 1) def forward(self, x, mask=None): # x: [B, D, L] — detached x = x.detach() h = self.conv1(x) h = torch.relu(h) h = h.transpose(1, 2) h = self.norm1(h) h = h.transpose(1, 2) h = self.drop1(h) h = self.conv2(h) h = torch.relu(h) h = h.transpose(1, 2) h = self.norm2(h) h = h.transpose(1, 2) h = self.drop2(h) h = self.proj(h) # [B, 1, L] if mask is not None: h = h * mask.unsqueeze(1) return h.squeeze(1) # [B, L] def length_regulate(encoder_out, durations): """Repeat encoder frames according to durations. Args: encoder_out: [B, D, L] durations: [B, L] (integer) Returns: regulated: [B, D, T_max] zero-padded frame_lengths: [B] """ B, D, L = encoder_out.shape durations = durations.long() frame_lengths = durations.sum(dim=1) # [B] T_max = frame_lengths.max().item() regulated = torch.zeros(B, D, T_max, device=encoder_out.device, dtype=encoder_out.dtype) for b in range(B): idx = 0 for l in range(L): dur = durations[b, l].item() if dur > 0 and idx < T_max: end = min(idx + dur, T_max) regulated[b, :, idx:end] = encoder_out[b, :, l:l+1].expand(-1, end - idx) idx = end return regulated, frame_lengths class FrameDecoder(nn.Module): def __init__(self, d_model=128, n_layers=3, kernel_size=5, dropout=0.2): super().__init__() self.blocks = nn.ModuleList([ ConvBlock(d_model, kernel_size, dropout) for _ in range(n_layers) ]) self.proj = nn.Conv1d(d_model, 2, 1) # [f0, rms] def forward(self, x): # x: [B, D, T] for block in self.blocks: x = block(x) return self.proj(x) # [B, 2, T] class ProsodyPredictor(nn.Module): def __init__(self, vocab_size=50, d_model=128, dropout=0.2): super().__init__() self.encoder = CharEncoder(vocab_size, d_model, n_layers=4, kernel_size=5, dropout=dropout) self.dur_predictor = DurationPredictor(d_model, kernel_size=3, dropout=dropout) self.decoder = FrameDecoder(d_model, n_layers=3, kernel_size=5, dropout=dropout) def forward(self, char_ids, durations=None, char_lengths=None): """ Args: char_ids: [B, L] durations: [B, L] ground-truth (training) or None (inference) char_lengths: [B] for masking Returns: pred_f0: [B, T] pred_rms: [B, T] pred_log_dur: [B, L] frame_lengths: [B] """ # Encoder enc_out = self.encoder(char_ids) # [B, D, L] # Duration prediction (detached input) char_mask = None if char_lengths is not None: char_mask = torch.arange(char_ids.size(1), device=char_ids.device).unsqueeze(0) < char_lengths.unsqueeze(1) char_mask = char_mask.float() pred_log_dur = self.dur_predictor(enc_out, char_mask) # [B, L] # Length regulate if durations is not None: # Training: use GT durations regulated, frame_lengths = length_regulate(enc_out, durations) else: # Inference: use predicted durations pred_dur = torch.round(torch.exp(pred_log_dur)).long().clamp(min=1) if char_mask is not None: pred_dur = pred_dur * char_mask.long() regulated, frame_lengths = length_regulate(enc_out, pred_dur) # Decode out = self.decoder(regulated) # [B, 2, T] pred_f0 = out[:, 0] # [B, T] pred_rms = out[:, 1] # [B, T] return pred_f0, pred_rms, pred_log_dur, frame_lengths