import os import torch import json import pandas as pd from pathlib import Path from safetensors.torch import load_file from faster_whisper import WhisperModel from google import genai from google.genai import types import soundfile as sf import re from tqdm import tqdm import itertools # Internal Modules from src.config import TrainConfig from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS from src.chatterbox_.models.t3.t3 import T3 # ============================================================================== # CONFIGURATION # ============================================================================== GEMINI_API_KEY = "INSERT_API_KEY_HERE" CHECKPOINT_PATH = "./chatterbox_stage2_output/checkpoint-16" REFERENCE_WAV = "/workspaces/work/Chatterbox-Finnish/GrowthMindset_Chatterbox_Dataset/wavs/growthmindset_00000.wav" # Align with evaluate_checkpoints.py LEAN_HOLDOUT_IDS = [ "growthmindset_00547", # Short "growthmindset_00548", # Medium/Long "growthmindset_00564" # Very expressive ] EVERYDAY_PHRASES = [ "Voisitko ystävällisesti auttaa minua tämän asian kanssa?", # Short "Tänään on todella kaunis päivä, joten ajattelin lähteä ulos kävelemään ja nauttimaan auringosta ennen kuin ilta viilenee.", # Long 1 "Huomenta kaikille, toivottavasti teillä on ollut mukava aamu ja olette valmiita aloittamaan uuden päivän täynnä mielenkiintoisia haasteita ja onnistumisia." # Long 2 ] # Parameter Grid PARAM_GRID = { "repetition_penalty": [1.2, 1.5], "temperature": [0.7, 0.8], "exaggeration": [0.5, 0.6], "cfg_weight": [0.3, 0.5] } OUTPUT_BASE_DIR = "./param_sweep_results" # ============================================================================== def setup_gemini(): return genai.Client(api_key=GEMINI_API_KEY) def get_mos_score(client, audio_path, target_text): try: audio_file = client.files.upload(file=audio_path) import time for _ in range(10): file_info = client.files.get(name=audio_file.name) if file_info.state == "ACTIVE": break time.sleep(1) prompt = f""" Olet asiantunteva puheenlaadun arvioija. Arvioi oheinen äänitiedosto, jossa hienoviritetty TTS-malli sanoo: "{target_text}" Arvioi asteikolla 1-5 (1=huono, 5=erinomainen): 1. Luonnollisuus: Kuulostaako se ihmiseltä? 2. Selkeys: Ovatko sanat helposti erotettavissa? 3. Prosodia: Kuulostaako rytmi luonnolliselta suomen kielelle? Vastaa TARKALLEEN tässä JSON-muodossa: {{"mos": , "reason": ""}} """ response = client.models.generate_content( model='gemini-3-flash-preview', contents=[prompt, audio_file], config=types.GenerateContentConfig(response_mime_type="application/json") ) result = json.loads(response.text) if isinstance(result, list): result = result[0] return result except Exception: return {"mos": 0} def calculate_wer(reference, hypothesis): try: import jiwer return jiwer.wer(reference, hypothesis) except ImportError: def clean(t): return re.sub(r'[^\w\s]', '', t.lower()).strip() ref_words = clean(reference).split() hyp_words = clean(hypothesis).split() if not ref_words: return 0.0 import difflib return 1.0 - difflib.SequenceMatcher(None, ref_words, hyp_words).ratio() def main(): cfg = TrainConfig() device = "cuda" if torch.cuda.is_available() else "cpu" os.makedirs(OUTPUT_BASE_DIR, exist_ok=True) # Load metadata for holdouts meta = pd.read_csv(cfg.csv_path, sep="|", header=None, quoting=3) lean_meta = meta[meta[0].isin(LEAN_HOLDOUT_IDS)] sweep_sentences = list(lean_meta[1]) + EVERYDAY_PHRASES print("Loading Faster Whisper...") whisper_model = WhisperModel("large-v3", device=device, compute_type="float16" if device == "cuda" else "int8") gemini_client = setup_gemini() # Load engine and checkpoint weights once engine = ChatterboxMultilingualTTS.from_local(cfg.model_dir, device=device) weights_path = Path(CHECKPOINT_PATH) / "model.safetensors" checkpoint_state = load_file(str(weights_path)) t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()} if "text_emb.weight" in t3_state_dict: engine.t3.hp.text_tokens_dict_size = t3_state_dict["text_emb.weight"].shape[0] engine.t3 = T3(hp=engine.t3.hp).to(device) engine.t3.load_state_dict(t3_state_dict, strict=False) engine.t3.eval() # Generate parameter combinations keys, values = zip(*PARAM_GRID.items()) combinations = [dict(zip(keys, v)) for v in itertools.product(*values)] print(f"Starting sweep of {len(combinations)} combinations using {len(sweep_sentences)} sentences...") sweep_results = [] for i, params in enumerate(combinations): print(f"\n[{i+1}/{len(combinations)}] Testing: {params}") total_wer = 0 total_mos = 0 valid_mos_count = 0 for j, text in enumerate(sweep_sentences): wav_tensor = engine.generate( text=text, language_id="fi", audio_prompt_path=REFERENCE_WAV, **params ) # Format filename with key params for easy manual review param_str = f"rp{params['repetition_penalty']}_temp{params['temperature']}_ex{params['exaggeration']}_cfg{params['cfg_weight']}" audio_path = os.path.join(OUTPUT_BASE_DIR, f"trial_{i}_sent_{j}_{param_str}.wav") sf.write(audio_path, wav_tensor.squeeze().cpu().numpy(), engine.sr) # WER segments, _ = whisper_model.transcribe(audio_path, language="fi") hyp = " ".join([s.text for s in segments]) wer = calculate_wer(text, hyp) total_wer += wer # MOS mos_data = get_mos_score(gemini_client, audio_path, text) if mos_data.get('mos', 0) > 0: total_mos += mos_data['mos'] valid_mos_count += 1 avg_wer = total_wer / len(sweep_sentences) avg_mos = total_mos / valid_mos_count if valid_mos_count > 0 else 0 result_entry = { "trial_id": i, "params": params, "avg_wer": avg_wer, "avg_mos": avg_mos } sweep_results.append(result_entry) print(f"Result: WER={avg_wer:.4f}, MOS={avg_mos:.2f}") # Save intermediate results with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary_partial.json"), "w") as f: json.dump(sweep_results, f, indent=4) # Find the best combination # We want low WER and high MOS. A simple score: MOS * (1 - WER) best_score = -1 best_params = None for r in sweep_results: score = r['avg_mos'] * (1 - r['avg_wer']) if score > best_score: best_score = score best_params = r print("\n" + "="*60) print("SWEEP COMPLETE") print(f"Best Params: {best_params['params']}") print(f"Best Metrics: WER={best_params['avg_wer']:.4f}, MOS={best_params['avg_mos']:.2f}") print("="*60) with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary.json"), "w") as f: json.dump(sweep_results, f, indent=4) if __name__ == "__main__": main()