Hexa09 commited on
Commit
e729286
·
verified ·
1 Parent(s): b38ebdf

Upload folder using huggingface_hub

Browse files
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (123 Bytes). View file
 
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (1.85 kB). View file
 
src/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (4.95 kB). View file
 
src/__pycache__/hf_model.cpython-312.pyc ADDED
Binary file (3.05 kB). View file
 
src/__pycache__/inference.cpython-312.pyc ADDED
Binary file (2.84 kB). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (9.34 kB). View file
 
src/__pycache__/test_tiny.cpython-312.pyc ADDED
Binary file (2.74 kB). View file
 
src/__pycache__/text_encoder.cpython-312.pyc ADDED
Binary file (2.35 kB). View file
 
src/__pycache__/train.cpython-312.pyc ADDED
Binary file (3.61 kB). View file
 
src/__pycache__/train_hf.cpython-312.pyc ADDED
Binary file (2.82 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ @dataclass
5
+ class HexaConfig:
6
+ """
7
+ Configuration for Hexa TTS 5B Model.
8
+ Designed to scale to ~5 Billion parameters.
9
+ """
10
+ # Model Architecture
11
+ dim: int = 3200 # Tuned for ~5B params (4.92B)
12
+ depth: int = 40 # Number of layers
13
+ heads: int = 32 # Number of attention heads
14
+ dim_head: int = 100 # Dimension of each head
15
+ mlp_ratio: float = 4.0 # Feedforward expansion factor
16
+ dropout: float = 0.1
17
+
18
+ # Input / Output
19
+ num_languages: int = 15
20
+ vocab_size: int = 256 # Size of phoneme/text vocabulary
21
+ num_speakers: int = 10000 # Embedding slot for speakers
22
+ num_emotions: int = 32 # Distinct emotion categories
23
+
24
+ # Audio Settings
25
+ sample_rate: int = 24000
26
+ n_mel_channels: int = 100
27
+ n_fft: int = 1024
28
+ hop_length: int = 256
29
+ win_length: int = 1024
30
+
31
+ # Context
32
+ max_text_len: int = 1024
33
+ max_audio_len: int = 4096 # In mel frames
34
+
35
+ # Checkpoints
36
+ checkpoint_path: str = "checkpoints/hexa_5b_latest.pt"
37
+
38
+ def __post_init__(self):
39
+ # Rough parameter count estimation:
40
+ # 12 * layers * dim^2 (approximate for standard transformer)
41
+ total_params = 12 * self.depth * (self.dim ** 2)
42
+ print(f"Hexa Config initialized.")
43
+ print(f"Approximate Model Size: {total_params / 1e9:.2f} Billion parameters")
src/dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import os
4
+ from torch.utils.data import Dataset
5
+ from .text_encoder import TextEncoder
6
+ from .config import HexaConfig
7
+
8
+ class HexaDataset(Dataset):
9
+ """
10
+ Real Dataset Loader for Hexa TTS.
11
+ Expects a directory structure:
12
+ /data_root
13
+ /wavs/
14
+ metadata.csv (formatted: filename|text)
15
+ """
16
+ def __init__(self, root_dir, config: HexaConfig, train=True):
17
+ self.root_dir = root_dir
18
+ self.config = config
19
+ self.encoder = TextEncoder()
20
+
21
+ self.wav_dir = os.path.join(root_dir, "wavs")
22
+ self.metadata_path = os.path.join(root_dir, "metadata.csv")
23
+
24
+ self.files = []
25
+ if os.path.exists(self.metadata_path):
26
+ with open(self.metadata_path, 'r', encoding='utf-8') as f:
27
+ for line in f:
28
+ parts = line.strip().split('|')
29
+ if len(parts) >= 2:
30
+ self.files.append((parts[0], parts[1]))
31
+ else:
32
+ print(f"Warning: Metadata not found at {self.metadata_path}")
33
+
34
+ # Mel Spectrogram Transform
35
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
36
+ sample_rate=config.sample_rate,
37
+ n_fft=config.n_fft,
38
+ win_length=config.win_length,
39
+ hop_length=config.hop_length,
40
+ n_mels=config.n_mel_channels
41
+ )
42
+
43
+ def __len__(self):
44
+ return len(self.files)
45
+
46
+ def __getitem__(self, idx):
47
+ filename, text = self.files[idx]
48
+ wav_path = os.path.join(self.wav_dir, filename + ".wav")
49
+
50
+ # 1. Load Audio
51
+ waveform, sr = torchaudio.load(wav_path)
52
+
53
+ # Resample if needed
54
+ if sr != self.config.sample_rate:
55
+ resampler = torchaudio.transforms.Resample(sr, self.config.sample_rate)
56
+ waveform = resampler(waveform)
57
+
58
+ # 2. Compute Mel
59
+ mel = self.mel_transform(waveform) # [channels, frames]
60
+ mel = mel.squeeze(0).transpose(0, 1) # [frames, channels]
61
+
62
+ # 3. Tokenize Text
63
+ # Assuming English for starter dataset (LJSpeech)
64
+ text_ids = self.encoder.preprocess(text, lang_code='en').squeeze(0)
65
+
66
+ # 4. Dummy Speaker/Lang/Emotion for single-speaker dataset
67
+ speaker = torch.tensor(0)
68
+ lang = torch.tensor(0)
69
+ emotion = torch.tensor(0)
70
+
71
+ return text_ids, speaker, lang, emotion, mel
72
+
73
+ def collate_fn(batch):
74
+ """
75
+ Pads batch to longest sequence.
76
+ """
77
+ # Sort by text length for packing (optional but good practice)
78
+ batch.sort(key=lambda x: x[0].shape[0], reverse=True)
79
+
80
+ text_ids, speakers, langs, emotions, mels = zip(*batch)
81
+
82
+ # Pad Text
83
+ text_padded = torch.nn.utils.rnn.pad_sequence(text_ids, batch_first=True, padding_value=0)
84
+
85
+ # Pad Mels
86
+ mel_padded = torch.nn.utils.rnn.pad_sequence(mels, batch_first=True, padding_value=0.0)
87
+
88
+ speakers = torch.stack(speakers)
89
+ langs = torch.stack(langs)
90
+ emotions = torch.stack(emotions)
91
+
92
+ return text_padded, speakers, langs, emotions, mel_padded
src/hf_model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from .config import HexaConfig
5
+ # Re-importing the core layers from the existing definition or redefining for cleanliness.
6
+ # To integrate with HF Trainer, we wrap the existing module.
7
+
8
+ class HexaHFConfig(PretrainedConfig):
9
+ model_type = "hexa_tts"
10
+ def __init__(self, **kwargs):
11
+ # Flatten HexaConfig into kwargs for HF compatibility
12
+ self.hexa_config = HexaConfig()
13
+ # Update with manual kwargs if provided
14
+ for k, v in kwargs.items():
15
+ if hasattr(self.hexa_config, k):
16
+ setattr(self.hexa_config, k, v)
17
+ super().__init__(**kwargs)
18
+
19
+ from .model import HexaTransformer as CoreTransformer
20
+
21
+ class HexaModel(PreTrainedModel):
22
+ config_class = HexaHFConfig
23
+
24
+ def __init__(self, config):
25
+ super().__init__(config)
26
+ self.config = config
27
+ # Initialize the core model using the internal HexaConfig
28
+ self.core = CoreTransformer(config.hexa_config)
29
+
30
+ # Enable Gradient Checkpointing for memory savings
31
+ self.gradient_checkpointing = False
32
+
33
+ def forward(self, text_ids, speaker_ids=None, language_ids=None, emotion_ids=None, labels=None):
34
+ # Handle defaults for optional args
35
+ device = text_ids.device
36
+ if speaker_ids is None: speaker_ids = torch.zeros_like(text_ids).to(device)
37
+ if language_ids is None: language_ids = torch.zeros_like(text_ids).to(device)
38
+ if emotion_ids is None: emotion_ids = torch.zeros_like(text_ids).to(device)
39
+
40
+ # Forward pass
41
+ mels = self.core(text_ids, speaker_ids, language_ids, emotion_ids)
42
+
43
+ loss = None
44
+ if labels is not None:
45
+ # labels = target_mels
46
+ # Align lengths
47
+ min_len = min(mels.shape[1], labels.shape[1])
48
+ mels_sliced = mels[:, :min_len, :]
49
+ labels_sliced = labels[:, :min_len, :]
50
+ loss = torch.nn.functional.mse_loss(mels_sliced, labels_sliced)
51
+
52
+ return {"loss": loss, "logits": mels} if loss is not None else {"logits": mels}
53
+
54
+ def _set_gradient_checkpointing(self, module, value=False):
55
+ if isinstance(module, CoreTransformer):
56
+ module.gradient_checkpointing = value
57
+
src/inference.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import soundfile as sf
3
+ import os
4
+ from .model import build_model
5
+ from .text_encoder import TextEncoder
6
+ from .config import HexaConfig
7
+
8
+ def generate_audio(text, output_path, lang='en', speaker_id=0, emotion_id=0):
9
+ """
10
+ Generates audio from text using the Hexa 5B model.
11
+ """
12
+ print(f"Initializing Hexa 5B TTS System...")
13
+
14
+ # 1. Load Configuration
15
+ config = HexaConfig()
16
+
17
+ # 2. Load Model (Architecture only, random weights for demo)
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"Using device: {device}")
20
+
21
+ model = build_model()
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ # 3. Process Text
26
+ encoder = TextEncoder()
27
+ print(f"Processing text: '{text}' ({lang})")
28
+ text_ids = encoder.preprocess(text, lang_code=lang).to(device)
29
+
30
+ # 4. Prepare inputs
31
+ # Ensure IDs are within range
32
+ speaker_tensor = torch.tensor([speaker_id]).to(device).clamp(0, config.num_speakers-1)
33
+ language_tensor = torch.tensor([0]).to(device) # Placeholder mapping
34
+ emotion_tensor = torch.tensor([emotion_id]).to(device).clamp(0, config.num_emotions-1)
35
+
36
+ # 5. Generate (Forward Pass)
37
+ with torch.no_grad():
38
+ # In a real autoregressive model, this would be a loop.
39
+ # Here we just run one forward pass to verify architecture.
40
+ mel_output = model(text_ids, speaker_tensor, language_tensor, emotion_tensor)
41
+
42
+ print(f"Model forward pass successful. Output shape: {mel_output.shape}")
43
+ print("Note: Since this is an untrained model, the output is random noise.")
44
+
45
+ # 6. Dummy Vocoder (Simulated)
46
+ # In production, use HifiGAN here to convert Mel -> Audio
47
+ sr = config.sample_rate
48
+ dummy_audio = torch.randn(mel_output.shape[1] * 256) # Approx length
49
+
50
+ # Save
51
+ sf.write(output_path, dummy_audio.cpu().numpy(), sr)
52
+ print(f"Saved generated (random) audio to: {output_path}")
53
+
54
+ if __name__ == "__main__":
55
+ # Test Run
56
+ generate_audio(
57
+ "Hello, this is Hexa TTS.",
58
+ "test_output.wav",
59
+ lang='en',
60
+ emotion_id=5 # e.g. 'Happy'
61
+ )
src/model.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+ from .config import HexaConfig
6
+
7
+ class RotaryEmbedding(nn.Module):
8
+ def __init__(self, dim):
9
+ super().__init__()
10
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
11
+ self.register_buffer("inv_freq", inv_freq)
12
+
13
+ def forward(self, x):
14
+ n, device = x.shape[1], x.device
15
+ t = torch.arange(n, device=device).type_as(self.inv_freq)
16
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
17
+ emb = torch.cat((freqs, freqs), dim=-1)
18
+ return emb[None, None, :, :]
19
+
20
+ class FeedForward(nn.Module):
21
+ def __init__(self, dim, hidden_dim, dropout=0.0):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.Linear(dim, hidden_dim),
25
+ nn.GELU(),
26
+ nn.Dropout(dropout),
27
+ nn.Linear(hidden_dim, dim),
28
+ nn.Dropout(dropout)
29
+ )
30
+ def forward(self, x):
31
+ return self.net(x)
32
+
33
+ class Attention(nn.Module):
34
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
35
+ super().__init__()
36
+ inner_dim = dim_head * heads
37
+ self.heads = heads
38
+ self.scale = dim_head ** -0.5
39
+
40
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
41
+ self.to_out = nn.Sequential(
42
+ nn.Linear(inner_dim, dim),
43
+ nn.Dropout(dropout)
44
+ )
45
+
46
+ def forward(self, x, mask=None, rope_emb=None):
47
+ b, n, _, h = *x.shape, self.heads
48
+
49
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
50
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
51
+
52
+ # Apply RoPE if provided
53
+ if rope_emb is not None:
54
+ # Simplified RoPE application (omitted full logic for brevity, assuming training stability)
55
+ pass
56
+
57
+ dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
58
+
59
+ if mask is not None:
60
+ mask_value = -torch.finfo(dots.dtype).max
61
+ dots.masked_fill_(~mask, mask_value)
62
+
63
+ attn = dots.softmax(dim=-1)
64
+
65
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
66
+ out = rearrange(out, 'b h n d -> b n (h d)')
67
+ return self.to_out(out)
68
+
69
+ class TransformerBlock(nn.Module):
70
+ def __init__(self, dim, heads, dim_head, mlp_dim, dropout=0.0):
71
+ super().__init__()
72
+ self.norm1 = nn.LayerNorm(dim)
73
+ self.attn = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
74
+ self.norm2 = nn.LayerNorm(dim)
75
+ self.ff = FeedForward(dim, mlp_dim, dropout=dropout)
76
+
77
+ def forward(self, x, mask=None, rope_emb=None):
78
+ x = x + self.attn(self.norm1(x), mask=mask, rope_emb=rope_emb)
79
+ x = x + self.ff(self.norm2(x))
80
+ return x
81
+
82
+ class HexaTransformer(nn.Module):
83
+ """
84
+ Hexa TTS 5B Model Core.
85
+ A massive decoder-only transformer for autoregressive spectral / token generation.
86
+ """
87
+ def __init__(self, config: HexaConfig):
88
+ super().__init__()
89
+ self.config = config
90
+
91
+ # Embeddings
92
+ self.token_emb = nn.Embedding(config.vocab_size, config.dim)
93
+ self.speaker_emb = nn.Embedding(config.num_speakers, config.dim) # Multi-Character
94
+ self.language_emb = nn.Embedding(config.num_languages, config.dim) # 14 Languages
95
+ self.emotion_emb = nn.Embedding(config.num_emotions, config.dim) # Emotion Support
96
+
97
+ self.pos_emb = RotaryEmbedding(config.dim_head)
98
+
99
+ # Transformer Layers
100
+ self.layers = nn.ModuleList([])
101
+ for _ in range(config.depth):
102
+ self.layers.append(TransformerBlock(
103
+ dim = config.dim,
104
+ heads = config.heads,
105
+ dim_head = config.dim_head,
106
+ mlp_dim = int(config.dim * config.mlp_ratio),
107
+ dropout = config.dropout
108
+ ))
109
+
110
+ self.norm_final = nn.LayerNorm(config.dim)
111
+
112
+ # Output Head (Projecting to Mel Channels OR Discrete Codebook)
113
+ self.to_mel = nn.Linear(config.dim, config.n_mel_channels)
114
+
115
+ def forward(self, text_ids, speaker_ids, language_ids, emotion_ids, mask=None):
116
+ """
117
+ Forward pass for training or inference.
118
+ """
119
+ # Embed Inputs
120
+ x = self.token_emb(text_ids)
121
+ s = self.speaker_emb(speaker_ids)
122
+ l = self.language_emb(language_ids)
123
+ e = self.emotion_emb(emotion_ids)
124
+
125
+ # Fuse Conditioning
126
+ # Simple addition for now; more complex fusion (AdaLIN, Cross-Attn) can be added.
127
+ # Broadcasting speaker, language, emotion to sequence length
128
+ s = s.unsqueeze(1).expand(-1, x.shape[1], -1)
129
+ l = l.unsqueeze(1).expand(-1, x.shape[1], -1)
130
+ e = e.unsqueeze(1).expand(-1, x.shape[1], -1)
131
+
132
+ x = x + s + l + e
133
+
134
+ # Parameters for RoPE
135
+ rope_emb = self.pos_emb(x)
136
+
137
+ # Transformer Pass
138
+ for layer in self.layers:
139
+ x = layer(x, mask=mask, rope_emb=rope_emb)
140
+
141
+ x = self.norm_final(x)
142
+
143
+ # Output Generation
144
+ mels = self.to_mel(x)
145
+ return mels
146
+
147
+ def build_model():
148
+ conf = HexaConfig()
149
+ model = HexaTransformer(conf)
150
+ return model
src/test_tiny.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import soundfile as sf
3
+ import os
4
+ from .model import HexaTransformer
5
+ from .text_encoder import TextEncoder
6
+ from .config import HexaConfig
7
+
8
+ def run_tiny_test():
9
+ """
10
+ Test the architecture with a tiny config to fit in memory.
11
+ """
12
+ print("Initializing Tiny Hexa Model for Code Verification...")
13
+
14
+ # Override Config for Tiny Scale
15
+ config = HexaConfig(
16
+ dim=512,
17
+ depth=6,
18
+ heads=8,
19
+ dim_head=64,
20
+ num_languages=15
21
+ )
22
+
23
+ device = "cpu"
24
+ model = HexaTransformer(config)
25
+ model.to(device)
26
+ model.eval()
27
+
28
+ params = sum(p.numel() for p in model.parameters())
29
+ print(f"Tiny Model Size: {params / 1e6:.2f} Million parameters")
30
+
31
+ # Process Text
32
+ text = "Hello world, testing tiny hexa."
33
+ encoder = TextEncoder()
34
+ text_ids = encoder.preprocess(text, lang_code='en').to(device)
35
+ print(f"Encoded text shape: {text_ids.shape}")
36
+
37
+ # Inputs
38
+ speaker = torch.tensor([0]).to(device)
39
+ language = torch.tensor([0]).to(device)
40
+ emotion = torch.tensor([0]).to(device)
41
+
42
+ # Forward Pass
43
+ with torch.no_grad():
44
+ output = model(text_ids, speaker, language, emotion)
45
+
46
+ print(f"Forward pass successful. Output shape: {output.shape}")
47
+
48
+ # Save dummy audio
49
+ # Output is (B, Frames, Mel_Channels)
50
+ # We fake audio from it
51
+ dummy_wav = torch.randn(output.shape[1] * 256).numpy()
52
+ sf.write("tiny_output.wav", dummy_wav, config.sample_rate)
53
+ print("Saved tiny_output.wav")
54
+
55
+ if __name__ == "__main__":
56
+ run_tiny_test()
src/text_encoder.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from phonemizer import phonemize
3
+ from phonemizer.separator import Separator
4
+
5
+ class TextEncoder:
6
+ """
7
+ Handles text-to-phoneme conversion for 14 languages.
8
+ """
9
+ def __init__(self, vocab_map=None):
10
+ self.separator = Separator(phone=' ', word='|', syllable='')
11
+ # Maps 14 languages to phonemizer language codes
12
+ self.lang_map = {
13
+ 'en': 'en-us', 'zh': 'cmn', 'es': 'es', 'fr': 'fr-fr',
14
+ 'de': 'de', 'ja': 'ja', 'ko': 'ko', 'ru': 'ru',
15
+ 'pt': 'pt', 'it': 'it', 'hi': 'hi', 'ar': 'ar',
16
+ 'tr': 'tr', 'nl': 'nl', 'bn': 'bn'
17
+ }
18
+ # Simple character-to-id mapping (placeholder)
19
+ self.vocab = vocab_map if vocab_map else {c: i for i, c in enumerate(" abcdefghijklmnopqrstuvwxyz|")}
20
+
21
+ def preprocess(self, text, lang_code='en'):
22
+ """
23
+ Converts text to phoneme IDs.
24
+ """
25
+ if lang_code not in self.lang_map:
26
+ print(f"Warning: Language {lang_code} not fully supported, defaulting to English backend.")
27
+ backend_lang = 'en-us'
28
+ else:
29
+ backend_lang = self.lang_map[lang_code]
30
+
31
+ try:
32
+ # Phonemize
33
+ phonemes = phonemize(
34
+ text,
35
+ language=backend_lang,
36
+ backend='espeak',
37
+ separator=self.separator,
38
+ strip=True,
39
+ preserve_punctuation=True,
40
+ njobs=1
41
+ )
42
+ except RuntimeError:
43
+ print("Warning: eSpeak not found. Falling back to character-level tokenization.")
44
+ phonemes = list(text) # Simple list of characters as fallback
45
+
46
+ # Tokenize (Simple lookup for now)
47
+ token_ids = [self.vocab.get(p, 0) for p in phonemes]
48
+ return torch.tensor(token_ids).unsqueeze(0) # Batch dim
src/train.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torch.optim import AdamW
4
+ from accelerate import Accelerator
5
+ from tqdm import tqdm
6
+ import os
7
+ from .model import build_model
8
+ from .config import HexaConfig
9
+ from .dataset import HexaDataset, collate_fn
10
+
11
+ def train():
12
+ """
13
+ Massive Scale Training Loop.
14
+ """
15
+ # 1. Setup
16
+ config = HexaConfig()
17
+
18
+ # Gradient Accumulation is CRITICAL for large models on small GPUs
19
+ accelerator = Accelerator(gradient_accumulation_steps=16)
20
+
21
+ print(f"Initializing 5B Parameter Model... (This takes memory!)")
22
+ try:
23
+ model = build_model()
24
+ except RuntimeError as e:
25
+ print(f"Error initializing full model: {e}")
26
+ print("Fallback: Your GPU memory is too small for 5B. Please try reducing config.dim in config.py")
27
+ return
28
+
29
+ # 2. Data
30
+ data_root = "d:\\hexatts\\data"
31
+ if not os.path.exists(data_root):
32
+ print("Data not found. Run 'python get_data.py' first.")
33
+ return
34
+
35
+ dataset = HexaDataset(data_root, config)
36
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
37
+
38
+ # 3. Optimize
39
+ optimizer = AdamW(model.parameters(), lr=1e-4) # Standard LR
40
+
41
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
42
+
43
+ print("Starting Training...")
44
+ model.train()
45
+
46
+ # 4. Loop
47
+ global_step = 0
48
+ epochs = 5 # arbitrary for demo
49
+
50
+ for epoch in range(epochs):
51
+ progress_bar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
52
+
53
+ for batch in dataloader:
54
+ with accelerator.accumulate(model):
55
+ text, speakers, langs, emotions, target_mels = batch
56
+
57
+ # Check shapes
58
+ # Output: [Batch, Time, Channels]
59
+ # Target: [Batch, Time, Channels]
60
+
61
+ output_mels = model(text, speakers, langs, emotions)
62
+
63
+ # Align lengths (Simple truncation to min length for loss)
64
+ min_len = min(output_mels.shape[1], target_mels.shape[1])
65
+ output_sliced = output_mels[:, :min_len, :]
66
+ target_sliced = target_mels[:, :min_len, :]
67
+
68
+ loss = torch.nn.functional.mse_loss(output_sliced, target_sliced)
69
+
70
+ accelerator.backward(loss)
71
+ optimizer.step()
72
+ optimizer.zero_grad()
73
+
74
+ progress_bar.set_postfix(loss=loss.item())
75
+ progress_bar.update(1)
76
+ global_step += 1
77
+
78
+ # Save Checkpoint
79
+ save_path = os.path.join("checkpoints", f"checkpoint_epoch_{epoch}")
80
+ os.makedirs(save_path, exist_ok=True)
81
+ accelerator.save_state(save_path)
82
+ print(f"Saved checkpoint to {save_path}")
83
+
84
+ if __name__ == "__main__":
85
+ train()
src/train_hf.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import Trainer, TrainingArguments
3
+ from .hf_model import HexaModel, HexaHFConfig
4
+ from .dataset import HexaDataset
5
+ from .config import HexaConfig
6
+
7
+ # Data Collator for HF Trainer
8
+ def data_collator(features):
9
+ # Features is a list of tuples from Dataset.__getitem__
10
+ # (text_ids, speaker, lang, emotion, mel)
11
+
12
+ batch_text = [f[0] for f in features]
13
+ batch_speaker = [f[1] for f in features]
14
+ batch_lang = [f[2] for f in features]
15
+ batch_emotion = [f[3] for f in features]
16
+ batch_mel = [f[4] for f in features]
17
+
18
+ # Pad
19
+ txt = torch.nn.utils.rnn.pad_sequence(batch_text, batch_first=True, padding_value=0)
20
+ mel = torch.nn.utils.rnn.pad_sequence(batch_mel, batch_first=True, padding_value=0.0)
21
+
22
+ spk = torch.stack(batch_speaker)
23
+ lng = torch.stack(batch_lang)
24
+ emo = torch.stack(batch_emotion)
25
+
26
+ return {
27
+ "text_ids": txt,
28
+ "speaker_ids": spk,
29
+ "language_ids": lng,
30
+ "emotion_ids": emo,
31
+ "labels": mel
32
+ }
33
+
34
+ def train():
35
+ print("Initializing Hexa TTS (5B Config) with HuggingFace Trainer...")
36
+
37
+ # 1. Config & Model
38
+ # Setting dim=3200, depth=40 -> ~5B params
39
+ hexa_conf = HexaConfig(dim=3200, depth=40, heads=32, dim_head=100)
40
+ hf_config = HexaHFConfig()
41
+ hf_config.hexa_config = hexa_conf
42
+
43
+ model = HexaModel(hf_config)
44
+
45
+ # 2. Data
46
+ data_root = "d:\\hexatts\\data"
47
+ dataset = HexaDataset(data_root, hexa_conf)
48
+
49
+ # 3. Training Arguments for Memory Optimization
50
+ args = TrainingArguments(
51
+ output_dir="./hexa_checkpoints",
52
+ per_device_train_batch_size=1, # Must be 1 for 5B model on single GPU
53
+ gradient_accumulation_steps=16, # Simulate batch size 16
54
+ learning_rate=1e-4,
55
+ num_train_epochs=3,
56
+ logging_steps=1,
57
+ save_steps=100,
58
+ fp16=False, # Enable if you have Tensor Cores (NVIDIA)
59
+ gradient_checkpointing=True, # CRITICAL for 5B model memory
60
+ dataloader_num_workers=0,
61
+ report_to="tensorboard"
62
+ )
63
+
64
+ trainer = Trainer(
65
+ model=model,
66
+ args=args,
67
+ train_dataset=dataset,
68
+ data_collator=data_collator,
69
+ )
70
+
71
+ print("Starting Training...")
72
+ trainer.train()
73
+
74
+ if __name__ == "__main__":
75
+ train()