hexa-tts-trainer / src /config.py
Hexa09's picture
Upload folder using huggingface_hub
4650d1f verified
from dataclasses import dataclass, field
from typing import List
@dataclass
class HexaConfig:
"""
Configuration for Hexa TTS 5B Model.
Designed to scale to ~5 Billion parameters.
"""
# Model Architecture
dim: int = 3200 # Tuned for ~5B params (4.92B)
depth: int = 40 # Number of layers
heads: int = 32 # Number of attention heads
dim_head: int = 100 # Dimension of each head
mlp_ratio: float = 4.0 # Feedforward expansion factor
dropout: float = 0.1
# Input / Output
num_languages: int = 15
vocab_size: int = 256 # Size of phoneme/text vocabulary
num_speakers: int = 10000 # Embedding slot for speakers
num_emotions: int = 32 # Distinct emotion categories
# Audio Settings
sample_rate: int = 24000
n_mel_channels: int = 100
n_fft: int = 1024
hop_length: int = 256
win_length: int = 1024
# Context
max_text_len: int = 1024
max_audio_len: int = 4096 # In mel frames
# Checkpoints
checkpoint_path: str = "checkpoints/hexa_5b_latest.pt"
def __post_init__(self):
# Rough parameter count estimation:
# 12 * layers * dim^2 (approximate for standard transformer)
total_params = 12 * self.depth * (self.dim ** 2)
print(f"Hexa Config initialized.")
print(f"Approximate Model Size: {total_params / 1e9:.2f} Billion parameters")