File size: 5,392 Bytes
e729286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .config import HexaConfig

class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device=device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb[None, None, :, :]

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask=None, rope_emb=None):
        b, n, _, h = *x.shape, self.heads
        
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        # Apply RoPE if provided
        if rope_emb is not None:
            # Simplified RoPE application (omitted full logic for brevity, assuming training stability)
            pass

        dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        if mask is not None:
            mask_value = -torch.finfo(dots.dtype).max
            dots.masked_fill_(~mask, mask_value)

        attn = dots.softmax(dim=-1)
        
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, mlp_dim, dropout=dropout)

    def forward(self, x, mask=None, rope_emb=None):
        x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb)
        x = x + self.ff(self.norm2(x))
        return x

class HexaTransformer(nn.Module):
    """

    Hexa TTS 5B Model Core.

    A massive decoder-only transformer for autoregressive spectral / token generation.

    """
    def __init__(self, config: HexaConfig):
        super().__init__()
        self.config = config
        
        # Embeddings
        self.token_emb = nn.Embedding(config.vocab_size, config.dim)
        self.speaker_emb = nn.Embedding(config.num_speakers, config.dim) # Multi-Character
        self.language_emb = nn.Embedding(config.num_languages, config.dim) # 14 Languages
        self.emotion_emb = nn.Embedding(config.num_emotions, config.dim) # Emotion Support

        self.pos_emb = RotaryEmbedding(config.dim_head)
        
        # Transformer Layers
        self.layers = nn.ModuleList([])
        for _ in range(config.depth):
            self.layers.append(TransformerBlock(
                dim = config.dim,
                heads = config.heads,
                dim_head = config.dim_head,
                mlp_dim = int(config.dim * config.mlp_ratio),
                dropout = config.dropout
            ))
            
        self.norm_final = nn.LayerNorm(config.dim)
        
        # Output Head (Projecting to Mel Channels OR Discrete Codebook)
        self.to_mel = nn.Linear(config.dim, config.n_mel_channels)

    def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None):
        """

        Forward pass for training or inference.

        """
        # Embed Inputs
        x = self.token_emb(text_ids)
        s = self.speaker_emb(speaker_ids)
        l = self.language_emb(language_ids)
        e = self.emotion_emb(emotion_ids)
        
        # Fuse Conditioning
        # Simple addition for now; more complex fusion (AdaLIN, Cross-Attn) can be added.
        # Broadcasting speaker, language, emotion to sequence length
        s = s.unsqueeze(1).expand(-1, x.shape[1], -1)
        l = l.unsqueeze(1).expand(-1, x.shape[1], -1)
        e = e.unsqueeze(1).expand(-1, x.shape[1], -1)
        
        x = x + s + l + e
        
        # Parameters for RoPE
        rope_emb = self.pos_emb(x)
        
        # Transformer Pass
        for layer in self.layers:
            x = layer(x, mask=mask, rope_emb=rope_emb)
            
        x = self.norm_final(x)
        
        # Output Generation
        mels = self.to_mel(x)
        return mels

def build_model():
    conf = HexaConfig()
    model = HexaTransformer(conf)
    return model