API_MC_AI / VietTTS /hifigan /f0_predictor.py
duyv's picture
Upload 86 files
a257816 verified
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
class ConvRNNF0Predictor(nn.Module):
def __init__(self,
num_class: int = 1,
in_channels: int = 80,
cond_channels: int = 512
):
super().__init__()
self.num_class = num_class
self.condnet = nn.Sequential(
weight_norm(
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
weight_norm(
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
),
nn.ELU(),
)
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.condnet(x)
x = x.transpose(1, 2)
return torch.abs(self.classifier(x).squeeze(-1))