Chatterbox-Finnish / inference_example.py
RASMUS's picture
Upload Finnish Chatterbox model
67ea4ca verified
import os
import torch
import soundfile as sf
from src.chatterbox_.tts import ChatterboxTTS
from safetensors.torch import load_file
# --- CONFIGURABLE VARIABLES ---
# Path to the directory containing base weights (ve.safetensors, etc.)
MODEL_DIR = "./pretrained_models"
# Path to our best finetuned T3 weights
# In the upload package, this is usually in the 'models' directory
FINETUNED_WEIGHTS = "./models/best_finnish_multilingual_cp986.safetensors"
# Text to synthesize
TEXT = "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä."
# Reference audio for the speaker identity (Zero-shot)
REFERENCE_AUDIO = "./samples/reference_finnish.wav"
# Output filename
OUTPUT_FILE = "output_finnish.wav"
# ------------------------------
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# 1. Load the base Chatterbox engine
print(f"Loading base model from {MODEL_DIR}...")
engine = ChatterboxTTS.from_local(MODEL_DIR, device=device)
# 2. Inject the finetuned weights
if os.path.exists(FINETUNED_WEIGHTS):
print(f"Loading finetuned weights from {FINETUNED_WEIGHTS}...")
checkpoint_state = load_file(FINETUNED_WEIGHTS)
# Strip "t3." prefix if present
t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()}
# Load into the T3 component
engine.t3.load_state_dict(t3_state_dict, strict=False)
else:
print(f"Warning: Finetuned weights not found at {FINETUNED_WEIGHTS}. Using base weights.")
# 3. Generate Audio
print(f"Generating audio for: '{TEXT}'")
# Using optimized parameters for Finnish
wav_tensor = engine.generate(
text=TEXT,
audio_prompt_path=REFERENCE_AUDIO,
repetition_penalty=1.2,
temperature=0.8,
exaggeration=0.6
)
# 4. Save the result
wav_np = wav_tensor.squeeze().cpu().numpy()
sf.write(OUTPUT_FILE, wav_np, engine.sr)
print(f"Successfully saved audio to {OUTPUT_FILE}")
if __name__ == "__main__":
main()