Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| from .config import HexaConfig | |
| # Re-importing the core layers from the existing definition or redefining for cleanliness. | |
| # To integrate with HF Trainer, we wrap the existing module. | |
| class HexaHFConfig(PretrainedConfig): | |
| model_type = "hexa_tts" | |
| def __init__(self, **kwargs): | |
| # 1. Get defaults from the dataclass | |
| default_conf = HexaConfig() | |
| # 2. Set defaults on self (flattened) | |
| for k, v in default_conf.__dict__.items(): | |
| if not k.startswith("__"): | |
| setattr(self, k, v) | |
| # 3. Override with kwargs | |
| super().__init__(**kwargs) | |
| def hexa_config(self): | |
| # Reconstruct the dataclass on demand | |
| # This prevents 'HexaConfig not JSON serializable' error during checkpointing | |
| conf = HexaConfig() | |
| for k in conf.__dict__.keys(): | |
| if hasattr(self, k): | |
| setattr(conf, k, getattr(self, k)) | |
| return conf | |
| from .model import HexaTransformer as CoreTransformer | |
| class HexaModel(PreTrainedModel): | |
| config_class = HexaHFConfig | |
| _supports_gradient_checkpointing = True | |
| def supports_gradient_checkpointing(self): | |
| return True | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # Initialize the core model using the internal HexaConfig | |
| self.core = CoreTransformer(config.hexa_config) | |
| # Enable Gradient Checkpointing for memory savings | |
| self.gradient_checkpointing = False | |
| def get_input_embeddings(self): | |
| return self.core.token_emb | |
| def set_input_embeddings(self, value): | |
| self.core.token_emb = value | |
| def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None): | |
| # Handle defaults for optional args | |
| device = text_ids.device | |
| if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device) | |
| if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device) | |
| if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device) | |
| # Forward pass | |
| mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids) | |
| loss = None | |
| if labels is not None: | |
| # labels = target_mels | |
| # Align lengths | |
| min_len = min(mels.shape[1], labels.shape[1]) | |
| mels_sliced = mels[:, :min_len, :] | |
| labels_sliced = labels[:, :min_len, :] | |
| # CRITICAL FIX: Ensure inputs match model dtype (Float16 vs Float32) | |
| # transform labels to model's dtype | |
| if labels_sliced.dtype != mels_sliced.dtype: | |
| labels_sliced = labels_sliced.to(mels_sliced.dtype) | |
| loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced) | |
| return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels} | |