File size: 2,361 Bytes
e729286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
from .config import HexaConfig
# Re-importing the core layers from the existing definition or redefining for cleanliness.
# To integrate with HF Trainer, we wrap the existing module.

class HexaHFConfig(PretrainedConfig):
    model_type = "hexa_tts"
    def __init__(self, **kwargs):
        # Flatten HexaConfig into kwargs for HF compatibility
        self.hexa_config = HexaConfig() 
        # Update with manual kwargs if provided
        for k, v in kwargs.items():
            if hasattr(self.hexa_config, k):
                setattr(self.hexa_config, k, v)
        super().__init__(**kwargs)

from .model import HexaTransformer as CoreTransformer

class HexaModel(PreTrainedModel):
    config_class = HexaHFConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        # Initialize the core model using the internal HexaConfig
        self.core = CoreTransformer(config.hexa_config)
        
        # Enable Gradient Checkpointing for memory savings
        self.gradient_checkpointing = False

    def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None):
        # Handle defaults for optional args
        device = text_ids.device
        if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device)
        if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device)
        if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device)

        # Forward pass
        mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids)
        
        loss = None
        if labels is not None:
            # labels = target_mels
            # Align lengths
            min_len = min(mels.shape[1], labels.shape[1])
            mels_sliced = mels[:, :min_len, :]
            labels_sliced = labels[:, :min_len, :]
            loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced)
            
        return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels}
    
    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, CoreTransformer):
            module.gradient_checkpointing = value