Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass, field | |
| from typing import List | |
| class HexaConfig: | |
| """ | |
| Configuration for Hexa TTS 5B Model. | |
| Designed to scale to ~5 Billion parameters. | |
| """ | |
| # Model Architecture | |
| dim: int = 3200 # Tuned for ~5B params (4.92B) | |
| depth: int = 40 # Number of layers | |
| heads: int = 32 # Number of attention heads | |
| dim_head: int = 100 # Dimension of each head | |
| mlp_ratio: float = 4.0 # Feedforward expansion factor | |
| dropout: float = 0.1 | |
| # Input / Output | |
| num_languages: int = 15 | |
| vocab_size: int = 256 # Size of phoneme/text vocabulary | |
| num_speakers: int = 10000 # Embedding slot for speakers | |
| num_emotions: int = 32 # Distinct emotion categories | |
| # Audio Settings | |
| sample_rate: int = 24000 | |
| n_mel_channels: int = 100 | |
| n_fft: int = 1024 | |
| hop_length: int = 256 | |
| win_length: int = 1024 | |
| # Context | |
| max_text_len: int = 1024 | |
| max_audio_len: int = 4096 # In mel frames | |
| # Checkpoints | |
| checkpoint_path: str = "checkpoints/hexa_5b_latest.pt" | |
| def __post_init__(self): | |
| # Rough parameter count estimation: | |
| # 12 * layers * dim^2 (approximate for standard transformer) | |
| total_params = 12 * self.depth * (self.dim ** 2) | |
| print(f"Hexa Config initialized.") | |
| print(f"Approximate Model Size: {total_params / 1e9:.2f} Billion parameters") | |