RASMUS's picture
Upload Finnish Chatterbox model
67ea4ca verified
from dataclasses import dataclass
@dataclass
class TrainConfig:
# --- Paths ---
# Directory where setup.py downloaded the files
# Using the original pretrained_models directory which now contains the English-only base weights
model_dir: str = "./pretrained_models"
# Path to your metadata CSV (Format: ID|RawText|NormText)
csv_path: str = "./chatterbox_midtune_cc_data_16k/metadata.csv"
# Directory containing WAV files
wav_dir: str = "./chatterbox_midtune_cc_data_16k"
# Attribution file for speaker-aware splitting
attribution_path: str = "./chatterbox_midtune_cc_data_16k/attribution.csv"
preprocessed_dir = "./chatterbox_midtune_cc_data_16k/preprocess"
# Output directory for the finetuned model
# Changed to differentiate from the English-only run
output_dir: str = "./chatterbox_output_multilingual"
ljspeech = True # Set True if the dataset format is ljspeech, and False if it's file-based.
preprocess = True # If you've already done preprocessing once, set it to false.
is_turbo: bool = False # Set True if you're training Turbo, False if you're training Standard (multilingual, stronger)
# --- OOD Evaluation ---
# These speakers are strictly excluded from training and validation
ood_speakers = ["cv-15_11", "cv-15_16", "cv-15_2"]
# --- Vocabulary ---
# The size of the NEW vocabulary (from tokenizer.json)
# Ensure this matches the JSON file generated by your tokenizer script.
# For Turbo mode: Use the exact number provided by setup.py (e.g., 52260)
new_vocab_size: int = 52260 if is_turbo else 2454
# --- Hyperparameters ---
batch_size: int = 16 # Adjust based on VRAM
grad_accum: int = 2 # Effective Batch Size = 64
learning_rate: float = 2e-5 # Research-optimized LR with warmup
num_epochs: int = 5 # Run exactly 5 epochs
weight_decay: float = 0.05 # Defensive weight decay
# Training Strategy:
# Stage 1 (Current): Multi-speaker Finnish -> 3-5 epochs, lower LR
# Stage 2 (Later): Single speaker voice clone -> 50-150 epochs, higher LR
# --- Validation ---
validation_split: float = 0.05 # 5% of data for validation
validation_seed: int = 42 # For reproducible train/val split
# --- Constraints ---
min_training_duration: float = 4.0 # Filter samples shorter than this
min_training_snr: float = 35.0 # Filter samples with SNR lower than this
max_training_snr: float = 100.0 # Filter samples with SNR higher than this (digital artifacts)
start_text_token = 255
stop_text_token = 0
max_text_len: int = 256
max_speech_len: int = 1024 # Truncates very long audio
prompt_duration: float = 3.0 # Duration for the reference prompt (seconds)