Update modeling_bert_vits2.py
Browse files- modeling_bert_vits2.py +16 -22
modeling_bert_vits2.py
CHANGED
|
@@ -33,16 +33,10 @@ from transformers.modeling_outputs import (
|
|
| 33 |
from transformers.models.bert.modeling_bert import BertModel
|
| 34 |
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 36 |
-
from configuration_bert_vits2 import BertVits2Config
|
| 37 |
-
|
| 38 |
|
| 39 |
logger = logging.get_logger(__name__)
|
| 40 |
|
| 41 |
|
| 42 |
-
# General docstring
|
| 43 |
-
_CONFIG_FOR_DOC = "BertVits2Config"
|
| 44 |
-
|
| 45 |
-
|
| 46 |
@dataclass
|
| 47 |
class BertVits2ModelOutput(ModelOutput):
|
| 48 |
"""
|
|
@@ -328,7 +322,7 @@ def _rational_quadratic_spline(
|
|
| 328 |
|
| 329 |
|
| 330 |
class BertVits2WaveNet(torch.nn.Module):
|
| 331 |
-
def __init__(self, config
|
| 332 |
super().__init__()
|
| 333 |
self.hidden_size = config.hidden_size
|
| 334 |
self.num_layers = num_layers
|
|
@@ -408,7 +402,7 @@ class BertVits2WaveNet(torch.nn.Module):
|
|
| 408 |
|
| 409 |
|
| 410 |
class BertVits2PosteriorEncoder(nn.Module):
|
| 411 |
-
def __init__(self, config
|
| 412 |
super().__init__()
|
| 413 |
self.out_channels = config.flow_size
|
| 414 |
|
|
@@ -485,7 +479,7 @@ class HifiGanResidualBlock(nn.Module):
|
|
| 485 |
|
| 486 |
|
| 487 |
class BertVits2HifiGan(nn.Module):
|
| 488 |
-
def __init__(self, config
|
| 489 |
super().__init__()
|
| 490 |
self.config = config
|
| 491 |
self.num_kernels = len(config.resblock_kernel_sizes)
|
|
@@ -571,7 +565,7 @@ class BertVits2HifiGan(nn.Module):
|
|
| 571 |
|
| 572 |
|
| 573 |
class BertVits2ResidualCouplingLayer(nn.Module):
|
| 574 |
-
def __init__(self, config
|
| 575 |
super().__init__()
|
| 576 |
self.half_channels = config.flow_size // 2
|
| 577 |
|
|
@@ -593,7 +587,7 @@ class BertVits2ResidualCouplingLayer(nn.Module):
|
|
| 593 |
|
| 594 |
|
| 595 |
class BertVits2ResidualCouplingBlock(nn.Module):
|
| 596 |
-
def __init__(self, config
|
| 597 |
super().__init__()
|
| 598 |
self.flows = nn.ModuleList()
|
| 599 |
for _ in range(config.prior_encoder_num_flows):
|
|
@@ -608,7 +602,7 @@ class BertVits2ResidualCouplingBlock(nn.Module):
|
|
| 608 |
|
| 609 |
|
| 610 |
class BertVits2TransformerCouplingLayer(nn.Module):
|
| 611 |
-
def __init__(self, config
|
| 612 |
super().__init__()
|
| 613 |
self.half_channels = config.flow_size // 2
|
| 614 |
|
|
@@ -653,7 +647,7 @@ class BertVits2TransformerCouplingLayer(nn.Module):
|
|
| 653 |
|
| 654 |
|
| 655 |
class BertVits2TransformerCouplingBlock(nn.Module):
|
| 656 |
-
def __init__(self, config
|
| 657 |
super().__init__()
|
| 658 |
self.flows = nn.ModuleList([
|
| 659 |
BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows)
|
|
@@ -672,7 +666,7 @@ class BertVits2TransformerCouplingBlock(nn.Module):
|
|
| 672 |
|
| 673 |
|
| 674 |
class BertVits2DilatedDepthSeparableConv(nn.Module):
|
| 675 |
-
def __init__(self, config
|
| 676 |
super().__init__()
|
| 677 |
kernel_size = config.duration_predictor_kernel_size
|
| 678 |
channels = config.hidden_size
|
|
@@ -718,7 +712,7 @@ class BertVits2DilatedDepthSeparableConv(nn.Module):
|
|
| 718 |
|
| 719 |
|
| 720 |
class BertVits2ConvFlow(nn.Module):
|
| 721 |
-
def __init__(self, config
|
| 722 |
super().__init__()
|
| 723 |
self.filter_channels = config.hidden_size
|
| 724 |
self.half_channels = config.depth_separable_channels // 2
|
|
@@ -761,7 +755,7 @@ class BertVits2ConvFlow(nn.Module):
|
|
| 761 |
|
| 762 |
|
| 763 |
class BertVits2ElementwiseAffine(nn.Module):
|
| 764 |
-
def __init__(self, config
|
| 765 |
super().__init__()
|
| 766 |
self.channels = config.depth_separable_channels
|
| 767 |
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
|
|
@@ -918,7 +912,7 @@ class BertVits2DurationPredictor(nn.Module):
|
|
| 918 |
class BertVits2Attention(nn.Module):
|
| 919 |
"""Multi-headed attention with relative positional representation."""
|
| 920 |
|
| 921 |
-
def __init__(self, config
|
| 922 |
super().__init__()
|
| 923 |
self.embed_dim = config.hidden_size
|
| 924 |
self.num_heads = config.num_attention_heads
|
|
@@ -1130,7 +1124,7 @@ class BertVits2FeedForward(nn.Module):
|
|
| 1130 |
|
| 1131 |
|
| 1132 |
class BertVits2EncoderLayer(nn.Module):
|
| 1133 |
-
def __init__(self, config
|
| 1134 |
super().__init__()
|
| 1135 |
self.attention = BertVits2Attention(config)
|
| 1136 |
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
@@ -1169,7 +1163,7 @@ class BertVits2EncoderLayer(nn.Module):
|
|
| 1169 |
|
| 1170 |
|
| 1171 |
class BertVits2Encoder(nn.Module):
|
| 1172 |
-
def __init__(self, config
|
| 1173 |
super().__init__()
|
| 1174 |
self.config = config
|
| 1175 |
if n_layers is None:
|
|
@@ -1260,7 +1254,7 @@ class BertVits2TextEncoder(nn.Module):
|
|
| 1260 |
Transformer encoder that uses relative positional representation instead of absolute positional encoding.
|
| 1261 |
"""
|
| 1262 |
|
| 1263 |
-
def __init__(self, config
|
| 1264 |
super().__init__()
|
| 1265 |
self.config = config
|
| 1266 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
|
@@ -1330,7 +1324,7 @@ class BertVits2TextEncoder(nn.Module):
|
|
| 1330 |
|
| 1331 |
|
| 1332 |
class BertVits2ReferenceEncoder(nn.Module):
|
| 1333 |
-
def __init__(self, config
|
| 1334 |
super().__init__()
|
| 1335 |
self.config = config
|
| 1336 |
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
|
@@ -1464,7 +1458,7 @@ BERT_VITS2_INPUTS_DOCSTRING = r"""
|
|
| 1464 |
BERT_VITS2_START_DOCSTRING,
|
| 1465 |
)
|
| 1466 |
class BertVits2Model(BertVits2PreTrainedModel):
|
| 1467 |
-
def __init__(self, config
|
| 1468 |
super().__init__(config)
|
| 1469 |
self.config = config
|
| 1470 |
self.text_encoder = BertVits2TextEncoder(config)
|
|
|
|
| 33 |
from transformers.models.bert.modeling_bert import BertModel
|
| 34 |
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
|
|
|
|
|
|
| 36 |
|
| 37 |
logger = logging.get_logger(__name__)
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
@dataclass
|
| 41 |
class BertVits2ModelOutput(ModelOutput):
|
| 42 |
"""
|
|
|
|
| 322 |
|
| 323 |
|
| 324 |
class BertVits2WaveNet(torch.nn.Module):
|
| 325 |
+
def __init__(self, config, num_layers: int):
|
| 326 |
super().__init__()
|
| 327 |
self.hidden_size = config.hidden_size
|
| 328 |
self.num_layers = num_layers
|
|
|
|
| 402 |
|
| 403 |
|
| 404 |
class BertVits2PosteriorEncoder(nn.Module):
|
| 405 |
+
def __init__(self, config):
|
| 406 |
super().__init__()
|
| 407 |
self.out_channels = config.flow_size
|
| 408 |
|
|
|
|
| 479 |
|
| 480 |
|
| 481 |
class BertVits2HifiGan(nn.Module):
|
| 482 |
+
def __init__(self, config):
|
| 483 |
super().__init__()
|
| 484 |
self.config = config
|
| 485 |
self.num_kernels = len(config.resblock_kernel_sizes)
|
|
|
|
| 565 |
|
| 566 |
|
| 567 |
class BertVits2ResidualCouplingLayer(nn.Module):
|
| 568 |
+
def __init__(self, config):
|
| 569 |
super().__init__()
|
| 570 |
self.half_channels = config.flow_size // 2
|
| 571 |
|
|
|
|
| 587 |
|
| 588 |
|
| 589 |
class BertVits2ResidualCouplingBlock(nn.Module):
|
| 590 |
+
def __init__(self, config):
|
| 591 |
super().__init__()
|
| 592 |
self.flows = nn.ModuleList()
|
| 593 |
for _ in range(config.prior_encoder_num_flows):
|
|
|
|
| 602 |
|
| 603 |
|
| 604 |
class BertVits2TransformerCouplingLayer(nn.Module):
|
| 605 |
+
def __init__(self, config):
|
| 606 |
super().__init__()
|
| 607 |
self.half_channels = config.flow_size // 2
|
| 608 |
|
|
|
|
| 647 |
|
| 648 |
|
| 649 |
class BertVits2TransformerCouplingBlock(nn.Module):
|
| 650 |
+
def __init__(self, config):
|
| 651 |
super().__init__()
|
| 652 |
self.flows = nn.ModuleList([
|
| 653 |
BertVits2TransformerCouplingLayer(config) for _ in range(config.prior_encoder_num_flows)
|
|
|
|
| 666 |
|
| 667 |
|
| 668 |
class BertVits2DilatedDepthSeparableConv(nn.Module):
|
| 669 |
+
def __init__(self, config, dropout_rate=0.0):
|
| 670 |
super().__init__()
|
| 671 |
kernel_size = config.duration_predictor_kernel_size
|
| 672 |
channels = config.hidden_size
|
|
|
|
| 712 |
|
| 713 |
|
| 714 |
class BertVits2ConvFlow(nn.Module):
|
| 715 |
+
def __init__(self, config):
|
| 716 |
super().__init__()
|
| 717 |
self.filter_channels = config.hidden_size
|
| 718 |
self.half_channels = config.depth_separable_channels // 2
|
|
|
|
| 755 |
|
| 756 |
|
| 757 |
class BertVits2ElementwiseAffine(nn.Module):
|
| 758 |
+
def __init__(self, config):
|
| 759 |
super().__init__()
|
| 760 |
self.channels = config.depth_separable_channels
|
| 761 |
self.translate = nn.Parameter(torch.zeros(self.channels, 1))
|
|
|
|
| 912 |
class BertVits2Attention(nn.Module):
|
| 913 |
"""Multi-headed attention with relative positional representation."""
|
| 914 |
|
| 915 |
+
def __init__(self, config):
|
| 916 |
super().__init__()
|
| 917 |
self.embed_dim = config.hidden_size
|
| 918 |
self.num_heads = config.num_attention_heads
|
|
|
|
| 1124 |
|
| 1125 |
|
| 1126 |
class BertVits2EncoderLayer(nn.Module):
|
| 1127 |
+
def __init__(self, config, kernel_size=None):
|
| 1128 |
super().__init__()
|
| 1129 |
self.attention = BertVits2Attention(config)
|
| 1130 |
self.dropout = nn.Dropout(config.hidden_dropout)
|
|
|
|
| 1163 |
|
| 1164 |
|
| 1165 |
class BertVits2Encoder(nn.Module):
|
| 1166 |
+
def __init__(self, config, kernel_size=None, n_layers=None):
|
| 1167 |
super().__init__()
|
| 1168 |
self.config = config
|
| 1169 |
if n_layers is None:
|
|
|
|
| 1254 |
Transformer encoder that uses relative positional representation instead of absolute positional encoding.
|
| 1255 |
"""
|
| 1256 |
|
| 1257 |
+
def __init__(self, config):
|
| 1258 |
super().__init__()
|
| 1259 |
self.config = config
|
| 1260 |
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
|
|
|
|
| 1324 |
|
| 1325 |
|
| 1326 |
class BertVits2ReferenceEncoder(nn.Module):
|
| 1327 |
+
def __init__(self, config):
|
| 1328 |
super().__init__()
|
| 1329 |
self.config = config
|
| 1330 |
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
|
|
|
| 1458 |
BERT_VITS2_START_DOCSTRING,
|
| 1459 |
)
|
| 1460 |
class BertVits2Model(BertVits2PreTrainedModel):
|
| 1461 |
+
def __init__(self, config):
|
| 1462 |
super().__init__(config)
|
| 1463 |
self.config = config
|
| 1464 |
self.text_encoder = BertVits2TextEncoder(config)
|