API_MC_AI / VietTTS /utils /class_utils.py
duyv's picture
Upload 86 files
a257816 verified
import torch
from VietTTS.transformer.activation import Swish
from VietTTS.transformer.subsampling import (
LinearNoSubsampling,
EmbedinigNoSubsampling,
Conv1dSubsampling2,
Conv2dSubsampling4,
Conv2dSubsampling6,
Conv2dSubsampling8,
)
from VietTTS.transformer.embedding import (
PositionalEncoding,
RelPositionalEncoding,
WhisperPositionalEncoding,
LearnablePositionalEncoding,
NoPositionalEncoding
)
from VietTTS.transformer.attention import (
MultiHeadedAttention,
RelPositionMultiHeadedAttention
)
from VietTTS.transformer.embedding import EspnetRelPositionalEncoding
from VietTTS.transformer.subsampling import LegacyLinearNoSubsampling
ACTIVATION_CLASSES = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": getattr(torch.nn, "SiLU", Swish),
"gelu": torch.nn.GELU,
}
SUBSAMPLE_CLASSES = {
"linear": LinearNoSubsampling,
"linear_legacy": LegacyLinearNoSubsampling,
"embed": EmbedinigNoSubsampling,
"conv1d2": Conv1dSubsampling2,
"conv2d": Conv2dSubsampling4,
"conv2d6": Conv2dSubsampling6,
"conv2d8": Conv2dSubsampling8,
'paraformer_dummy': torch.nn.Identity
}
EMB_CLASSES = {
"embed": PositionalEncoding,
"abs_pos": PositionalEncoding,
"rel_pos": RelPositionalEncoding,
"rel_pos_espnet": EspnetRelPositionalEncoding,
"no_pos": NoPositionalEncoding,
"abs_pos_whisper": WhisperPositionalEncoding,
"embed_learnable_pe": LearnablePositionalEncoding,
}
ATTENTION_CLASSES = {
"selfattn": MultiHeadedAttention,
"rel_selfattn": RelPositionMultiHeadedAttention,
}