|
|
import os |
|
|
import torch |
|
|
import json |
|
|
import pandas as pd |
|
|
from pathlib import Path |
|
|
from safetensors.torch import load_file |
|
|
from faster_whisper import WhisperModel |
|
|
from google import genai |
|
|
from google.genai import types |
|
|
import soundfile as sf |
|
|
import re |
|
|
from tqdm import tqdm |
|
|
import itertools |
|
|
|
|
|
|
|
|
from src.config import TrainConfig |
|
|
from src.chatterbox_.mtl_tts import ChatterboxMultilingualTTS |
|
|
from src.chatterbox_.models.t3.t3 import T3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GEMINI_API_KEY = "INSERT_API_KEY_HERE" |
|
|
CHECKPOINT_PATH = "./chatterbox_stage2_output/checkpoint-16" |
|
|
REFERENCE_WAV = "/workspaces/work/Chatterbox-Finnish/GrowthMindset_Chatterbox_Dataset/wavs/growthmindset_00000.wav" |
|
|
|
|
|
|
|
|
LEAN_HOLDOUT_IDS = [ |
|
|
"growthmindset_00547", |
|
|
"growthmindset_00548", |
|
|
"growthmindset_00564" |
|
|
] |
|
|
|
|
|
EVERYDAY_PHRASES = [ |
|
|
"Voisitko ystävällisesti auttaa minua tämän asian kanssa?", |
|
|
"Tänään on todella kaunis päivä, joten ajattelin lähteä ulos kävelemään ja nauttimaan auringosta ennen kuin ilta viilenee.", |
|
|
"Huomenta kaikille, toivottavasti teillä on ollut mukava aamu ja olette valmiita aloittamaan uuden päivän täynnä mielenkiintoisia haasteita ja onnistumisia." |
|
|
] |
|
|
|
|
|
|
|
|
PARAM_GRID = { |
|
|
"repetition_penalty": [1.2, 1.5], |
|
|
"temperature": [0.7, 0.8], |
|
|
"exaggeration": [0.5, 0.6], |
|
|
"cfg_weight": [0.3, 0.5] |
|
|
} |
|
|
|
|
|
OUTPUT_BASE_DIR = "./param_sweep_results" |
|
|
|
|
|
|
|
|
def setup_gemini(): |
|
|
return genai.Client(api_key=GEMINI_API_KEY) |
|
|
|
|
|
def get_mos_score(client, audio_path, target_text): |
|
|
try: |
|
|
audio_file = client.files.upload(file=audio_path) |
|
|
import time |
|
|
for _ in range(10): |
|
|
file_info = client.files.get(name=audio_file.name) |
|
|
if file_info.state == "ACTIVE": break |
|
|
time.sleep(1) |
|
|
|
|
|
prompt = f""" |
|
|
Olet asiantunteva puheenlaadun arvioija. |
|
|
Arvioi oheinen äänitiedosto, jossa hienoviritetty TTS-malli sanoo: "{target_text}" |
|
|
|
|
|
Arvioi asteikolla 1-5 (1=huono, 5=erinomainen): |
|
|
1. Luonnollisuus: Kuulostaako se ihmiseltä? |
|
|
2. Selkeys: Ovatko sanat helposti erotettavissa? |
|
|
3. Prosodia: Kuulostaako rytmi luonnolliselta suomen kielelle? |
|
|
|
|
|
Vastaa TARKALLEEN tässä JSON-muodossa: {{"mos": <numero>, "reason": "<lyhyt_perustelu>"}} |
|
|
""" |
|
|
response = client.models.generate_content( |
|
|
model='gemini-3-flash-preview', |
|
|
contents=[prompt, audio_file], |
|
|
config=types.GenerateContentConfig(response_mime_type="application/json") |
|
|
) |
|
|
result = json.loads(response.text) |
|
|
if isinstance(result, list): result = result[0] |
|
|
return result |
|
|
except Exception: |
|
|
return {"mos": 0} |
|
|
|
|
|
def calculate_wer(reference, hypothesis): |
|
|
try: |
|
|
import jiwer |
|
|
return jiwer.wer(reference, hypothesis) |
|
|
except ImportError: |
|
|
def clean(t): return re.sub(r'[^\w\s]', '', t.lower()).strip() |
|
|
ref_words = clean(reference).split() |
|
|
hyp_words = clean(hypothesis).split() |
|
|
if not ref_words: return 0.0 |
|
|
import difflib |
|
|
return 1.0 - difflib.SequenceMatcher(None, ref_words, hyp_words).ratio() |
|
|
|
|
|
def main(): |
|
|
cfg = TrainConfig() |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
os.makedirs(OUTPUT_BASE_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
meta = pd.read_csv(cfg.csv_path, sep="|", header=None, quoting=3) |
|
|
lean_meta = meta[meta[0].isin(LEAN_HOLDOUT_IDS)] |
|
|
sweep_sentences = list(lean_meta[1]) + EVERYDAY_PHRASES |
|
|
|
|
|
print("Loading Faster Whisper...") |
|
|
whisper_model = WhisperModel("large-v3", device=device, compute_type="float16" if device == "cuda" else "int8") |
|
|
|
|
|
gemini_client = setup_gemini() |
|
|
|
|
|
|
|
|
engine = ChatterboxMultilingualTTS.from_local(cfg.model_dir, device=device) |
|
|
weights_path = Path(CHECKPOINT_PATH) / "model.safetensors" |
|
|
checkpoint_state = load_file(str(weights_path)) |
|
|
t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()} |
|
|
if "text_emb.weight" in t3_state_dict: |
|
|
engine.t3.hp.text_tokens_dict_size = t3_state_dict["text_emb.weight"].shape[0] |
|
|
engine.t3 = T3(hp=engine.t3.hp).to(device) |
|
|
engine.t3.load_state_dict(t3_state_dict, strict=False) |
|
|
engine.t3.eval() |
|
|
|
|
|
|
|
|
keys, values = zip(*PARAM_GRID.items()) |
|
|
combinations = [dict(zip(keys, v)) for v in itertools.product(*values)] |
|
|
|
|
|
print(f"Starting sweep of {len(combinations)} combinations using {len(sweep_sentences)} sentences...") |
|
|
|
|
|
sweep_results = [] |
|
|
|
|
|
for i, params in enumerate(combinations): |
|
|
print(f"\n[{i+1}/{len(combinations)}] Testing: {params}") |
|
|
|
|
|
total_wer = 0 |
|
|
total_mos = 0 |
|
|
valid_mos_count = 0 |
|
|
|
|
|
for j, text in enumerate(sweep_sentences): |
|
|
wav_tensor = engine.generate( |
|
|
text=text, |
|
|
language_id="fi", |
|
|
audio_prompt_path=REFERENCE_WAV, |
|
|
**params |
|
|
) |
|
|
|
|
|
|
|
|
param_str = f"rp{params['repetition_penalty']}_temp{params['temperature']}_ex{params['exaggeration']}_cfg{params['cfg_weight']}" |
|
|
audio_path = os.path.join(OUTPUT_BASE_DIR, f"trial_{i}_sent_{j}_{param_str}.wav") |
|
|
sf.write(audio_path, wav_tensor.squeeze().cpu().numpy(), engine.sr) |
|
|
|
|
|
|
|
|
segments, _ = whisper_model.transcribe(audio_path, language="fi") |
|
|
hyp = " ".join([s.text for s in segments]) |
|
|
wer = calculate_wer(text, hyp) |
|
|
total_wer += wer |
|
|
|
|
|
|
|
|
mos_data = get_mos_score(gemini_client, audio_path, text) |
|
|
if mos_data.get('mos', 0) > 0: |
|
|
total_mos += mos_data['mos'] |
|
|
valid_mos_count += 1 |
|
|
|
|
|
avg_wer = total_wer / len(sweep_sentences) |
|
|
avg_mos = total_mos / valid_mos_count if valid_mos_count > 0 else 0 |
|
|
|
|
|
result_entry = { |
|
|
"trial_id": i, |
|
|
"params": params, |
|
|
"avg_wer": avg_wer, |
|
|
"avg_mos": avg_mos |
|
|
} |
|
|
sweep_results.append(result_entry) |
|
|
print(f"Result: WER={avg_wer:.4f}, MOS={avg_mos:.2f}") |
|
|
|
|
|
|
|
|
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary_partial.json"), "w") as f: |
|
|
json.dump(sweep_results, f, indent=4) |
|
|
|
|
|
|
|
|
|
|
|
best_score = -1 |
|
|
best_params = None |
|
|
|
|
|
for r in sweep_results: |
|
|
score = r['avg_mos'] * (1 - r['avg_wer']) |
|
|
if score > best_score: |
|
|
best_score = score |
|
|
best_params = r |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("SWEEP COMPLETE") |
|
|
print(f"Best Params: {best_params['params']}") |
|
|
print(f"Best Metrics: WER={best_params['avg_wer']:.4f}, MOS={best_params['avg_mos']:.2f}") |
|
|
print("="*60) |
|
|
|
|
|
with open(os.path.join(OUTPUT_BASE_DIR, "sweep_summary.json"), "w") as f: |
|
|
json.dump(sweep_results, f, indent=4) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|