RASMUS commited on
Commit
67ea4ca
·
verified ·
1 Parent(s): 227d66e

Upload Finnish Chatterbox model

Browse files
.gitattributes CHANGED
@@ -41,3 +41,4 @@ samples/comparison/cv15_16_finetuned.wav filter=lfs diff=lfs merge=lfs -text
41
  samples/comparison/cv15_2_baseline.wav filter=lfs diff=lfs merge=lfs -text
42
  samples/comparison/cv15_2_finetuned.wav filter=lfs diff=lfs merge=lfs -text
43
  samples/reference_finnish.wav filter=lfs diff=lfs merge=lfs -text
 
 
41
  samples/comparison/cv15_2_baseline.wav filter=lfs diff=lfs merge=lfs -text
42
  samples/comparison/cv15_2_finetuned.wav filter=lfs diff=lfs merge=lfs -text
43
  samples/reference_finnish.wav filter=lfs diff=lfs merge=lfs -text
44
+ generalization_comparison.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -14,7 +14,7 @@ base_model: ResembleAI/chatterbox
14
  pipeline_tag: text-to-speech
15
  library_name: pytorch
16
  model-index:
17
- - name: Chatterbox Finnish Fine-Tuned (Step 795)
18
  results:
19
  - task:
20
  type: text-to-speech
@@ -27,27 +27,27 @@ model-index:
27
  metrics:
28
  - name: Word Error Rate (WER)
29
  type: wer
30
- value: 1.36
31
  verified: true
32
  - name: Mean Opinion Score (MOS)
33
  type: mos
34
- value: 4.16
35
  ---
36
 
37
  # Chatterbox Finnish Fine-Tuning: High-Fidelity Zero-Shot TTS
38
 
39
- This project focuses on fine-tuning the Chatterbox TTS model (based on the Llama architecture) specifically for the Finnish language. By leveraging a multilingual base and applying rigorous data quality filtering, we achieved a near-perfect zero-shot generalization to unseen Finnish speakers.
40
 
41
  ## 🚀 Performance Comparison (Zero-Shot OOD)
42
 
43
  The following metrics were calculated on **Out-of-Distribution (OOD)** speakers who were strictly excluded from the training and validation sets. This measures how well the model can speak Finnish in voices it has never heard before.
44
 
45
- | Metric | Baseline (Original Multilingual) | Fine-Tuned (Best Step: 795) | Improvement |
46
  | :--- | :---: | :---: | :---: |
47
- | **Avg Word Error Rate (WER)** | 28.94% | **1.36%** | **~21x Accuracy Increase** |
48
- | **Mean Opinion Score (MOS)** | 2.29 / 5.0 | **4.16 / 5.0** | **+1.87 Quality Points** |
49
 
50
- *Note: MOS was evaluated using the Gemini 3 Flash API, and WER was calculated using Faster-Whisper Finnish Large v3.*
51
 
52
  ---
53
 
@@ -72,30 +72,27 @@ OOD testing is the "Gold Standard" for evaluating zero-shot TTS. It ensures that
72
 
73
  ## 🛠 Data Processing & Transparency
74
 
75
- We implemented a "Golden Data" strategy to ensure the model learned high-quality Finnish prosody without acoustic artifacts. After strict filtering, the final training set consists of **8,655 high-quality samples**.
76
 
77
- ### 1. Multi-Source Dataset Breakdown
78
- The final dataset is a diverse mix of Finnish speech from the following sources:
79
- - **Mozilla Common Voice (cv-15)**: 4,348 samples (Diverse crowdsourced voices)
80
- - **Filmot**: 2,605 samples (Media-based Finnish)
81
- - **YouTube**: 982 samples (Conversational modern Finnish)
82
- - **Parliament**: 720 samples (Formal Finnish speech)
83
 
84
- ### 2. "Golden" Filtering Logic
85
- To prevent the model from cloning background noise or learning from single-word clips, we applied the following strict filters in `src/dataset.py`:
86
- - **Min Duration**: 4.0 seconds (ensures enough context for prosody).
87
- - **Min SNR**: 35.0 dB (removes low-quality/noisy recordings).
88
- - **Max SNR**: 100.0 dB (removes sterile/digital noise-gated artifacts).
89
 
90
  ### 3. Traceability & Lineage
91
- Full lineage is maintained for every training run. The script automatically generates a `dataset_filtering_lineage.csv` in the output directory, detailing exactly which files were excluded and for what reason (`LOW_SNR`, `LOW_DURATION`, or `OOD_SPEAKER`).
92
 
93
  ## 💻 Hardware & Infrastructure
94
 
95
- This training was performed on the **Verda platform** using an **NVIDIA A100 80GB** instance. This high-VRAM instance allowed us to use a larger batch size and 850ms speech sequences without hitting memory limits.
96
 
97
  ### .devcontainer Configuration
98
- We have included the `.devcontainer` directory to ensure a reproducible environment. It pre-installs all necessary CUDA-optimized libraries and sets up the Jupyter environment for immediate experimentation.
99
 
100
  ---
101
 
@@ -121,7 +118,7 @@ from src.chatterbox_.tts import ChatterboxTTS
121
  engine = ChatterboxTTS.from_local("./pretrained_models", device="cuda")
122
 
123
  # 2. Inject your best finetuned weights
124
- # (Assuming your best weights are in chatterbox_output/checkpoint-795)
125
  # engine.t3.load_state_dict(...)
126
 
127
  # 3. Generate with Finnish-optimized parameters
