import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig from .config import HexaConfig # Re-importing the core layers from the existing definition or redefining for cleanliness. # To integrate with HF Trainer, we wrap the existing module. class HexaHFConfig(PretrainedConfig): model_type = "hexa_tts" def __init__(self, **kwargs): # Flatten HexaConfig into kwargs for HF compatibility self.hexa_config = HexaConfig() # Update with manual kwargs if provided for k, v in kwargs.items(): if hasattr(self.hexa_config, k): setattr(self.hexa_config, k, v) super().__init__(**kwargs) from .model import HexaTransformer as CoreTransformer class HexaModel(PreTrainedModel): config_class = HexaHFConfig def __init__(self, config): super().__init__(config) self.config = config # Initialize the core model using the internal HexaConfig self.core = CoreTransformer(config.hexa_config) # Enable Gradient Checkpointing for memory savings self.gradient_checkpointing = False def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None): # Handle defaults for optional args device = text_ids.device if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device) if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device) if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device) # Forward pass mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids) loss = None if labels is not None: # labels = target_mels # Align lengths min_len = min(mels.shape[1], labels.shape[1]) mels_sliced = mels[:, :min_len, :] labels_sliced = labels[:, :min_len, :] loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced) return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels} def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, CoreTransformer): module.gradient_checkpointing = value