| | import torch
|
| | import torch.nn as nn
|
| | from transformers import PreTrainedModel, PretrainedConfig
|
| | from .config import HexaConfig
|
| |
|
| |
|
| |
|
| | class HexaHFConfig(PretrainedConfig):
|
| | model_type = "hexa_tts"
|
| | def __init__(self, **kwargs):
|
| |
|
| | self.hexa_config = HexaConfig()
|
| |
|
| | for k, v in kwargs.items():
|
| | if hasattr(self.hexa_config, k):
|
| | setattr(self.hexa_config, k, v)
|
| | super().__init__(**kwargs)
|
| |
|
| | from .model import HexaTransformer as CoreTransformer
|
| |
|
| | class HexaModel(PreTrainedModel):
|
| | config_class = HexaHFConfig
|
| |
|
| | def __init__(self, config):
|
| | super().__init__(config)
|
| | self.config = config
|
| |
|
| | self.core = CoreTransformer(config.hexa_config)
|
| |
|
| |
|
| | self.gradient_checkpointing = False
|
| |
|
| | def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None):
|
| |
|
| | device = text_ids.device
|
| | if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device)
|
| | if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device)
|
| | if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device)
|
| |
|
| |
|
| | mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids)
|
| |
|
| | loss = None
|
| | if labels is not None:
|
| |
|
| |
|
| | min_len = min(mels.shape[1], labels.shape[1])
|
| | mels_sliced = mels[:, :min_len, :]
|
| | labels_sliced = labels[:, :min_len, :]
|
| | loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced)
|
| |
|
| | return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels}
|
| |
|
| | def _set_gradient_checkpointing(self, module, value=False):
|
| | if isinstance(module, CoreTransformer):
|
| | module.gradient_checkpointing = value
|
| |
|
| |
|