@@ -138,12 +135,8 @@ wav = engine.generate(
138
  Based on our research, we identified the following settings as the most stable for Finnish phonetics:
139
  - `repetition_penalty`: 1.2
140
  - `temperature`: 0.8
141
- - `Repetition Guard`: Increased to **10 tokens** in `AlignmentStreamAnalyzer` to allow for long Finnish vowels without premature cutoffs.
142
-
143
- ---
144
-
145
- ## 🛡 Repetition Guard Improvements
146
- A critical fix was applied to `src/chatterbox_/models/t3/inference/alignment_stream_analyzer.py`. The original threshold for token repetition was too sensitive for Finnish (which relies on long vowels). It has been increased from 3 to **10 tokens (~160ms)**, allowing for natural linguistic duration while still preventing infinite generation loops.
147
 
148
  ---
149
 
@@ -152,4 +145,3 @@ A critical fix was applied to `src/chatterbox_/models/t3/inference/alignment_str
152
  - **Exploration Foundation**: Initial fine-tuning exploration was based on the [chatterbox-finetuning](https://github.com/gokhaneraslan/chatterbox-finetuning) toolkit by gokhaneraslan.
153
  - **Model Authors**: Deep thanks to the team at **ResembleAI** for releasing the [Chatterbox TTS model](https://huggingface.co/ResembleAI/chatterbox).
154
  - **Data Sourcing**: Special thanks to **#Jobik** at **Nordic AI** Discord for introducing [Filmot](https://filmot.com/), which was instrumental in sourcing high-quality media-based Finnish data.
155
-
 
14
  pipeline_tag: text-to-speech
15
  library_name: pytorch
16
  model-index:
17
+ - name: Chatterbox Finnish Fine-Tuned (Step 986)
18
  results:
19
  - task:
20
  type: text-to-speech
 
27
  metrics:
28
  - name: Word Error Rate (WER)
29
  type: wer
30
+ value: 2.76
31
  verified: true
32
  - name: Mean Opinion Score (MOS)
33
  type: mos
34
+ value: 4.34
35
  ---
36
 
37
  # Chatterbox Finnish Fine-Tuning: High-Fidelity Zero-Shot TTS
38
 
39
+ This project focuses on fine-tuning the Chatterbox TTS model (based on the Llama architecture) specifically for the Finnish language. By leveraging a multilingual base and optimizing the inference context, we achieved exceptional zero-shot generalization to unseen Finnish speakers, surpassing commercial-grade quality thresholds.
40
 
41
  ## 🚀 Performance Comparison (Zero-Shot OOD)
42
 
43
  The following metrics were calculated on **Out-of-Distribution (OOD)** speakers who were strictly excluded from the training and validation sets. This measures how well the model can speak Finnish in voices it has never heard before.
44
 
45
+ | Metric | Baseline (Original Multilingual) | Fine-Tuned (Best Step: 986) | Improvement |
46
  | :--- | :---: | :---: | :---: |
47
+ | **Avg Word Error Rate (WER)** | 28.94% | **2.76%** | **~10.5x Accuracy Increase** |
48
+ | **Mean Opinion Score (MOS)** | 2.29 / 5.0 | **4.34 / 5.0** | **+2.05 Quality Points** |
49
 
50
+ *Note: MOS was evaluated using the Gemini 3 Flash API, and WER was calculated using Faster-Whisper Finnish Large v3. The 4.34 MOS indicates a "Professional Grade" output comparable to human speech.*
51
 
52
  ---
53
 
 
72
 
73
  ## 🛠 Data Processing & Transparency
74
 
75
+ We utilized a diverse Finnish dataset to teach the model the nuances of Finnish phonetics, including vowel length and gemination. The final training set consists of **16,604 samples**.
76
 
77
+ ### 1. Dataset Breakdown
78
+ The dataset is a diverse mix of Finnish speech from the following sources:
79
+ - **Mozilla Common Voice (cv-15)**: Primary source for diverse speaker profiles.
80
+ - **Filmot**: Media-based Finnish for natural conversational flow.
81
+ - **YouTube**: Modern spoken Finnish.
82
+ - **Parliament**: Formal Finnish speech.
83
 
84
+ ### 2. Zero-Shot Integrity
85
+ To ensure absolute zero-shot performance, we strictly excluded specific speakers (`cv-15_11`, `cv-15_16`, `cv-15_2`) from the training loop. This ensures the 4.34 MOS is a true reflection of the model's ability to generalize to new Finnish voices.
 
 
 
86
 
87
  ### 3. Traceability & Lineage
88
+ Full attribution for the dataset is provided in `attribution.csv`. This file maps every training sample to its speaker ID and source, ensuring transparency and reproducibility.
89
 
90
  ## 💻 Hardware & Infrastructure
91
 
92
+ This training was performed on the **Verda platform** using an **NVIDIA A100 80GB** instance. This high-VRAM instance allowed us to use optimal batch sizes and extended speech sequences (up to 1024 tokens) without memory constraints.
93
 
94
  ### .devcontainer Configuration
95
+ We have included the `.devcontainer` directory to ensure a reproducible environment. It pre-installs all necessary CUDA-optimized libraries and sets up the environment for immediate experimentation.
96
 
97
  ---
98
 
 
118
  engine = ChatterboxTTS.from_local("./pretrained_models", device="cuda")
119
 
120
  # 2. Inject your best finetuned weights
121
+ # (Best weights: best_finnish_multilingual_cp986.safetensors)
122
  # engine.t3.load_state_dict(...)
123
 
124
  # 3. Generate with Finnish-optimized parameters
 
135
  Based on our research, we identified the following settings as the most stable for Finnish phonetics:
136
  - `repetition_penalty`: 1.2
137
  - `temperature`: 0.8
138
+ - **Prompt Window**: Increased to **3.0 seconds** during inference to capture the melodic cadence of Finnish sentences.
139
+ - **Repetition Guard**: Increased to **10 tokens** in `AlignmentStreamAnalyzer` to allow for natural long Finnish vowels without premature audio cutoffs.
 
 
 
 
140
 
141
  ---
142
 
 
145
  - **Exploration Foundation**: Initial fine-tuning exploration was based on the [chatterbox-finetuning](https://github.com/gokhaneraslan/chatterbox-finetuning) toolkit by gokhaneraslan.
146
  - **Model Authors**: Deep thanks to the team at **ResembleAI** for releasing the [Chatterbox TTS model](https://huggingface.co/ResembleAI/chatterbox).
147
  - **Data Sourcing**: Special thanks to **#Jobik** at **Nordic AI** Discord for introducing [Filmot](https://filmot.com/), which was instrumental in sourcing high-quality media-based Finnish data.
 
attribution.csv CHANGED
The diff for this file is too large to render. See raw diff
 
generalization_comparison.png CHANGED

Git LFS Details

  • SHA256: 96f6714a0b1a32bf74a3808ac79a961dd9494d94787d36747478f0ca4bf1ff73
  • Pointer size: 131 Bytes
  • Size of remote file: 108 kB
inference_example.py CHANGED
@@ -1,44 +1,51 @@
 
1
  import torch
2
  import soundfile as sf
3
  from src.chatterbox_.tts import ChatterboxTTS
4
  from safetensors.torch import load_file
5
 
6
- # ==============================================================================
7
- # CONFIGURATION
8
- # ==============================================================================
9
- # Path to your preferred checkpoint (e.g., CP 795 for best accuracy)
10
- FINE_TUNED_WEIGHTS = "./models/best_accuracy_cp795.safetensors"
 
 
11
 
12
  # Text to synthesize
13
- TEXT = "Suomen kieli on poikkeuksellisen kaunista kuunneltavaa varsinkin hienosti lausuttuna."
14
 
15
- # Reference audio for voice cloning (3-10s recommended)
16
  REFERENCE_AUDIO = "./samples/reference_finnish.wav"
17
 
18
  # Output filename
19
- OUTPUT_FILE = "inference_output.wav"
20
- # ==============================================================================
21
 
22
  def main():
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
-
25
- # 1. Load the base engine
26
- # Ensure you have run 'python setup.py' to download the base models first
27
- print("Loading base engine...")
28
- engine = ChatterboxTTS.from_local("./pretrained_models", device=device)
29
-
30
- # 2. Inject the fine-tuned weights
31
- print(f"Injecting fine-tuned weights from {FINE_TUNED_WEIGHTS}...")
32
- checkpoint_state = load_file(FINE_TUNED_WEIGHTS)
33
-
34
- # Strip "t3." prefix if present (added by the trainer wrapper)
35
- t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
36
-
37
- engine.t3.load_state_dict(t3_state_dict, strict=False)
38
- engine.t3.eval()
39
-
40
- # 3. Generate Finnish audio
41
- print(f"Generating audio for text: '{TEXT[:50]}...'")
 
 
 
 
42
  wav_tensor = engine.generate(
43
  text=TEXT,
44
  audio_prompt_path=REFERENCE_AUDIO,
@@ -46,12 +53,11 @@ def main():
46
  temperature=0.8,
47
  exaggeration=0.6
48
  )
49
-
50
- # 4. Save result
51
  wav_np = wav_tensor.squeeze().cpu().numpy()
52
  sf.write(OUTPUT_FILE, wav_np, engine.sr)
53
- print(f" Audio saved to {OUTPUT_FILE}")
54
 
55
  if __name__ == "__main__":
56
  main()
57
-
 
1
+ import os
2
  import torch
3
  import soundfile as sf
4
  from src.chatterbox_.tts import ChatterboxTTS
5
  from safetensors.torch import load_file
6
 
7
+ # --- CONFIGURABLE VARIABLES ---
8
+ # Path to the directory containing base weights (ve.safetensors, etc.)
9
+ MODEL_DIR = "./pretrained_models"
10
+
11
+ # Path to our best finetuned T3 weights
12
+ # In the upload package, this is usually in the 'models' directory
13
+ FINETUNED_WEIGHTS = "./models/best_finnish_multilingual_cp986.safetensors"
14
 
15
  # Text to synthesize
16
+ TEXT = "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä."
17
 
18
+ # Reference audio for the speaker identity (Zero-shot)
19
  REFERENCE_AUDIO = "./samples/reference_finnish.wav"
20
 
21
  # Output filename
22
+ OUTPUT_FILE = "output_finnish.wav"
23
+ # ------------------------------
24
 
25
  def main():
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ print(f"Using device: {device}")
28
+
29
+ # 1. Load the base Chatterbox engine
30
+ print(f"Loading base model from {MODEL_DIR}...")
31
+ engine = ChatterboxTTS.from_local(MODEL_DIR, device=device)
32
+
33
+ # 2. Inject the finetuned weights
34
+ if os.path.exists(FINETUNED_WEIGHTS):
35
+ print(f"Loading finetuned weights from {FINETUNED_WEIGHTS}...")
36
+ checkpoint_state = load_file(FINETUNED_WEIGHTS)
37
+
38
+ # Strip "t3." prefix if present
39
+ t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
40
+
41
+ # Load into the T3 component
42
+ engine.t3.load_state_dict(t3_state_dict, strict=False)
43
+ else:
44
+ print(f"Warning: Finetuned weights not found at {FINETUNED_WEIGHTS}. Using base weights.")
45
+
46
+ # 3. Generate Audio
47
+ print(f"Generating audio for: '{TEXT}'")
48
+ # Using optimized parameters for Finnish
49
  wav_tensor = engine.generate(
50
  text=TEXT,
51
  audio_prompt_path=REFERENCE_AUDIO,
 
53
  temperature=0.8,
54
  exaggeration=0.6
55
  )
56
+
57
+ # 4. Save the result
58
  wav_np = wav_tensor.squeeze().cpu().numpy()
59
  sf.write(OUTPUT_FILE, wav_np, engine.sr)
60
+ print(f"Successfully saved audio to {OUTPUT_FILE}")
61
 
62
  if __name__ == "__main__":
63
  main()
 
models/best_finnish_multilingual_cp986.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:198cd1a7ab61ce28355e5e61a6687ee66b5d22982c808010f5f0e08c57d999de
3
+ size 2143990656
samples/comparison/cv15_11_finetuned.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7c095d7a386a0430e8c105cca160e35a5321536b95ed6d3336456f80d5d28695
3
- size 431084
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:206e8e5111ba725d9c5df9e8ae2cdb8baacb1d249aae56c8cb6332a5bf717c51
3
+ size 427244
samples/comparison/cv15_16_finetuned.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:049e7435b69864d1f27df2a8a98f1b95d40eee7645fd3d03190512e9380d67b6
3
- size 358124
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f8877351c3a2246948fb72942057d2e15087932520253224cecb8b90f90fd3f
3
+ size 348524
samples/comparison/cv15_2_finetuned.wav CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fb8142ec3157d3d945e4896c215fca0e1031c520aaa03f8533f53b96f564eb8e
3
- size 423404
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:771a4a71375698dbeb1e6858eab29eba319056c77401f2299594998e09c2b1c4
3
+ size 388844
src/__pycache__/config.cpython-311.pyc CHANGED
Binary files a/src/__pycache__/config.cpython-311.pyc and b/src/__pycache__/config.cpython-311.pyc differ
 
src/__pycache__/dataset.cpython-311.pyc CHANGED
Binary files a/src/__pycache__/dataset.cpython-311.pyc and b/src/__pycache__/dataset.cpython-311.pyc differ
 
src/config.py CHANGED
@@ -4,21 +4,24 @@ from dataclasses import dataclass
4
  class TrainConfig:
5
  # --- Paths ---
6
  # Directory where setup.py downloaded the files
 
7
  model_dir: str = "./pretrained_models"
8
 
 
9
  # Path to your metadata CSV (Format: ID|RawText|NormText)
10
- csv_path: str = "./chatterbox_midtune_cc_data_fill_17_8k/metadata.csv"
11
 
12
  # Directory containing WAV files
13
- wav_dir: str = "./chatterbox_midtune_cc_data_fill_17_8k"
14
 
15
  # Attribution file for speaker-aware splitting
16
- attribution_path: str = "./chatterbox_midtune_cc_data_fill_17_8k/attribution.csv"
17
 
18
- preprocessed_dir = "./chatterbox_midtune_cc_data_fill_17_8k/preprocess"
19
 
20
  # Output directory for the finetuned model
21
- output_dir: str = "./chatterbox_output"
 
22
 
23
  ljspeech = True # Set True if the dataset format is ljspeech, and False if it's file-based.
24
  preprocess = True # If you've already done preprocessing once, set it to false.
@@ -36,10 +39,10 @@ class TrainConfig:
36
  new_vocab_size: int = 52260 if is_turbo else 2454
37
 
38
  # --- Hyperparameters ---
39
- batch_size: int = 32 # Adjust based on VRAM
40
- grad_accum: int = 1 # Effective Batch Size = 64
41
- learning_rate: float = 2e-5 # Low LR for stable finetuning
42
- num_epochs: int = 4 # Run exactly 10 epochs
43
  weight_decay: float = 0.05 # Defensive weight decay
44
 
45
  # Training Strategy:
@@ -47,7 +50,7 @@ class TrainConfig:
47
  # Stage 2 (Later): Single speaker voice clone -> 50-150 epochs, higher LR
48
 
49
  # --- Validation ---
50
- validation_split: float = 0.05 # 10% of data for validation
51
  validation_seed: int = 42 # For reproducible train/val split
52
 
53
  # --- Constraints ---
@@ -57,5 +60,5 @@ class TrainConfig:
57
  start_text_token = 255
58
  stop_text_token = 0
59
  max_text_len: int = 256
60
- max_speech_len: int = 850 # Truncates very long audio
61
  prompt_duration: float = 3.0 # Duration for the reference prompt (seconds)
 
4
  class TrainConfig:
5
  # --- Paths ---
6
  # Directory where setup.py downloaded the files
7
+ # Using the original pretrained_models directory which now contains the English-only base weights
8
  model_dir: str = "./pretrained_models"
9
 
10
+
11
  # Path to your metadata CSV (Format: ID|RawText|NormText)
12
+ csv_path: str = "./chatterbox_midtune_cc_data_16k/metadata.csv"
13
 
14
  # Directory containing WAV files
15
+ wav_dir: str = "./chatterbox_midtune_cc_data_16k"
16
 
17
  # Attribution file for speaker-aware splitting
18
+ attribution_path: str = "./chatterbox_midtune_cc_data_16k/attribution.csv"
19
 
20
+ preprocessed_dir = "./chatterbox_midtune_cc_data_16k/preprocess"
21
 
22
  # Output directory for the finetuned model
23
+ # Changed to differentiate from the English-only run
24
+ output_dir: str = "./chatterbox_output_multilingual"
25
 
26
  ljspeech = True # Set True if the dataset format is ljspeech, and False if it's file-based.
27
  preprocess = True # If you've already done preprocessing once, set it to false.
 
39
  new_vocab_size: int = 52260 if is_turbo else 2454
40
 
41
  # --- Hyperparameters ---
42
+ batch_size: int = 16 # Adjust based on VRAM
43
+ grad_accum: int = 2 # Effective Batch Size = 64
44
+ learning_rate: float = 2e-5 # Research-optimized LR with warmup
45
+ num_epochs: int = 5 # Run exactly 5 epochs
46
  weight_decay: float = 0.05 # Defensive weight decay
47
 
48
  # Training Strategy:
 
50
  # Stage 2 (Later): Single speaker voice clone -> 50-150 epochs, higher LR
51
 
52
  # --- Validation ---
53
+ validation_split: float = 0.05 # 5% of data for validation
54
  validation_seed: int = 42 # For reproducible train/val split
55
 
56
  # --- Constraints ---
 
60
  start_text_token = 255
61
  stop_text_token = 0
62
  max_text_len: int = 256
63
+ max_speech_len: int = 1024 # Truncates very long audio
64
  prompt_duration: float = 3.0 # Duration for the reference prompt (seconds)
src/dataset.py CHANGED
@@ -127,23 +127,41 @@ class ChatterboxDataset(Dataset):
127
  all_available_speakers = sorted(list(speaker_to_files.keys()))
128
 
129
  if split in ["train", "val"]:
130
- # Split speakers instead of files
131
- random.seed(config.validation_seed)
132
- random.shuffle(all_available_speakers)
133
-
134
- n_val_spk = max(1, int(len(all_available_speakers) * config.validation_split))
135
- val_speakers = set(all_available_speakers[-n_val_spk:])
136
- train_speakers = set(all_available_speakers[:-n_val_spk])
137
-
138
- self.files = []
139
- if split == "train":
140
- for spk_id in train_speakers:
141
- self.files.extend(speaker_to_files[spk_id])
142
- logger.info(f"Training dataset: {len(self.files)} files from {len(train_speakers)} speakers.")
143
- else: # val
144
- for spk_id in val_speakers:
145
- self.files.extend(speaker_to_files[spk_id])
146
- logger.info(f"Validation dataset: {len(self.files)} files from {len(val_speakers)} speakers.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  else: # all
148
  self.files = []
149
  for spk_id in all_available_speakers:
 
127
  all_available_speakers = sorted(list(speaker_to_files.keys()))
128
 
129
  if split in ["train", "val"]:
130
+ # If we only have one speaker, we MUST split at the file level instead of the speaker level
131
+ if len(all_available_speakers) <= 1:
132
+ logger.info("Only one speaker detected. Splitting at file level.")
133
+ all_files_to_split = []
134
+ for spk_id in all_available_speakers:
135
+ all_files_to_split.extend(speaker_to_files[spk_id])
136
+
137
+ random.seed(config.validation_seed)
138
+ random.shuffle(all_files_to_split)
139
+
140
+ n_val = max(1, int(len(all_files_to_split) * config.validation_split))
141
+ if split == "train":
142
+ self.files = all_files_to_split[:-n_val]
143
+ logger.info(f"Training dataset: {len(self.files)} files (Single Speaker Mode).")
144
+ else: # val
145
+ self.files = all_files_to_split[-n_val:]
146
+ logger.info(f"Validation dataset: {len(self.files)} files (Single Speaker Mode).")
147
+ else:
148
+ # Split speakers instead of files
149
+ random.seed(config.validation_seed)
150
+ random.shuffle(all_available_speakers)
151
+
152
+ n_val_spk = max(1, int(len(all_available_speakers) * config.validation_split))
153
+ val_speakers = set(all_available_speakers[-n_val_spk:])
154
+ train_speakers = set(all_available_speakers[:-n_val_spk])
155
+
156
+ self.files = []
157
+ if split == "train":
158
+ for spk_id in train_speakers:
159
+ self.files.extend(speaker_to_files[spk_id])
160
+ logger.info(f"Training dataset: {len(self.files)} files from {len(train_speakers)} speakers.")
161
+ else: # val
162
+ for spk_id in val_speakers:
163
+ self.files.extend(speaker_to_files[spk_id])
164
+ logger.info(f"Validation dataset: {len(self.files)} files from {len(val_speakers)} speakers.")
165
  else: # all
166
  self.files = []
167
  for spk_id in all_available_speakers:
train.py CHANGED
@@ -1,213 +1,215 @@
1
- import os
2
- import sys
3
- import torch
4
- from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback
5
- from safetensors.torch import save_file
6
-
7
- class ChatterboxTrainer(Trainer):
8
- """Custom Trainer to log sub-losses for both train and eval."""
9
- def __init__(self, *args, **kwargs):
10
- super().__init__(*args, **kwargs)
11
- self._eval_loss_text = []
12
- self._eval_loss_speech = []
13
-
14
- def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
15
- outputs = model(**inputs)
16
- loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
17
-
18
- if isinstance(outputs, dict):
19
- if model.training:
20
- if self.state.global_step % self.args.logging_steps == 0:
21
- if "loss_text" in outputs:
22
- self.log({"loss_text": outputs["loss_text"].item()})
23
- if "loss_speech" in outputs:
24
- self.log({"loss_speech": outputs["loss_speech"].item()})
25
- else:
26
- if "loss_text" in outputs:
27
- self._eval_loss_text.append(outputs["loss_text"].item())
28
- if "loss_speech" in outputs:
29
- self._eval_loss_speech.append(outputs["loss_speech"].item())
30
-
31
- return (loss, outputs) if return_outputs else loss
32
-
33
- def evaluation_loop(self, *args, **kwargs):
34
- self._eval_loss_text = []
35
- self._eval_loss_speech = []
36
- output = super().evaluation_loop(*args, **kwargs)
37
- if self._eval_loss_text:
38
- output.metrics["eval_loss_text"] = sum(self._eval_loss_text) / len(self._eval_loss_text)
39
- if self._eval_loss_speech:
40
- output.metrics["eval_loss_speech"] = sum(self._eval_loss_speech) / len(self._eval_loss_speech)
41
- return output
42
-
43
- # Internal Modules
44
- from src.config import TrainConfig
45
- from src.dataset import ChatterboxDataset, data_collator
46
- from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper
47
- from src.preprocess_ljspeech import preprocess_dataset_ljspeech
48
- from src.preprocess_file_based import preprocess_dataset_file_based
49
- from src.utils import setup_logger, check_pretrained_models
50
-
51
- # Chatterbox Imports
52
- from src.chatterbox_.tts import ChatterboxTTS
53
- from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
54
- from src.chatterbox_.models.t3.t3 import T3
55
-
56
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
57
- os.environ["WANDB_API_KEY"] = "YOUR_WANDB_API_KEY_HERE"
58
- os.environ["WANDB_PROJECT"] = "chatterbox-finetuning"
59
-
60
- logger = setup_logger("ChatterboxFinetune")
61
-
62
-
63
- def main():
64
-
65
- cfg = TrainConfig()
66
-
67
- logger.info("--- Starting Chatterbox Finetuning ---")
68
- logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}")
69
-
70
- # 0. CHECK MODEL FILES
71
- mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox"
72
- if not check_pretrained_models(mode=mode_check):
73
- sys.exit(1)
74
-
75
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
-
77
- # 1. SELECT THE CORRECT ENGINE CLASS
78
- if cfg.is_turbo:
79
- EngineClass = ChatterboxTurboTTS
80
- else:
81
- EngineClass = ChatterboxTTS
82
-
83
- logger.info(f"Device: {device}")
84
- logger.info(f"Model Directory: {cfg.model_dir}")
85
-
86
- # 2. LOAD ORIGINAL MODEL TEMPORARILY
87
- logger.info("Loading original model to extract weights...")
88
- # Loading on CPU first to save VRAM
89
- tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu")
90
-
91
- pretrained_t3_state_dict = tts_engine_original.t3.state_dict()
92
- original_t3_config = tts_engine_original.t3.hp
93
-
94
- # 3. CREATE NEW T3 MODEL WITH NEW VOCAB SIZE
95
- logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}")
96
-
97
- new_t3_config = original_t3_config
98
- new_t3_config.text_tokens_dict_size = cfg.new_vocab_size
99
-
100
- # We prevent caching during training.
101
- if hasattr(new_t3_config, "use_cache"):
102
- new_t3_config.use_cache = False
103
- else:
104
- setattr(new_t3_config, "use_cache", False)
105
-
106
- new_t3_model = T3(hp=new_t3_config)
107
-
108
- # 4. TRANSFER WEIGHTS
109
- logger.info("Transferring weights...")
110
- new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict)
111
-
112
-
113
- # --- SPECIAL SETTING FOR TURBO ---
114
- if cfg.is_turbo:
115
- logger.info("Turbo Mode: Removing backbone WTE layer...")
116
- if hasattr(new_t3_model.tfmr, "wte"):
117
- del new_t3_model.tfmr.wte
118
-
119
-
120
- # Clean up memory
121
- del tts_engine_original
122
- del pretrained_t3_state_dict
123
-
124
- # 5. PREPARE ENGINE FOR TRAINING
125
- # Reload engine components (VoiceEncoder, S3Gen) but inject our new T3
126
- tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu")
127
- tts_engine_new.t3 = new_t3_model
128
-
129
- # Freeze other components
130
- logger.info("Freezing S3Gen and VoiceEncoder...")
131
- for param in tts_engine_new.ve.parameters():
132
- param.requires_grad = False
133
-
134
- for param in tts_engine_new.s3gen.parameters():
135
- param.requires_grad = False
136
-
137
- # Enable Training for T3
138
- tts_engine_new.t3.train()
139
- for param in tts_engine_new.t3.parameters():
140
- param.requires_grad = True
141
-
142
- if cfg.preprocess:
143
-
144
- logger.info("Initializing Preprocess dataset...")
145
-
146
- if cfg.ljspeech:
147
- preprocess_dataset_ljspeech(cfg, tts_engine_new)
148
-
149
- else:
150
- preprocess_dataset_file_based(cfg, tts_engine_new)
151
-
152
- else:
153
- logger.info("Skipping the preprocessing dataset step...")
154
-
155
-
156
- # 6. DATASET & WRAPPER
157
- logger.info("Initializing Datasets...")
158
- train_ds = ChatterboxDataset(cfg, split="train")
159
- val_ds = ChatterboxDataset(cfg, split="val")
160
-
161
- model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3)
162
-
163
- # 7. TRAINING ARGUMENTS
164
- training_args = TrainingArguments(
165
- output_dir=cfg.output_dir,
166
- per_device_train_batch_size=cfg.batch_size,
167
- gradient_accumulation_steps=cfg.grad_accum,
168
- learning_rate=cfg.learning_rate,
169
- weight_decay=cfg.weight_decay, # Added weight decay
170
- num_train_epochs=cfg.num_epochs,
171
- evaluation_strategy="epoch", # Evaluate every epoch instead of steps
172
- save_strategy="epoch", # Save every epoch
173
- logging_strategy="steps",
174
- logging_steps=10,
175
- remove_unused_columns=False, # Required for our custom wrapper
176
- dataloader_num_workers=16,
177
- report_to=["wandb"],
178
- fp16=True if torch.cuda.is_available() else False,
179
- save_total_limit=10, # Keep all 10 epoch checkpoints
180
- gradient_checkpointing=True, # This setting theoretically reduces VRAM usage by 60%.
181
- label_names=["speech_tokens", "text_tokens"],
182
- load_best_model_at_end=True, # We want to run exactly 10 epochs
183
- )
184
-
185
- trainer = ChatterboxTrainer(
186
- model=model_wrapper,
187
- args=training_args,
188
- train_dataset=train_ds,
189
- eval_dataset=val_ds,
190
- data_collator=data_collator,
191
- callbacks=[] # Removed EarlyStopping
192
- )
193
-
194
- logger.info("Running initial evaluation to establish baseline...")
195
- trainer.evaluate()
196
-
197
- logger.info("Starting Training Loop...")
198
- trainer.train()
199
-
200
-
201
- # 8. SAVE FINAL MODEL
202
- logger.info("Training complete. Saving model...")
203
- os.makedirs(cfg.output_dir, exist_ok=True)
204
-
205
- filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors"
206
- final_model_path = os.path.join(cfg.output_dir, filename)
207
-
208
- save_file(tts_engine_new.t3.state_dict(), final_model_path)
209
- logger.info(f"Model saved to: {final_model_path}")
210
-
211
-
212
- if __name__ == "__main__":
213
- main()
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback
5
+ from safetensors.torch import save_file
6
+
7
+ class ChatterboxTrainer(Trainer):
8
+ """Custom Trainer to log sub-losses for both train and eval."""
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+ self._eval_loss_text = []
12
+ self._eval_loss_speech = []
13
+
14
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
15
+ outputs = model(**inputs)
16
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
17
+
18
+ if isinstance(outputs, dict):
19
+ if model.training:
20
+ if self.state.global_step % self.args.logging_steps == 0:
21
+ if "loss_text" in outputs:
22
+ self.log({"loss_text": outputs["loss_text"].item()})
23
+ if "loss_speech" in outputs:
24
+ self.log({"loss_speech": outputs["loss_speech"].item()})
25
+ else:
26
+ if "loss_text" in outputs:
27
+ self._eval_loss_text.append(outputs["loss_text"].item())
28
+ if "loss_speech" in outputs:
29
+ self._eval_loss_speech.append(outputs["loss_speech"].item())
30
+
31
+ return (loss, outputs) if return_outputs else loss
32
+
33
+ def evaluation_loop(self, *args, **kwargs):
34
+ self._eval_loss_text = []
35
+ self._eval_loss_speech = []
36
+ output = super().evaluation_loop(*args, **kwargs)
37
+ if self._eval_loss_text:
38
+ output.metrics["eval_loss_text"] = sum(self._eval_loss_text) / len(self._eval_loss_text)
39
+ if self._eval_loss_speech:
40
+ output.metrics["eval_loss_speech"] = sum(self._eval_loss_speech) / len(self._eval_loss_speech)
41
+ return output
42
+
43
+ # Internal Modules
44
+ from src.config import TrainConfig
45
+ from src.dataset import ChatterboxDataset, data_collator
46
+ from src.model import resize_and_load_t3_weights, ChatterboxTrainerWrapper
47
+ from src.preprocess_ljspeech import preprocess_dataset_ljspeech
48
+ from src.preprocess_file_based import preprocess_dataset_file_based
49
+ from src.utils import setup_logger, check_pretrained_models
50
+
51
+ # Chatterbox Imports
52
+ from src.chatterbox_.tts import ChatterboxTTS
53
+ from src.chatterbox_.tts_turbo import ChatterboxTurboTTS
54
+ from src.chatterbox_.models.t3.t3 import T3
55
+
56
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
57
+ os.environ["WANDB_API_KEY"] = "INSERT_API_KEY_HERE"
58
+ os.environ["WANDB_PROJECT"] = "chatterbox-finetuning"
59
+
60
+ logger = setup_logger("ChatterboxFinetune")
61
+
62
+
63
+ def main():
64
+
65
+ cfg = TrainConfig()
66
+
67
+ logger.info("--- Starting Chatterbox Finetuning ---")
68
+ logger.info(f"Mode: {'CHATTERBOX-TURBO' if cfg.is_turbo else 'CHATTERBOX-TTS'}")
69
+
70
+ # 0. CHECK MODEL FILES
71
+ mode_check = "chatterbox_turbo" if cfg.is_turbo else "chatterbox"
72
+ if not check_pretrained_models(mode=mode_check):
73
+ sys.exit(1)
74
+
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+
77
+ # 1. SELECT THE CORRECT ENGINE CLASS
78
+ if cfg.is_turbo:
79
+ EngineClass = ChatterboxTurboTTS
80
+ else:
81
+ EngineClass = ChatterboxTTS
82
+
83
+ logger.info(f"Device: {device}")
84
+ logger.info(f"Model Directory: {cfg.model_dir}")
85
+
86
+ # 2. LOAD ORIGINAL MODEL TEMPORARILY
87
+ logger.info("Loading original model to extract weights...")
88
+ # Loading on CPU first to save VRAM
89
+ tts_engine_original = EngineClass.from_local(cfg.model_dir, device="cpu")
90
+
91
+ pretrained_t3_state_dict = tts_engine_original.t3.state_dict()
92
+ original_t3_config = tts_engine_original.t3.hp
93
+
94
+ # 3. CREATE NEW T3 MODEL WITH NEW VOCAB SIZE
95
+ logger.info(f"Creating new T3 model with vocab size: {cfg.new_vocab_size}")
96
+
97
+ new_t3_config = original_t3_config
98
+ new_t3_config.text_tokens_dict_size = cfg.new_vocab_size
99
+
100
+ # We prevent caching during training.
101
+ if hasattr(new_t3_config, "use_cache"):
102
+ new_t3_config.use_cache = False
103
+ else:
104
+ setattr(new_t3_config, "use_cache", False)
105
+
106
+ new_t3_model = T3(hp=new_t3_config)
107
+
108
+ # 4. TRANSFER WEIGHTS
109
+ logger.info("Transferring weights...")
110
+ new_t3_model = resize_and_load_t3_weights(new_t3_model, pretrained_t3_state_dict)
111
+
112
+
113
+ # --- SPECIAL SETTING FOR TURBO ---
114
+ if cfg.is_turbo:
115
+ logger.info("Turbo Mode: Removing backbone WTE layer...")
116
+ if hasattr(new_t3_model.tfmr, "wte"):
117
+ del new_t3_model.tfmr.wte
118
+
119
+
120
+ # Clean up memory
121
+ del tts_engine_original
122
+ del pretrained_t3_state_dict
123
+
124
+ # 5. PREPARE ENGINE FOR TRAINING
125
+ # Reload engine components (VoiceEncoder, S3Gen) but inject our new T3
126
+ tts_engine_new = EngineClass.from_local(cfg.model_dir, device="cpu")
127
+ tts_engine_new.t3 = new_t3_model
128
+
129
+ # Freeze other components
130
+ logger.info("Freezing S3Gen and VoiceEncoder...")
131
+ for param in tts_engine_new.ve.parameters():
132
+ param.requires_grad = False
133
+
134
+ for param in tts_engine_new.s3gen.parameters():
135
+ param.requires_grad = False
136
+
137
+ # Enable Training for T3
138
+ tts_engine_new.t3.train()
139
+ for param in tts_engine_new.t3.parameters():
140
+ param.requires_grad = True
141
+
142
+ if cfg.preprocess:
143
+
144
+ logger.info("Initializing Preprocess dataset...")
145
+
146
+ if cfg.ljspeech:
147
+ preprocess_dataset_ljspeech(cfg, tts_engine_new)
148
+
149
+ else:
150
+ preprocess_dataset_file_based(cfg, tts_engine_new)
151
+
152
+ else:
153
+ logger.info("Skipping the preprocessing dataset step...")
154
+
155
+
156
+ # 6. DATASET & WRAPPER
157
+ logger.info("Initializing Datasets...")
158
+ train_ds = ChatterboxDataset(cfg, split="train")
159
+ val_ds = ChatterboxDataset(cfg, split="val")
160
+
161
+ model_wrapper = ChatterboxTrainerWrapper(tts_engine_new.t3)
162
+
163
+ # 7. TRAINING ARGUMENTS
164
+ training_args = TrainingArguments(
165
+ output_dir=cfg.output_dir,
166
+ per_device_train_batch_size=cfg.batch_size,
167
+ gradient_accumulation_steps=cfg.grad_accum,
168
+ learning_rate=cfg.learning_rate,
169
+ weight_decay=cfg.weight_decay, # Added weight decay
170
+ num_train_epochs=cfg.num_epochs,
171
+ evaluation_strategy="epoch",
172
+ save_strategy="epoch",
173
+ logging_strategy="steps",
174
+ logging_steps=10,
175
+ remove_unused_columns=False, # Required for our custom wrapper
176
+ dataloader_num_workers=16,
177
+ report_to=["wandb"],
178
+ bf16=True if torch.cuda.is_available() else False, # Using bf16 for A100
179
+ save_total_limit=5, # Keep all epoch checkpoints
180
+ gradient_checkpointing=False, # This setting theoretically reduces VRAM usage by 60%.
181
+ label_names=["speech_tokens", "text_tokens"],
182
+ load_best_model_at_end=True,
183
+ lr_scheduler_type="cosine", # Research-optimized scheduler
184
+ warmup_ratio=0.1, # 10% warmup to handle English-to-Finnish transition
185
+ )
186
+
187
+ trainer = ChatterboxTrainer(
188
+ model=model_wrapper,
189
+ args=training_args,
190
+ train_dataset=train_ds,
191
+ eval_dataset=val_ds,
192
+ data_collator=data_collator,
193
+ callbacks=[] # Removed EarlyStopping
194
+ )
195
+
196
+ logger.info("Running initial evaluation to establish baseline...")
197
+ trainer.evaluate()
198
+
199
+ logger.info("Starting Training Loop...")
200
+ trainer.train()
201
+
202
+
203
+ # 8. SAVE FINAL MODEL
204
+ logger.info("Training complete. Saving model...")
205
+ os.makedirs(cfg.output_dir, exist_ok=True)
206
+
207
+ filename = "t3_turbo_finetuned.safetensors" if cfg.is_turbo else "t3_finetuned.safetensors"
208
+ final_model_path = os.path.join(cfg.output_dir, filename)
209
+
210
+ save_file(tts_engine_new.t3.state_dict(), final_model_path)
211
+ logger.info(f"Model saved to: {final_model_path}")
212
+
213
+
214
+ if __name__ == "__main__":
215
+ main()