|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
from .config import HexaConfig
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
|
def __init__(self, dim):
|
|
|
super().__init__()
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
|
|
def forward(self, x):
|
|
|
n, device = x.shape[1], x.device
|
|
|
t = torch.arange(n, device=device).type_as(self.inv_freq)
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
|
emb = torch.cat((freqs, freqs), dim=-1)
|
|
|
return emb[None, None, :, :]
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
def __init__(self, dim, hidden_dim, dropout=0.0):
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(dim, hidden_dim),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(hidden_dim, dim),
|
|
|
nn.Dropout(dropout)
|
|
|
)
|
|
|
def forward(self, x):
|
|
|
return self.net(x)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
|
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
|
|
|
super().__init__()
|
|
|
inner_dim = dim_head * heads
|
|
|
self.heads = heads
|
|
|
self.scale = dim_head ** -0.5
|
|
|
|
|
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
|
|
|
self.to_out = nn.Sequential(
|
|
|
nn.Linear(inner_dim, dim),
|
|
|
nn.Dropout(dropout)
|
|
|
)
|
|
|
|
|
|
def forward(self, x, mask=None, rope_emb=None):
|
|
|
b, n, _, h = *x.shape, self.heads
|
|
|
|
|
|
qkv = self.to_qkv(x).chunk(3, dim=-1)
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
|
|
|
|
|
|
|
|
|
if rope_emb is not None:
|
|
|
|
|
|
pass
|
|
|
|
|
|
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
|
|
|
|
|
if mask is not None:
|
|
|
mask_value = -torch.finfo(dots.dtype).max
|
|
|
dots.masked_fill_(~mask, mask_value)
|
|
|
|
|
|
attn = dots.softmax(dim=-1)
|
|
|
|
|
|
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
|
return self.to_out(out)
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
|
|
|
super().__init__()
|
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
|
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
|
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
|
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
|
|
|
|
|
|
def forward(self, x, mask=None, rope_emb=None):
|
|
|
x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb)
|
|
|
x = x + self.ff(self.norm2(x))
|
|
|
return x
|
|
|
|
|
|
class HexaTransformer(nn.Module):
|
|
|
"""
|
|
|
Hexa TTS 5B Model Core.
|
|
|
A massive decoder-only transformer for autoregressive spectral / token generation.
|
|
|
"""
|
|
|
def __init__(self, config: HexaConfig):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
|
|
|
self.speaker_emb = nn.Embedding(config.num_speakers, config.dim)
|
|
|
self.language_emb = nn.Embedding(config.num_languages, config.dim)
|
|
|
self.emotion_emb = nn.Embedding(config.num_emotions, config.dim)
|
|
|
|
|
|
self.pos_emb = RotaryEmbedding(config.dim_head)
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([])
|
|
|
for _ in range(config.depth):
|
|
|
self.layers.append(TransformerBlock(
|
|
|
dim = config.dim,
|
|
|
heads = config.heads,
|
|
|
dim_head = config.dim_head,
|
|
|
mlp_dim = int(config.dim * config.mlp_ratio),
|
|
|
dropout = config.dropout
|
|
|
))
|
|
|
|
|
|
self.norm_final = nn.LayerNorm(config.dim)
|
|
|
|
|
|
|
|
|
self.to_mel = nn.Linear(config.dim, config.n_mel_channels)
|
|
|
|
|
|
def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None):
|
|
|
"""
|
|
|
Forward pass for training or inference.
|
|
|
"""
|
|
|
|
|
|
x = self.token_emb(text_ids)
|
|
|
s = self.speaker_emb(speaker_ids)
|
|
|
l = self.language_emb(language_ids)
|
|
|
e = self.emotion_emb(emotion_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
s = s.unsqueeze(1).expand(-1, x.shape[1], -1)
|
|
|
l = l.unsqueeze(1).expand(-1, x.shape[1], -1)
|
|
|
e = e.unsqueeze(1).expand(-1, x.shape[1], -1)
|
|
|
|
|
|
x = x + s + l + e
|
|
|
|
|
|
|
|
|
rope_emb = self.pos_emb(x)
|
|
|
|
|
|
|
|
|
for layer in self.layers:
|
|
|
x = layer(x, mask=mask, rope_emb=rope_emb)
|
|
|
|
|
|
x = self.norm_final(x)
|
|
|
|
|
|
|
|
|
mels = self.to_mel(x)
|
|
|
return mels
|
|
|
|
|
|
def build_model():
|
|
|
conf = HexaConfig()
|
|
|
model = HexaTransformer(conf)
|
|
|
return model
|
|
|
|