Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange | |
| from .config import HexaConfig | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward(self, x): | |
| n, device = x.shape[1], x.device | |
| t = torch.arange(n, device=device).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| return emb[None, None, :, :] | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.0): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.heads = heads | |
| self.scale = dim_head ** -0.5 | |
| self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
| self.to_out = nn.Sequential( | |
| nn.Linear(inner_dim, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x, mask=None, rope_emb=None): | |
| b, n, _, h = *x.shape, self.heads | |
| qkv = self.to_qkv(x).chunk(3, dim=-1) | |
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) | |
| # Apply RoPE if provided | |
| if rope_emb is not None: | |
| # Simplified RoPE application (omitted full logic for brevity, assuming training stability) | |
| pass | |
| dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
| if mask is not None: | |
| mask_value = -torch.finfo(dots.dtype).max | |
| dots.masked_fill_(~mask, mask_value) | |
| attn = dots.softmax(dim=-1) | |
| out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| return self.to_out(out) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(dim) | |
| self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout) | |
| self.norm2 = nn.LayerNorm(dim) | |
| self.ff = FeedForward(dim, mlp_dim, dropout=dropout) | |
| def forward(self, x, mask=None, rope_emb=None): | |
| x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb) | |
| x = x + self.ff(self.norm2(x)) | |
| return x | |
| class HexaTransformer(nn.Module): | |
| """ | |
| Hexa TTS 5B Model Core. | |
| A massive decoder-only transformer for autoregressive spectral / token generation. | |
| """ | |
| def __init__(self, config: HexaConfig): | |
| super().__init__() | |
| self.config = config | |
| # Embeddings | |
| self.token_emb = nn.Embedding(config.vocab_size, config.dim) | |
| self.speaker_emb = nn.Embedding(config.num_speakers, config.dim) # Multi-Character | |
| self.language_emb = nn.Embedding(config.num_languages, config.dim) # 14 Languages | |
| self.emotion_emb = nn.Embedding(config.num_emotions, config.dim) # Emotion Support | |
| self.pos_emb = RotaryEmbedding(config.dim_head) | |
| # Transformer Layers | |
| self.gradient_checkpointing = False # Default | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(config.depth): | |
| self.layers.append(TransformerBlock( | |
| dim = config.dim, | |
| heads = config.heads, | |
| dim_head = config.dim_head, | |
| mlp_dim = int(config.dim * config.mlp_ratio), | |
| dropout = config.dropout | |
| )) | |
| self.norm_final = nn.LayerNorm(config.dim) | |
| # Output Head (Projecting to Mel Channels OR Discrete Codebook) | |
| self.to_mel = nn.Linear(config.dim, config.n_mel_channels) | |
| def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None): | |
| """ | |
| Forward pass for training or inference. | |
| """ | |
| # Embed Inputs | |
| x = self.token_emb(text_ids) | |
| s = self.speaker_emb(speaker_ids) | |
| l = self.language_emb(language_ids) | |
| e = self.emotion_emb(emotion_ids) | |
| # Fuse Conditioning | |
| # Simple addition for now; more complex fusion (AdaLIN, Cross-Attn) can be added. | |
| # Broadcasting speaker, language, emotion to sequence length | |
| s = s.unsqueeze(1).expand(-1, x.shape[1], -1) | |
| l = l.unsqueeze(1).expand(-1, x.shape[1], -1) | |
| e = e.unsqueeze(1).expand(-1, x.shape[1], -1) | |
| x = x + s + l + e | |
| # Parameters for RoPE | |
| rope_emb = self.pos_emb(x) | |
| # Transformer Pass | |
| for layer in self.layers: | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| # Checkpoint requires inputs to have requires_grad=True for at least one input. | |
| # x usually has it. | |
| x = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(layer), | |
| x, | |
| mask, | |
| rope_emb, | |
| use_reentrant=False | |
| ) | |
| else: | |
| x = layer(x, mask=mask, rope_emb=rope_emb) | |
| x = self.norm_final(x) | |
| # Output Generation | |
| mels = self.to_mel(x) | |
| return mels | |
| def build_model(): | |
| conf = HexaConfig() | |
| model = HexaTransformer(conf) | |
| return model | |