Update modeling_bert_vits2.py
Browse files- modeling_bert_vits2.py +1 -55
modeling_bert_vits2.py
CHANGED
|
@@ -32,7 +32,7 @@ from transformers.modeling_outputs import (
|
|
| 32 |
)
|
| 33 |
from transformers.models.bert.modeling_bert import BertModel
|
| 34 |
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
-
from transformers.utils import
|
| 36 |
|
| 37 |
logger = logging.get_logger(__name__)
|
| 38 |
|
|
@@ -1404,58 +1404,6 @@ class BertVits2PreTrainedModel(PreTrainedModel):
|
|
| 1404 |
module.weight.data[module.padding_idx].zero_()
|
| 1405 |
|
| 1406 |
|
| 1407 |
-
BERT_VITS2_START_DOCSTRING = r"""
|
| 1408 |
-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 1409 |
-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 1410 |
-
etc.)
|
| 1411 |
-
|
| 1412 |
-
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 1413 |
-
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 1414 |
-
and behavior.
|
| 1415 |
-
|
| 1416 |
-
Parameters:
|
| 1417 |
-
config ([`BertVits2Config`]):
|
| 1418 |
-
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 1419 |
-
load the weights associated with the model, only the configuration. Check out the
|
| 1420 |
-
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 1421 |
-
"""
|
| 1422 |
-
|
| 1423 |
-
|
| 1424 |
-
BERT_VITS2_INPUTS_DOCSTRING = r"""
|
| 1425 |
-
Args:
|
| 1426 |
-
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1427 |
-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1428 |
-
it.
|
| 1429 |
-
|
| 1430 |
-
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 1431 |
-
[`PreTrainedTokenizer.__call__`] for details.
|
| 1432 |
-
|
| 1433 |
-
[What are input IDs?](../glossary#input-ids)
|
| 1434 |
-
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1435 |
-
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
|
| 1436 |
-
1]`:
|
| 1437 |
-
|
| 1438 |
-
- 1 for tokens that are **not masked**,
|
| 1439 |
-
- 0 for tokens that are **masked**.
|
| 1440 |
-
|
| 1441 |
-
[What are attention masks?](../glossary#attention-mask)
|
| 1442 |
-
speaker_id (`int`, *optional*):
|
| 1443 |
-
Which speaker embedding to use. Only used for multispeaker models.
|
| 1444 |
-
output_attentions (`bool`, *optional*):
|
| 1445 |
-
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 1446 |
-
tensors for more detail.
|
| 1447 |
-
output_hidden_states (`bool`, *optional*):
|
| 1448 |
-
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 1449 |
-
more detail.
|
| 1450 |
-
return_dict (`bool`, *optional*):
|
| 1451 |
-
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1452 |
-
"""
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
@add_start_docstrings(
|
| 1456 |
-
"The complete VITS model, for text-to-speech synthesis.",
|
| 1457 |
-
BERT_VITS2_START_DOCSTRING,
|
| 1458 |
-
)
|
| 1459 |
class BertVits2Model(BertVits2PreTrainedModel):
|
| 1460 |
def __init__(self, config):
|
| 1461 |
super().__init__(config)
|
|
@@ -1492,8 +1440,6 @@ class BertVits2Model(BertVits2PreTrainedModel):
|
|
| 1492 |
def get_encoder(self):
|
| 1493 |
return self.text_encoder
|
| 1494 |
|
| 1495 |
-
@add_start_docstrings_to_model_forward(BERT_VITS2_INPUTS_DOCSTRING)
|
| 1496 |
-
@replace_return_docstrings(output_type=BertVits2ModelOutput)
|
| 1497 |
def forward(
|
| 1498 |
self,
|
| 1499 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 32 |
)
|
| 33 |
from transformers.models.bert.modeling_bert import BertModel
|
| 34 |
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
+
from transformers.utils import logging
|
| 36 |
|
| 37 |
logger = logging.get_logger(__name__)
|
| 38 |
|
|
|
|
| 1404 |
module.weight.data[module.padding_idx].zero_()
|
| 1405 |
|
| 1406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1407 |
class BertVits2Model(BertVits2PreTrainedModel):
|
| 1408 |
def __init__(self, config):
|
| 1409 |
super().__init__(config)
|
|
|
|
| 1440 |
def get_encoder(self):
|
| 1441 |
return self.text_encoder
|
| 1442 |
|
|
|
|
|
|
|
| 1443 |
def forward(
|
| 1444 |
self,
|
| 1445 |
input_ids: Optional[torch.Tensor] = None,
|