hexa-tts-trainer / src /hf_model.py
Hexa09's picture
Upload folder using huggingface_hub
c1bc764 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from .config import HexaConfig
# Re-importing the core layers from the existing definition or redefining for cleanliness.
# To integrate with HF Trainer, we wrap the existing module.
class HexaHFConfig(PretrainedConfig):
model_type = "hexa_tts"
def __init__(self, **kwargs):
# 1. Get defaults from the dataclass
default_conf = HexaConfig()
# 2. Set defaults on self (flattened)
for k, v in default_conf.__dict__.items():
if not k.startswith("__"):
setattr(self, k, v)
# 3. Override with kwargs
super().__init__(**kwargs)
@property
def hexa_config(self):
# Reconstruct the dataclass on demand
# This prevents 'HexaConfig not JSON serializable' error during checkpointing
conf = HexaConfig()
for k in conf.__dict__.keys():
if hasattr(self, k):
setattr(conf, k, getattr(self, k))
return conf
from .model import HexaTransformer as CoreTransformer
class HexaModel(PreTrainedModel):
config_class = HexaHFConfig
_supports_gradient_checkpointing = True
@property
def supports_gradient_checkpointing(self):
return True
def __init__(self, config):
super().__init__(config)
self.config = config
# Initialize the core model using the internal HexaConfig
self.core = CoreTransformer(config.hexa_config)
# Enable Gradient Checkpointing for memory savings
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.core.token_emb
def set_input_embeddings(self, value):
self.core.token_emb = value
def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None):
# Handle defaults for optional args
device = text_ids.device
if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device)
if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device)
if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device)
# Forward pass
mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids)
loss = None
if labels is not None:
# labels = target_mels
# Align lengths
min_len = min(mels.shape[1], labels.shape[1])
mels_sliced = mels[:, :min_len, :]
labels_sliced = labels[:, :min_len, :]
# CRITICAL FIX: Ensure inputs match model dtype (Float16 vs Float32)
# transform labels to model's dtype
if labels_sliced.dtype != mels_sliced.dtype:
labels_sliced = labels_sliced.to(mels_sliced.dtype)
loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced)
return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels}