File size: 5,392 Bytes
e729286 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .config import HexaConfig
class RotaryEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb[None, None, :, :]
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask=None, rope_emb=None):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
# Apply RoPE if provided
if rope_emb is not None:
# Simplified RoPE application (omitted full logic for brevity, assuming training stability)
pass
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
if mask is not None:
mask_value = -torch.finfo(dots.dtype).max
dots.masked_fill_(~mask, mask_value)
attn = dots.softmax(dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class TransformerBlock(nn.Module):
def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
self.norm2 = nn.LayerNorm(dim)
self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
def forward(self, x, mask=None, rope_emb=None):
x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb)
x = x + self.ff(self.norm2(x))
return x
class HexaTransformer(nn.Module):
"""
Hexa TTS 5B Model Core.
A massive decoder-only transformer for autoregressive spectral / token generation.
"""
def __init__(self, config: HexaConfig):
super().__init__()
self.config = config
# Embeddings
self.token_emb = nn.Embedding(config.vocab_size, config.dim)
self.speaker_emb = nn.Embedding(config.num_speakers, config.dim) # Multi-Character
self.language_emb = nn.Embedding(config.num_languages, config.dim) # 14 Languages
self.emotion_emb = nn.Embedding(config.num_emotions, config.dim) # Emotion Support
self.pos_emb = RotaryEmbedding(config.dim_head)
# Transformer Layers
self.layers = nn.ModuleList([])
for _ in range(config.depth):
self.layers.append(TransformerBlock(
dim = config.dim,
heads = config.heads,
dim_head = config.dim_head,
mlp_dim = int(config.dim * config.mlp_ratio),
dropout = config.dropout
))
self.norm_final = nn.LayerNorm(config.dim)
# Output Head (Projecting to Mel Channels OR Discrete Codebook)
self.to_mel = nn.Linear(config.dim, config.n_mel_channels)
def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None):
"""
Forward pass for training or inference.
"""
# Embed Inputs
x = self.token_emb(text_ids)
s = self.speaker_emb(speaker_ids)
l = self.language_emb(language_ids)
e = self.emotion_emb(emotion_ids)
# Fuse Conditioning
# Simple addition for now; more complex fusion (AdaLIN, Cross-Attn) can be added.
# Broadcasting speaker, language, emotion to sequence length
s = s.unsqueeze(1).expand(-1, x.shape[1], -1)
l = l.unsqueeze(1).expand(-1, x.shape[1], -1)
e = e.unsqueeze(1).expand(-1, x.shape[1], -1)
x = x + s + l + e
# Parameters for RoPE
rope_emb = self.pos_emb(x)
# Transformer Pass
for layer in self.layers:
x = layer(x, mask=mask, rope_emb=rope_emb)
x = self.norm_final(x)
# Output Generation
mels = self.to_mel(x)
return mels
def build_model():
conf = HexaConfig()
model = HexaTransformer(conf)
return model
|