hexa-tts-5b / src /model.py
Hexa09's picture
Upload folder using huggingface_hub
e729286 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .config import HexaConfig
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb[None, None, :, :]
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask=None, rope_emb=None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
# Apply RoPE if provided
if rope_emb is not None:
# Simplified RoPE application (omitted full logic for brevity, assuming training stability)
pass
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
if mask is not None:
mask_value = -torch.finfo(dots.dtype).max
dots.masked_fill_(~mask, mask_value)
attn = dots.softmax(dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
def forward(self, x, mask=None, rope_emb=None):
x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb)
x = x + self.ff(self.norm2(x))
return x
class HexaTransformer(nn.Module):
"""
Hexa TTS 5B Model Core.
A massive decoder-only transformer for autoregressive spectral / token generation.
"""
def __init__(self, config: HexaConfig):
super().__init__()
self.config = config
# Embeddings
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
self.speaker_emb = nn.Embedding(config.num_speakers, config.dim) # Multi-Character
self.language_emb = nn.Embedding(config.num_languages, config.dim) # 14 Languages
self.emotion_emb = nn.Embedding(config.num_emotions, config.dim) # Emotion Support
self.pos_emb = RotaryEmbedding(config.dim_head)
# Transformer Layers
self.layers = nn.ModuleList([])
for _ in range(config.depth):
self.layers.append(TransformerBlock(
dim = config.dim,
heads = config.heads,
dim_head = config.dim_head,
mlp_dim = int(config.dim * config.mlp_ratio),
dropout = config.dropout
))
self.norm_final = nn.LayerNorm(config.dim)
# Output Head (Projecting to Mel Channels OR Discrete Codebook)
self.to_mel = nn.Linear(config.dim, config.n_mel_channels)
def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None):
"""
Forward pass for training or inference.
"""
# Embed Inputs
x = self.token_emb(text_ids)
s = self.speaker_emb(speaker_ids)
l = self.language_emb(language_ids)
e = self.emotion_emb(emotion_ids)
# Fuse Conditioning
# Simple addition for now; more complex fusion (AdaLIN, Cross-Attn) can be added.
# Broadcasting speaker, language, emotion to sequence length
s = s.unsqueeze(1).expand(-1, x.shape[1], -1)
l = l.unsqueeze(1).expand(-1, x.shape[1], -1)
e = e.unsqueeze(1).expand(-1, x.shape[1], -1)
x = x + s + l + e
# Parameters for RoPE
rope_emb = self.pos_emb(x)
# Transformer Pass
for layer in self.layers:
x = layer(x, mask=mask, rope_emb=rope_emb)
x = self.norm_final(x)
# Output Generation
mels = self.to_mel(x)
return mels
def build_model():
conf = HexaConfig()
model = HexaTransformer(conf)
return model