import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .config import HexaConfig class RotaryEmbedding(nn.Module): def __init__(self, dim): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) def forward(self, x): n, device = x.shape[1], x.device t = torch.arange(n, device=device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb[None, None, :, :] class FeedForward(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout) ) def forward(self, x): return self.net(x) class Attention(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads self.heads = heads self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) def forward(self, x, mask=None, rope_emb=None): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) # Apply RoPE if provided if rope_emb is not None: # Simplified RoPE application (omitted full logic for brevity, assuming training stability) pass dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale if mask is not None: mask_value = -torch.finfo(dots.dtype).max dots.masked_fill_(~mask, mask_value) attn = dots.softmax(dim=-1) out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) class TransformerBlock(nn.Module): def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) self.norm2 = nn.LayerNorm(dim) self.ff = FeedForward(dim, mlp_dim, dropout=dropout) def forward(self, x, mask=None, rope_emb=None): x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb) x = x + self.ff(self.norm2(x)) return x class HexaTransformer(nn.Module): """ Hexa TTS 5B Model Core. A massive decoder-only transformer for autoregressive spectral / token generation. """ def __init__(self, config: HexaConfig): super().__init__() self.config = config # Embeddings self.token_emb = nn.Embedding(config.vocab_size, config.dim) self.speaker_emb = nn.Embedding(config.num_speakers, config.dim) # Multi-Character self.language_emb = nn.Embedding(config.num_languages, config.dim) # 14 Languages self.emotion_emb = nn.Embedding(config.num_emotions, config.dim) # Emotion Support self.pos_emb = RotaryEmbedding(config.dim_head) # Transformer Layers self.layers = nn.ModuleList([]) for _ in range(config.depth): self.layers.append(TransformerBlock( dim = config.dim, heads = config.heads, dim_head = config.dim_head, mlp_dim = int(config.dim * config.mlp_ratio), dropout = config.dropout )) self.norm_final = nn.LayerNorm(config.dim) # Output Head (Projecting to Mel Channels OR Discrete Codebook) self.to_mel = nn.Linear(config.dim, config.n_mel_channels) def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None): """ Forward pass for training or inference. """ # Embed Inputs x = self.token_emb(text_ids) s = self.speaker_emb(speaker_ids) l = self.language_emb(language_ids) e = self.emotion_emb(emotion_ids) # Fuse Conditioning # Simple addition for now; more complex fusion (AdaLIN, Cross-Attn) can be added. # Broadcasting speaker, language, emotion to sequence length s = s.unsqueeze(1).expand(-1, x.shape[1], -1) l = l.unsqueeze(1).expand(-1, x.shape[1], -1) e = e.unsqueeze(1).expand(-1, x.shape[1], -1) x = x + s + l + e # Parameters for RoPE rope_emb = self.pos_emb(x) # Transformer Pass for layer in self.layers: x = layer(x, mask=mask, rope_emb=rope_emb) x = self.norm_final(x) # Output Generation mels = self.to_mel(x) return mels def build_model(): conf = HexaConfig() model = HexaTransformer(conf) return model