import os import torch import soundfile as sf from src.chatterbox_.tts import ChatterboxTTS from safetensors.torch import load_file # --- CONFIGURABLE VARIABLES --- # Path to the directory containing base weights (ve.safetensors, etc.) MODEL_DIR = "./pretrained_models" # Path to our best finetuned T3 weights # In the upload package, this is usually in the 'models' directory FINETUNED_WEIGHTS = "./models/best_finnish_multilingual_cp986.safetensors" # Text to synthesize TEXT = "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä." # Reference audio for the speaker identity (Zero-shot) REFERENCE_AUDIO = "./samples/reference_finnish.wav" # Output filename OUTPUT_FILE = "output_finnish.wav" # ------------------------------ def main(): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # 1. Load the base Chatterbox engine print(f"Loading base model from {MODEL_DIR}...") engine = ChatterboxTTS.from_local(MODEL_DIR, device=device) # 2. Inject the finetuned weights if os.path.exists(FINETUNED_WEIGHTS): print(f"Loading finetuned weights from {FINETUNED_WEIGHTS}...") checkpoint_state = load_file(FINETUNED_WEIGHTS) # Strip "t3." prefix if present t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()} # Load into the T3 component engine.t3.load_state_dict(t3_state_dict, strict=False) else: print(f"Warning: Finetuned weights not found at {FINETUNED_WEIGHTS}. Using base weights.") # 3. Generate Audio print(f"Generating audio for: '{TEXT}'") # Using optimized parameters for Finnish wav_tensor = engine.generate( text=TEXT, audio_prompt_path=REFERENCE_AUDIO, repetition_penalty=1.2, temperature=0.8, exaggeration=0.6 ) # 4. Save the result wav_np = wav_tensor.squeeze().cpu().numpy() sf.write(OUTPUT_FILE, wav_np, engine.sr) print(f"Successfully saved audio to {OUTPUT_FILE}") if __name__ == "__main__": main()