prosody-predictor / model_prosody.py
hidude562's picture
Upload model_prosody.py with huggingface_hub
33ab2a7 verified
"""Conv-based prosody predictor (~681K params)."""
import math
import torch
import torch.nn as nn
class SinusoidalPositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=4096):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_len, d_model]
def forward(self, x):
# x: [B, L, D]
return x + self.pe[:, :x.size(1)]
class ConvBlock(nn.Module):
def __init__(self, channels, kernel_size=5, dropout=0.2):
super().__init__()
self.conv = nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2)
self.norm = nn.LayerNorm(channels)
self.drop = nn.Dropout(dropout)
def forward(self, x):
# x: [B, C, T]
out = self.conv(x)
out = torch.relu(out)
out = out.transpose(1, 2) # [B, T, C]
out = self.norm(out)
out = out.transpose(1, 2) # [B, C, T]
out = self.drop(out)
return out
class CharEncoder(nn.Module):
def __init__(self, vocab_size=50, d_model=128, n_layers=4, kernel_size=5, dropout=0.2):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_enc = SinusoidalPositionalEncoding(d_model)
self.blocks = nn.ModuleList([
ConvBlock(d_model, kernel_size, dropout) for _ in range(n_layers)
])
def forward(self, char_ids):
# char_ids: [B, L]
x = self.embed(char_ids) # [B, L, D]
x = self.pos_enc(x) # [B, L, D]
x = x.transpose(1, 2) # [B, D, L]
for block in self.blocks:
x = block(x)
return x # [B, D, L]
class DurationPredictor(nn.Module):
"""Predicts log-duration from detached encoder output (tiny-tts pattern)."""
def __init__(self, d_model=128, kernel_size=3, dropout=0.2):
super().__init__()
self.conv1 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2)
self.norm1 = nn.LayerNorm(d_model)
self.drop1 = nn.Dropout(dropout)
self.conv2 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size // 2)
self.norm2 = nn.LayerNorm(d_model)
self.drop2 = nn.Dropout(dropout)
self.proj = nn.Conv1d(d_model, 1, 1)
def forward(self, x, mask=None):
# x: [B, D, L] — detached
x = x.detach()
h = self.conv1(x)
h = torch.relu(h)
h = h.transpose(1, 2)
h = self.norm1(h)
h = h.transpose(1, 2)
h = self.drop1(h)
h = self.conv2(h)
h = torch.relu(h)
h = h.transpose(1, 2)
h = self.norm2(h)
h = h.transpose(1, 2)
h = self.drop2(h)
h = self.proj(h) # [B, 1, L]
if mask is not None:
h = h * mask.unsqueeze(1)
return h.squeeze(1) # [B, L]
def length_regulate(encoder_out, durations):
"""Repeat encoder frames according to durations.
Args:
encoder_out: [B, D, L]
durations: [B, L] (integer)
Returns:
regulated: [B, D, T_max] zero-padded
frame_lengths: [B]
"""
B, D, L = encoder_out.shape
durations = durations.long()
frame_lengths = durations.sum(dim=1) # [B]
T_max = frame_lengths.max().item()
regulated = torch.zeros(B, D, T_max, device=encoder_out.device, dtype=encoder_out.dtype)
for b in range(B):
idx = 0
for l in range(L):
dur = durations[b, l].item()
if dur > 0 and idx < T_max:
end = min(idx + dur, T_max)
regulated[b, :, idx:end] = encoder_out[b, :, l:l+1].expand(-1, end - idx)
idx = end
return regulated, frame_lengths
class FrameDecoder(nn.Module):
def __init__(self, d_model=128, n_layers=3, kernel_size=5, dropout=0.2):
super().__init__()
self.blocks = nn.ModuleList([
ConvBlock(d_model, kernel_size, dropout) for _ in range(n_layers)
])
self.proj = nn.Conv1d(d_model, 2, 1) # [f0, rms]
def forward(self, x):
# x: [B, D, T]
for block in self.blocks:
x = block(x)
return self.proj(x) # [B, 2, T]
class ProsodyPredictor(nn.Module):
def __init__(self, vocab_size=50, d_model=128, dropout=0.2):
super().__init__()
self.encoder = CharEncoder(vocab_size, d_model, n_layers=4, kernel_size=5, dropout=dropout)
self.dur_predictor = DurationPredictor(d_model, kernel_size=3, dropout=dropout)
self.decoder = FrameDecoder(d_model, n_layers=3, kernel_size=5, dropout=dropout)
def forward(self, char_ids, durations=None, char_lengths=None):
"""
Args:
char_ids: [B, L]
durations: [B, L] ground-truth (training) or None (inference)
char_lengths: [B] for masking
Returns:
pred_f0: [B, T]
pred_rms: [B, T]
pred_log_dur: [B, L]
frame_lengths: [B]
"""
# Encoder
enc_out = self.encoder(char_ids) # [B, D, L]
# Duration prediction (detached input)
char_mask = None
if char_lengths is not None:
char_mask = torch.arange(char_ids.size(1), device=char_ids.device).unsqueeze(0) < char_lengths.unsqueeze(1)
char_mask = char_mask.float()
pred_log_dur = self.dur_predictor(enc_out, char_mask) # [B, L]
# Length regulate
if durations is not None:
# Training: use GT durations
regulated, frame_lengths = length_regulate(enc_out, durations)
else:
# Inference: use predicted durations
pred_dur = torch.round(torch.exp(pred_log_dur)).long().clamp(min=1)
if char_mask is not None:
pred_dur = pred_dur * char_mask.long()
regulated, frame_lengths = length_regulate(enc_out, pred_dur)
# Decode
out = self.decoder(regulated) # [B, 2, T]
pred_f0 = out[:, 0] # [B, T]
pred_rms = out[:, 1] # [B, T]
return pred_f0, pred_rms, pred_log_dur, frame_lengths