Chatterbox-Finnish / sweep_params.py
RASMUS's picture
Upload sweep_params.py
0a78f68 verified
import os
import torch
import json
import pandas as pd
from pathlib import Path
from safetensors.torch import load_file
from faster_whisper import WhisperModel
from google import genai
from google.genai import types
import soundfile as sf
import re
from tqdm import tqdm
import itertools
# Internal Modules
from src.config import TrainConfig
from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS
from src.chatterbox_.models.t3.t3 import T3
# ==============================================================================
# CONFIGURATION
# ==============================================================================
GEMINI_API_KEY = "INSERT_API_KEY_HERE"
CHECKPOINT_PATH = "./chatterbox_stage2_output/checkpoint-16"
REFERENCE_WAV = "/workspaces/work/Chatterbox-Finnish/GrowthMindset_Chatterbox_Dataset/wavs/growthmindset_00000.wav"
# Align with evaluate_checkpoints.py
LEAN_HOLDOUT_IDS = [
"growthmindset_00547", # Short
"growthmindset_00548", # Medium/Long
"growthmindset_00564" # Very expressive
]
EVERYDAY_PHRASES = [
"Voisitko ystävällisesti auttaa minua tämän asian kanssa?", # Short
"Tänään on todella kaunis päivä, joten ajattelin lähteä ulos kävelemään ja nauttimaan auringosta ennen kuin ilta viilenee.", # Long 1
"Huomenta kaikille, toivottavasti teillä on ollut mukava aamu ja olette valmiita aloittamaan uuden päivän täynnä mielenkiintoisia haasteita ja onnistumisia." # Long 2
]
# Parameter Grid
PARAM_GRID = {
"repetition_penalty": [1.2, 1.5],
"temperature": [0.7, 0.8],
"exaggeration": [0.5, 0.6],
"cfg_weight": [0.3, 0.5]
}
OUTPUT_BASE_DIR = "./param_sweep_results"
# ==============================================================================
def setup_gemini():
return genai.Client(api_key=GEMINI_API_KEY)
def get_mos_score(client, audio_path, target_text):
try:
audio_file = client.files.upload(file=audio_path)
import time
for _ in range(10):
file_info = client.files.get(name=audio_file.name)
if file_info.state == "ACTIVE": break
time.sleep(1)
prompt = f"""
Olet asiantunteva puheenlaadun arvioija.
Arvioi oheinen äänitiedosto, jossa hienoviritetty TTS-malli sanoo: "{target_text}"
Arvioi asteikolla 1-5 (1=huono, 5=erinomainen):
1. Luonnollisuus: Kuulostaako se ihmiseltä?
2. Selkeys: Ovatko sanat helposti erotettavissa?
3. Prosodia: Kuulostaako rytmi luonnolliselta suomen kielelle?
Vastaa TARKALLEEN tässä JSON-muodossa: {{"mos": <numero>, "reason": "<lyhyt_perustelu>"}}
"""
response = client.models.generate_content(
model='gemini-3-flash-preview',
contents=[prompt, audio_file],
config=types.GenerateContentConfig(response_mime_type="application/json")
)
result = json.loads(response.text)
if isinstance(result, list): result = result[0]
return result
except Exception:
return {"mos": 0}
def calculate_wer(reference, hypothesis):
try:
import jiwer
return jiwer.wer(reference, hypothesis)
except ImportError:
def clean(t): return re.sub(r'[^\w\s]', '', t.lower()).strip()
ref_words = clean(reference).split()
hyp_words = clean(hypothesis).split()
if not ref_words: return 0.0
import difflib
return 1.0 - difflib.SequenceMatcher(None, ref_words, hyp_words).ratio()
def main():
cfg = TrainConfig()
device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_BASE_DIR, exist_ok=True)
# Load metadata for holdouts
meta = pd.read_csv(cfg.csv_path, sep="|", header=None, quoting=3)
lean_meta = meta[meta[0].isin(LEAN_HOLDOUT_IDS)]
sweep_sentences = list(lean_meta[1]) + EVERYDAY_PHRASES
print("Loading Faster Whisper...")
whisper_model = WhisperModel("large-v3", device=device, compute_type="float16" if device == "cuda" else "int8")
gemini_client = setup_gemini()
# Load engine and checkpoint weights once
engine = ChatterboxMultilingualTTS.from_local(cfg.model_dir, device=device)
weights_path = Path(CHECKPOINT_PATH) / "model.safetensors"
checkpoint_state = load_file(str(weights_path))
t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
if "text_emb.weight" in t3_state_dict:
engine.t3.hp.text_tokens_dict_size = t3_state_dict["text_emb.weight"].shape[0]
engine.t3 = T3(hp=engine.t3.hp).to(device)
engine.t3.load_state_dict(t3_state_dict, strict=False)
engine.t3.eval()
# Generate parameter combinations
keys, values = zip(*PARAM_GRID.items())
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]
print(f"Starting sweep of {len(combinations)} combinations using {len(sweep_sentences)} sentences...")
sweep_results = []
for i, params in enumerate(combinations):
print(f"\n[{i+1}/{len(combinations)}] Testing: {params}")
total_wer = 0
total_mos = 0
valid_mos_count = 0
for j, text in enumerate(sweep_sentences):
wav_tensor = engine.generate(
text=text,
language_id="fi",
audio_prompt_path=REFERENCE_WAV,
**params
)
# Format filename with key params for easy manual review
param_str = f"rp{params['repetition_penalty']}_temp{params['temperature']}_ex{params['exaggeration']}_cfg{params['cfg_weight']}"
audio_path = os.path.join(OUTPUT_BASE_DIR, f"trial_{i}_sent_{j}_{param_str}.wav")
sf.write(audio_path, wav_tensor.squeeze().cpu().numpy(), engine.sr)
# WER
segments, _ = whisper_model.transcribe(audio_path, language="fi")
hyp = " ".join([s.text for s in segments])
wer = calculate_wer(text, hyp)
total_wer += wer
# MOS
mos_data = get_mos_score(gemini_client, audio_path, text)
if mos_data.get('mos', 0) > 0:
total_mos += mos_data['mos']
valid_mos_count += 1
avg_wer = total_wer / len(sweep_sentences)
avg_mos = total_mos / valid_mos_count if valid_mos_count > 0 else 0
result_entry = {
"trial_id": i,
"params": params,
"avg_wer": avg_wer,
"avg_mos": avg_mos
}
sweep_results.append(result_entry)
print(f"Result: WER={avg_wer:.4f}, MOS={avg_mos:.2f}")
# Save intermediate results
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary_partial.json"), "w") as f:
json.dump(sweep_results, f, indent=4)
# Find the best combination
# We want low WER and high MOS. A simple score: MOS * (1 - WER)
best_score = -1
best_params = None
for r in sweep_results:
score = r['avg_mos'] * (1 - r['avg_wer'])
if score > best_score:
best_score = score
best_params = r
print("\n" + "="*60)
print("SWEEP COMPLETE")
print(f"Best Params: {best_params['params']}")
print(f"Best Metrics: WER={best_params['avg_wer']:.4f}, MOS={best_params['avg_mos']:.2f}")
print("="*60)
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary.json"), "w") as f:
json.dump(sweep_results, f, indent=4)
if __name__ == "__main__":
main()