hexa-tts-5b / src /hf_model.py
Hexa09's picture
Upload folder using huggingface_hub
e729286 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from .config import HexaConfig
# Re-importing the core layers from the existing definition or redefining for cleanliness.
# To integrate with HF Trainer, we wrap the existing module.
class HexaHFConfig(PretrainedConfig):
model_type = "hexa_tts"
def __init__(self, **kwargs):
# Flatten HexaConfig into kwargs for HF compatibility
self.hexa_config = HexaConfig()
# Update with manual kwargs if provided
for k, v in kwargs.items():
if hasattr(self.hexa_config, k):
setattr(self.hexa_config, k, v)
super().__init__(**kwargs)
from .model import HexaTransformer as CoreTransformer
class HexaModel(PreTrainedModel):
config_class = HexaHFConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize the core model using the internal HexaConfig
self.core = CoreTransformer(config.hexa_config)
# Enable Gradient Checkpointing for memory savings
self.gradient_checkpointing = False
def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None):
# Handle defaults for optional args
device = text_ids.device
if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device)
if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device)
if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device)
# Forward pass
mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids)
loss = None
if labels is not None:
# labels = target_mels
# Align lengths
min_len = min(mels.shape[1], labels.shape[1])
mels_sliced = mels[:, :min_len, :]
labels_sliced = labels[:, :min_len, :]
loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced)
return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels}
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CoreTransformer):
module.gradient_checkpointing = value